Below is my input dataframe:
--- ---------- --------
|ID |date |shift_by|
--- ---------- --------
|1 |2021-01-01|2 |
|1 |2021-02-05|2 |
|1 |2021-03-27|2 |
|2 |2022-02-28|1 |
|2 |2022-04-30|1 |
--- ---------- --------
I need to groupBy "ID" and shift based on the "shift_by" column. In the end, the result should look like below:
--- ---------- ----------
|ID |date1 |date2 |
--- ---------- ----------
|1 |2021-01-01|2021-03-27|
|2 |2022-02-28|2022-04-30|
--- ---------- ----------
I have implemented the logic using UDF, but it makes my code slow. I would like to understand if this logic can be implemented without using UDF.
Below is a sample dataframe:
from datetime import datetime
from pyspark.sql.types import *
data2 = [(1, datetime.date(2021, 1, 1), datetime.date(2021, 3, 27)),
(2, datetime.date(2022, 2, 28), datetime.date(2022, 4, 30))
]
schema = StructType([
StructField("ID", IntegerType(), True),
StructField("date1", DateType(), True),
StructField("date2", DateType(), True),
])
df = spark.createDataFrame(data=data2, schema=schema)
CodePudding user response:
based on the comments and chats, you can try to calculate first
and last
values of the lat/lon fields of concern.
import pyspark.sql.functions as func
from pyspark.sql.window import Window as wd
import sys
data_sdf. \
withColumn('foo_first', func.first('foo').over(wd.partitionBy('id').orderBy('date').rowsBetween(-sys.maxsize, sys.maxsize))). \
withColumn('foo_last', func.last('foo').over(wd.partitionBy('id').orderBy('date').rowsBetween(-sys.maxsize, sys.maxsize))). \
select('id', 'foo_first', 'foo_last'). \
dropDuplicates()
OR, you can create structs and take min
/max
data_sdf = spark.createDataFrame(
[(1, '2021-01-01', 2, 2),
(1, '2021-02-05', 3, 2),
(1, '2021-03-27', 4, 2),
(2, '2022-02-28', 1, 5),
(2, '2022-04-30', 5, 1)],
['ID', 'date', 'lat', 'lon'])
data_sdf. \
withColumn('dt_lat_lon_struct', func.struct('date', 'lat', 'lon')). \
groupBy('id'). \
agg(func.min('dt_lat_lon_struct').alias('min_dt_lat_lon_struct'),
func.max('dt_lat_lon_struct').alias('max_dt_lat_lon_struct')
). \
selectExpr('id',
'min_dt_lat_lon_struct.lat as lat_first', 'min_dt_lat_lon_struct.lon as lon_first',
'max_dt_lat_lon_struct.lat as lat_last', 'max_dt_lat_lon_struct.lon as lon_last'
)
# --- --------- --------- -------- --------
# | id|lat_first|lon_first|lat_last|lon_last|
# --- --------- --------- -------- --------
# | 1| 2| 2| 4| 2|
# | 2| 1| 5| 5| 1|
# --- --------- --------- -------- --------
CodePudding user response:
Aggregation using min
and max
seems could work in your case.
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(1, '2021-01-01', 2),
(1, '2021-02-05', 2),
(1, '2021-03-27', 2),
(2, '2022-02-28', 1),
(2, '2022-04-30', 1)],
['ID', 'date', 'shift_by'])
df = df.groupBy('ID').agg(
F.min('date').alias('date1'),
F.max('date').alias('date2'),
)
df.show()
# --- ---------- ----------
# | ID| date1| date2|
# --- ---------- ----------
# | 1|2021-01-01|2021-03-27|
# | 2|2022-02-28|2022-04-30|
# --- ---------- ----------