Home > Back-end >  pyspark get latest non-null element of every column in one row
pyspark get latest non-null element of every column in one row

Time:12-23

Let me explain my question using an example: I have a dataframe:

pd_1 = pd.DataFrame({'day':[1,2,3,2,1,3], 
                     'code': [10, 10, 20,20,30,30],
                     'A': [44, 55, 66,77,88,99],
                     'B':['a',None,'c',None,'d', None],
                     'C':[None,None,'12',None,None, None]
                    })
df_1 = sc.createDataFrame(pd_1)
df_1.show()

Output:

 --- ---- --- ---- ---- 
|day|code|  A|   B|   C|
 --- ---- --- ---- ---- 
|  1|  10| 44|   a|null|
|  2|  10| 55|null|null|
|  3|  20| 66|   c|  12|
|  2|  20| 77|null|null|
|  1|  30| 88|   d|null|
|  3|  30| 99|null|null|
 --- ---- --- ---- ---- 

What I want to achieve is a new dataframe, each row corresponds to a code, and for each column I want to have the most recent non-null value (with highest day).

In pandas, I can simply do

pd_2 = pd_1.sort_values('day', ascending=True).groupby('code').last()
pd_2.reset_index()

to get

    code    day A   B   C
0   10       2  55  a   None
1   20       3  66  c   12
2   30       3  99  d   None

My question is, how can I do it in pyspark (preferably version < 3)?


What I have tried so far is:

from pyspark.sql import Window
import pyspark.sql.functions as F

w = Window.partitionBy('code').orderBy(F.desc('day')).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

## Update: after applying @Steven's idea to remove for loop:
df_1 = df_1 .select([F.collect_list(x).over(w).getItem(0).alias(x) for x in df_.columns])

##for x in df_1.columns:
##    df_1 = df_1.withColumn(x, F.collect_list(x).over(w).getItem(0))

df_1 = df_1.distinct()
df_1.show()

Output

 --- ---- --- --- ---- 
|day|code|  A|  B|   C|
 --- ---- --- --- ---- 
|  2|  10| 55|  a|null|
|  3|  30| 99|  d|null|
|  3|  20| 66|  c|  12|
 --- ---- --- --- ---- 

Which I'm not very happy with, especially due to the for loop.

CodePudding user response:

I think your current solution is quite nice. If you want another solution, you can try using first/last window functions :

from pyspark.sql import functions as F, Window

w = Window.partitionBy("code").orderBy(F.col("day").desc())


df2 = (
    df.select(
        "day",
        "code",
        F.row_number().over(w).alias("rwnb"),
        *(
            F.first(F.col(col), ignorenulls=True)
            .over(w.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
            .alias(col)
            for col in ("A", "B", "C")
        ),
    )
    .where("rwnb = 1")
    .drop("rwnb")
)

and the result :

df2.show()

 --- ---- --- --- ---- 
|day|code|  A|  B|   C|
 --- ---- --- --- ---- 
|  2|  10| 55|  a|null|
|  3|  30| 99|  d|null|
|  3|  20| 66|  c|  12|
 --- ---- --- --- ---- 

CodePudding user response:

Here's another way of doing by using array functions and struct ordering instead of Window:

from pyspark.sql import functions as F

other_cols = ["day", "A", "B", "C"]

df_1 = df_1.groupBy("code").agg(
    F.collect_list(F.struct(*other_cols)).alias("values")
).selectExpr(
    "code",
    *[f"array_max(filter(values, x-> x.{c} is not null))['{c}'] as {c}" for c in other_cols]
)

df_1.show()
# ---- --- --- --- ---- 
#|code|day|  A|  B|   C|
# ---- --- --- --- ---- 
#|  10|  2| 55|  a|null|
#|  30|  3| 99|  d|null|
#|  20|  3| 66|  c|  12|
# ---- --- --- --- ---- 
  • Related