Home > Software design >  PySpark - Selecting all rows within each group
PySpark - Selecting all rows within each group

Time:05-17

I have a dataframe similar to below.

from datetime import date
rdd = sc.parallelize([
     [123,date(2007,1,31),1],
     [123,date(2007,2,28),1],
     [123,date(2007,3,31),1],
     [123,date(2007,4,30),1],
     [123,date(2007,5,31),1],
     [123,date(2007,6,30),1],
     [123,date(2007,7,31),1],
     [123,date(2007,8,31),1],
     [123,date(2007,8,31),2],
     [123,date(2007,9,30),1],
     [123,date(2007,9,30),2],
     [123,date(2007,10,31),1],
     [123,date(2007,10,31),2],
     [123,date(2007,11,30),1],
     [123,date(2007,11,30),2],
     [123,date(2007,12,31),1],
     [123,date(2007,12,31),2],
     [123,date(2007,12,31),3],
     [123,date(2008,1,31),1],
     [123,date(2008,1,31),2],
     [123,date(2008,1,31),3]
])

df = rdd.toDF(['id','sale_date','sale'])
df.show()

From the above dataframe, I would like to keep all rows upto the most recent sale relative to the date. So essentially, I will only have unique date for each row. In the case of above example, output would look like:

rdd_out = sc.parallelize([
        [123,date(2007,1,31),1],
        [123,date(2007,2,28),1],
        [123,date(2007,3,31),1],
        [123,date(2007,4,30),1],
        [123,date(2007,5,31),1],
        [123,date(2007,6,30),1],
        [123,date(2007,7,31),1],
        [123,date(2007,8,31),2],
        [123,date(2007,9,30),2],
        [123,date(2007,10,31),2],
        [123,date(2007,11,30),2],
        [123,date(2007,12,31),2],
        [123,date(2008,1,31),3]
         ])

df_out = rdd_out.toDF(['id','sale_date','sale'])
df_out.show()

Can you please guide on how can I go to this result?

As an FYI - Using SAS, I would have achieved this results as follows:

proc sort data = df; 
   by id date sale;
run;

data want; 
 set df;
 by id date sale;
 if last.date;
run;

CodePudding user response:

There is probably many ways to achieve this, but one way is to use Window. With Window you can partition your data on one or more columns (in your case sale_date) and on top of that you can order the data within each partition by a specific column (in your case descending on sale, such that latest sale is first). So:

from pyspark.sql.window import Window
from pyspark.sql.functions import desc
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))

What you can then do is to apply this Window on your DataFrame and apply one out of many Window-functions. One of the functions you can apply is row_number which for each partition, adds a row number to each row based on your orderBy. Like this:

from pyspark.sql.functions import row_number
df_out = df.withColumn("row_number",row_number().over(my_window))

Which will result in that the last sale for each date will have row_number = 1. If you then filter on row_number=1 you will get the last sale for each group.

So, the full code:

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, desc, col
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))
df_out = (
        df
        .withColumn("row_number",row_number().over(my_window))
        .filter(col("row_number") == 1)
        .drop("row_number")
    )

CodePudding user response:

Here you would like replace "Department" with sale_date and "Salary" with sale.

Here's a none window example of the same thing... @Cleared's answer is excellent. This answer would likely perform better on very large data sets than using a window. Windows in my experience are slower than using logical equivalent with a groupBy. (Feel free to test what works better for you.) Windows are very simple write and easy to understand, so likely a better choice if the data is small.

from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkExample').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
unGroupedDf = df.select( \
  df["Department"], \
  f.struct(*[\ # Make a struct with all the record elements.
    df["Department"].alias("Dept"),\
    df["Salary"].alias("Salary"),\
    df["Name"].alias("Name")] )\
  .alias("record") )
unGroupedDf.groupBy("Department")\ #group
 .agg(f.collect_list("record")\  #Gather all the element in a group
  .alias("record"))\
  .select(\
    f.reverse(\ #Make the sort Descending
      f.array_sort(\ #Sort the array ascending
        f.col("record")\ #the struct
      )\
    )[0].alias("record"))\ #grab the "Max element in the array
    ).select( f.col("record.*") ).show() # use struct as Columns
  .show()

NOTE: If you do not specify a partitionBy with a window it will ship all data to one node to be processed. This would be a performance issue.

  • Related