Home > Software engineering >  Pivot column based on aggregation result
Pivot column based on aggregation result

Time:10-25

I have a PySpark dataframe:

userid sku action
123 2345 2
123 2345 0
123 5422 0
123 7622 0
231 4322 2
231 4322 0
231 8342 0
231 5342 0

The output should be like:

userid sku_pos sku_neg
123 2345 5422
123 2345 7622
231 4322 8342
231 4322 5342

For each distinct "userid" the "sku" which don't have an "action" > 0 will go to column "sku_neg", while the "sku" which has an "action" > 0 will go to column "sku_pos".

CodePudding user response:

Create positive and negative dataframes by filtering pos/neg records and grouping by "userid":

df_pos = df \
  .filter(F.col("action") > 0) \
  .groupBy("userid") \
  .agg(F.collect_set("sku").alias("sku_pos_list")) \
  .withColumnRenamed("userid", "userid_pos")

[Out]:
 ---------- ------------ 
|userid_pos|sku_pos_list|
 ---------- ------------ 
|       123|      [2345]|
|       231|      [4322]|
 ---------- ------------ 


df_neg = df \
  .filter(F.col("action") <= 0) \
  .groupBy("userid") \
  .agg(F.collect_set("sku").alias("sku_neg_list")) \
  .withColumnRenamed("userid", "userid_neg")

[Out]:
 ---------- ------------------ 
|userid_neg|      sku_neg_list|
 ---------- ------------------ 
|       123|[2345, 5422, 7622]|
|       231|[8342, 4322, 5342]|
 ---------- ------------------ 

Join back the positive and negative dataframes and explode the pos/neg records:

df_joined = df_pos.join(df_neg, (F.col("userid_pos")==F.col("userid_neg")), how="full")

# Clean up null, empty
df_joined = df_joined \
  .withColumn("userid", F.when(F.col("userid_pos").isNotNull(), F.col("userid_pos")).otherwise(F.col("userid_neg"))).drop("userid_pos", "userid_neg") \
  .withColumn("sku_pos_list", F.when(F.col("sku_pos_list").isNull(), F.array([F.lit(-1)])).otherwise(F.col("sku_pos_list"))) \
  .withColumn("sku_neg_list", F.when(F.col("sku_neg_list").isNull(), F.array([F.lit(-1)])).otherwise(F.col("sku_neg_list")))

[Out]:
 ------------ ------------------ ------ 
|sku_pos_list|sku_neg_list      |userid|
 ------------ ------------------ ------ 
|[2345]      |[2345, 5422, 7622]|123   |
|[4322]      |[8342, 4322, 5342]|231   |
 ------------ ------------------ ------ 


df_joined = df_joined \
  .withColumn("sku_pos", F.explode("sku_pos_list")) \
  .withColumn("sku_neg", F.explode("sku_neg_list")) \
  .drop("sku_pos_list", "sku_neg_list") \
  .filter(F.col("sku_pos") != F.col("sku_neg"))

[Out]:
 ------ ------- ------- 
|userid|sku_pos|sku_neg|
 ------ ------- ------- 
|   123|   2345|   5422|
|   123|   2345|   7622|
|   231|   4322|   8342|
|   231|   4322|   5342|
 ------ ------- ------- 

Dataset used:

df = spark.createDataFrame([
  (123,2345,2),
  (123,2345,0),
  (123,5422,0),
  (123,7622,0),
  (231,4322,2),
  (231,4322,0),
  (231,8342,0),
  (231,5342,0),
], ["userid", "sku", "action"])

CodePudding user response:

The other proposed solution seems perfectly fine but just in case, another approach that does not need a join. Note that I assume that there is only one sku_pos per userid. If that's not the case, this won't work.

spark.read.option("header", "true").csv("sku")\
    .withColumn("action", f.col("action") > 0)\
    .groupBy("userid", "sku")\
    .agg(f.max("action").alias("action"))\
    .groupBy("userid", "action")\
    .agg(f.collect_set("sku").alias("skus"))\
    .withColumn("sku_pos", f.col("skus").getItem(0))\
    .withColumn("sku_neg", f.when(~ f.col("action"), f.col("skus")))\
    .groupBy("userid")\
    .agg(f.first("sku_pos").alias("sku_pos"), f.first("sku_neg", ignorenulls=True).alias("sku_neg"))\
    .withColumn("sku_neg", f.explode("sku_neg"))\
    .show()\
 ------ ------- ------- 
|userid|sku_pos|sku_neg|
 ------ ------- ------- 
|   123|   5422|   5422|
|   123|   5422|   7622|
|   231|   4322|   5342|
|   231|   4322|   8342|
 ------ ------- ------- 

Basically the idea is first to use a groupBy to collect the positive and negative sku separately. Then I use f.col("skus").getItem(0) to only select one sku_pos, use another groupBy to have one line per userid and finally explode the sku_neg array.

CodePudding user response:

A couple of aggregations is needed:

  • first, assign pos/neg status to a "sku"
  • then use this status in the 2nd aggregation to collect all "sku" into lists

Finally, explode the lists.

Input:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [('123', '2345', 2),
     ('123', '2345', 0),
     ('123', '5422', 0),
     ('123', '7622', 0),
     ('231', '4322', 2),
     ('231', '4322', 0),
     ('231', '8342', 0),
     ('231', '5342', 0)],
    ['userid', 'sku', 'action'])

Script:

df = df.groupBy('userid', 'sku').agg(
    F.when(F.max('action') > 0, 'p').otherwise('n').alias('_flag')
)
df = (df
    .groupBy('userid').pivot('_flag', ['p', 'n']).agg(F.collect_list('sku'))
    .withColumn('sku_pos', F.explode('p'))
    .withColumn('sku_neg', F.explode('n'))
    .drop('p', 'n')
)

df.show()
#  ------ ------- ------- 
# |userid|sku_pos|sku_neg|
#  ------ ------- ------- 
# |   231|   4322|   5342|
# |   231|   4322|   8342|
# |   123|   2345|   7622|
# |   123|   2345|   5422|
#  ------ ------- ------- 
  • Related