I want to create a new column in my pyspark dataframe that is equal to one if that row is the last row in a groupby using some sorting. My solution works but seems very hacky:
My strategy was to create random numbers and (using a window function) identify the last random number per sorted group and finally compare that in a case when:
df = (df.withColumn('rand', rand())
.withColumn("marker",
F.when(F.last(col('rand'))
.over(Window.partitionBy(['bvdidnumber','dt_year'])
.orderBy(["dt_rfrnc"]).rowsBetween(0,sys.maxsize))
== col('rand'),1).otherwise(0))
.drop('rand'))
The first three columns are the input data, the last column is created using the code, goals
is the target to be achieved.
----------- ------- -------- ---- ------
|bvdidnumber|dt_year|dt_rfrnc|goal|marker|
----------- ------- -------- ---- ------
| 1| 2020| 202006| 0| 0|
| 1| 2020| 202012| 1| 1|
| 1| 2020| 202012| 0| 0|
| 1| 2021| 202103| 0| 0|
| 1| 2021| 202106| 0| 0|
| 1| 2021| 202112| 1| 1|
| 2| 2020| 202006| 0| 0|
| 2| 2020| 202012| 0| 0|
| 2| 2020| 202012| 1| 1|
| 2| 2021| 202103| 0| 0|
| 2| 2021| 202106| 0| 0|
| 2| 2021| 202112| 1| 1|
----------- ------- -------- ---- ------
CodePudding user response:
Cant see anything you have done wrong. Maybe just make it neater;
w=Window.partitionBy('bvdidnumber','dt_year').orderBy('bvdidnumber','dt_year')
df.withColumn('rand', when(sum((last('dt_rfrnc').over(w)==col('dt_rfrnc')).cast('int')).over(w.rowsBetween(-1,0))==1,1).otherwise(0)).show()
----------- ------- -------- ----
|bvdidnumber|dt_year|dt_rfrnc|rand|
----------- ------- -------- ----
| 1| 2020| 202006| 0|
| 1| 2020| 20201| 1|
| 1| 2020| 20201| 0|
| 1| 2021| 202103| 0|
| 1| 2021| 202106| 0|
| 1| 2021| 202112| 1|
| 2| 2020| 202006| 0|
| 2| 2020| 20201| 1|
| 2| 2020| 20201| 0|
| 2| 2021| 202103| 0|
| 2| 2021| 202106| 0|
| 2| 2021| 202112| 1|
----------- ------- -------- ----
CodePudding user response:
from pyspark.sql.window import Window
from pyspark.sql.functions import max, struct, lit
# create data
schema = ["bvdidnumber","dt_year","dt_rfrnc","goal","marker"]
data = [(1,2020,202006,0,0),
(1,2020,202012,1,1),
(1,2020,202012,0,0),
(1,2021,202103,0,0),
(1,2021,202106,0,0),
(1,2021,202112,1,1),
(2,2020,202006,0,0),
(2,2020,202012,0,0),
(2,2020,202012,1,1),
(2,2021,202103,0,0),
(2,2021,202106,0,0),
(2,2021,202112,1,1),
]
df = spark.createDataFrame( data, schema )
df\
.select(
#get max this will sort data based on first column of struct
struct( # create struct to carry all data forward - (think column with columns)
col("dt_rfrnc"), # must be first to ensure your sort on this column
col("dt_year"),
col("bvdidnumber")
).alias("columns")
)\
.groupby(
col("columns.bvdidnumber"),
col("columns.dt_year"))
.agg(
arrays_zip( #merge arrays in indexed manner
expr("concat(array_repeat(1, 1), array_repeat( 0 , size(collect_list( columns ))-1 ) ) "), #Build array of 1's and zerps'
reverse( #sort descending
sort_array( #sort
collect_list( "columns" )))).alias("zip") )\ #collect group elements
.select(
col("bvdidnumber") ,
col("dt_year"),
explode("zip") )\#use array values as rows
.select(
col("bvdidnumber") ,
col("dt_year"),
col("col.`0`")\# funny column reference required for working with arrays_zip output
.alias("marked")),\
col("col.`1`.dt_rfrnc").alias("dt_rfrnc")
.show()
----------- ------- ------ --------
|bvdidnumber|dt_year|marked|dt_rfrnc|
----------- ------- ------ --------
| 1| 2021| 1| 202112|
| 1| 2021| 0| 202106|
| 1| 2021| 0| 202103|
| 2| 2021| 1| 202112|
| 2| 2021| 0| 202106|
| 2| 2021| 0| 202103|
| 1| 2020| 1| 202012|
| 1| 2020| 0| 202012|
| 1| 2020| 0| 202006|
| 2| 2020| 1| 202012|
| 2| 2020| 0| 202012|
| 2| 2020| 0| 202006|
----------- ------- ------ --------