Home > other >  Shift rows dynamically based on column value
Shift rows dynamically based on column value

Time:11-01

Below is my input dataframe:

 --- ---------- -------- 
|ID |date      |shift_by|
 --- ---------- -------- 
|1  |2021-01-01|2       |
|1  |2021-02-05|2       |
|1  |2021-03-27|2       |
|2  |2022-02-28|1       |
|2  |2022-04-30|1       |
 --- ---------- -------- 

I need to groupBy "ID" and shift based on the "shift_by" column. In the end, the result should look like below:

 --- ---------- ---------- 
|ID |date1     |date2     |
 --- ---------- ---------- 
|1  |2021-01-01|2021-03-27|
|2  |2022-02-28|2022-04-30|
 --- ---------- ---------- 

I have implemented the logic using UDF, but it makes my code slow. I would like to understand if this logic can be implemented without using UDF.

Below is a sample dataframe:

from datetime import datetime
from pyspark.sql.types import *

data2 = [(1, datetime.date(2021, 1, 1), datetime.date(2021, 3, 27)),
    (2, datetime.date(2022, 2, 28), datetime.date(2022, 4, 30))
]
schema = StructType([
    StructField("ID", IntegerType(), True),
    StructField("date1", DateType(), True),
    StructField("date2", DateType(), True),
])
df = spark.createDataFrame(data=data2, schema=schema)

CodePudding user response:

based on the comments and chats, you can try to calculate first and last values of the lat/lon fields of concern.

import pyspark.sql.functions as func
from pyspark.sql.window import Window as wd
import sys

data_sdf. \
    withColumn('foo_first', func.first('foo').over(wd.partitionBy('id').orderBy('date').rowsBetween(-sys.maxsize, sys.maxsize))). \
    withColumn('foo_last', func.last('foo').over(wd.partitionBy('id').orderBy('date').rowsBetween(-sys.maxsize, sys.maxsize))). \
    select('id', 'foo_first', 'foo_last'). \
    dropDuplicates()

OR, you can create structs and take min/max

data_sdf = spark.createDataFrame(
    [(1, '2021-01-01', 2, 2),
     (1, '2021-02-05', 3, 2),
     (1, '2021-03-27', 4, 2),
     (2, '2022-02-28', 1, 5),
     (2, '2022-04-30', 5, 1)],
    ['ID', 'date', 'lat', 'lon'])

data_sdf. \
    withColumn('dt_lat_lon_struct', func.struct('date', 'lat', 'lon')). \
    groupBy('id'). \
    agg(func.min('dt_lat_lon_struct').alias('min_dt_lat_lon_struct'),
        func.max('dt_lat_lon_struct').alias('max_dt_lat_lon_struct')
        ). \
    selectExpr('id', 
               'min_dt_lat_lon_struct.lat as lat_first', 'min_dt_lat_lon_struct.lon as lon_first',
               'max_dt_lat_lon_struct.lat as lat_last', 'max_dt_lat_lon_struct.lon as lon_last'
               )

#  --- --------- --------- -------- -------- 
# | id|lat_first|lon_first|lat_last|lon_last|
#  --- --------- --------- -------- -------- 
# |  1|        2|        2|       4|       2|
# |  2|        1|        5|       5|       1|
#  --- --------- --------- -------- -------- 

CodePudding user response:

Aggregation using min and max seems could work in your case.

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(1, '2021-01-01', 2),
     (1, '2021-02-05', 2),
     (1, '2021-03-27', 2),
     (2, '2022-02-28', 1),
     (2, '2022-04-30', 1)],
    ['ID', 'date', 'shift_by'])

df = df.groupBy('ID').agg(
    F.min('date').alias('date1'),
    F.max('date').alias('date2'),
)
df.show()
#  --- ---------- ---------- 
# | ID|     date1|     date2|
#  --- ---------- ---------- 
# |  1|2021-01-01|2021-03-27|
# |  2|2022-02-28|2022-04-30|
#  --- ---------- ---------- 
  • Related