I have a PySpark dataframe:
userid | sku | action |
---|---|---|
123 | 2345 | 2 |
123 | 2345 | 0 |
123 | 5422 | 0 |
123 | 7622 | 0 |
231 | 4322 | 2 |
231 | 4322 | 0 |
231 | 8342 | 0 |
231 | 5342 | 0 |
The output should be like:
userid | sku_pos | sku_neg |
---|---|---|
123 | 2345 | 5422 |
123 | 2345 | 7622 |
231 | 4322 | 8342 |
231 | 4322 | 5342 |
For each distinct "userid" the "sku" which don't have an "action" > 0 will go to column "sku_neg", while the "sku" which has an "action" > 0 will go to column "sku_pos".
CodePudding user response:
Create positive and negative dataframes by filtering pos/neg records and grouping by "userid":
df_pos = df \
.filter(F.col("action") > 0) \
.groupBy("userid") \
.agg(F.collect_set("sku").alias("sku_pos_list")) \
.withColumnRenamed("userid", "userid_pos")
[Out]:
---------- ------------
|userid_pos|sku_pos_list|
---------- ------------
| 123| [2345]|
| 231| [4322]|
---------- ------------
df_neg = df \
.filter(F.col("action") <= 0) \
.groupBy("userid") \
.agg(F.collect_set("sku").alias("sku_neg_list")) \
.withColumnRenamed("userid", "userid_neg")
[Out]:
---------- ------------------
|userid_neg| sku_neg_list|
---------- ------------------
| 123|[2345, 5422, 7622]|
| 231|[8342, 4322, 5342]|
---------- ------------------
Join back the positive and negative dataframes and explode the pos/neg records:
df_joined = df_pos.join(df_neg, (F.col("userid_pos")==F.col("userid_neg")), how="full")
# Clean up null, empty
df_joined = df_joined \
.withColumn("userid", F.when(F.col("userid_pos").isNotNull(), F.col("userid_pos")).otherwise(F.col("userid_neg"))).drop("userid_pos", "userid_neg") \
.withColumn("sku_pos_list", F.when(F.col("sku_pos_list").isNull(), F.array([F.lit(-1)])).otherwise(F.col("sku_pos_list"))) \
.withColumn("sku_neg_list", F.when(F.col("sku_neg_list").isNull(), F.array([F.lit(-1)])).otherwise(F.col("sku_neg_list")))
[Out]:
------------ ------------------ ------
|sku_pos_list|sku_neg_list |userid|
------------ ------------------ ------
|[2345] |[2345, 5422, 7622]|123 |
|[4322] |[8342, 4322, 5342]|231 |
------------ ------------------ ------
df_joined = df_joined \
.withColumn("sku_pos", F.explode("sku_pos_list")) \
.withColumn("sku_neg", F.explode("sku_neg_list")) \
.drop("sku_pos_list", "sku_neg_list") \
.filter(F.col("sku_pos") != F.col("sku_neg"))
[Out]:
------ ------- -------
|userid|sku_pos|sku_neg|
------ ------- -------
| 123| 2345| 5422|
| 123| 2345| 7622|
| 231| 4322| 8342|
| 231| 4322| 5342|
------ ------- -------
Dataset used:
df = spark.createDataFrame([
(123,2345,2),
(123,2345,0),
(123,5422,0),
(123,7622,0),
(231,4322,2),
(231,4322,0),
(231,8342,0),
(231,5342,0),
], ["userid", "sku", "action"])
CodePudding user response:
The other proposed solution seems perfectly fine but just in case, another approach that does not need a join. Note that I assume that there is only one sku_pos
per userid
. If that's not the case, this won't work.
spark.read.option("header", "true").csv("sku")\
.withColumn("action", f.col("action") > 0)\
.groupBy("userid", "sku")\
.agg(f.max("action").alias("action"))\
.groupBy("userid", "action")\
.agg(f.collect_set("sku").alias("skus"))\
.withColumn("sku_pos", f.col("skus").getItem(0))\
.withColumn("sku_neg", f.when(~ f.col("action"), f.col("skus")))\
.groupBy("userid")\
.agg(f.first("sku_pos").alias("sku_pos"), f.first("sku_neg", ignorenulls=True).alias("sku_neg"))\
.withColumn("sku_neg", f.explode("sku_neg"))\
.show()\
------ ------- -------
|userid|sku_pos|sku_neg|
------ ------- -------
| 123| 5422| 5422|
| 123| 5422| 7622|
| 231| 4322| 5342|
| 231| 4322| 8342|
------ ------- -------
Basically the idea is first to use a groupBy
to collect the positive and negative sku
separately. Then I use f.col("skus").getItem(0)
to only select one sku_pos
, use another groupBy
to have one line per userid
and finally explode the sku_neg
array.
CodePudding user response:
A couple of aggregations is needed:
- first, assign pos/neg status to a "sku"
- then use this status in the 2nd aggregation to collect all "sku" into lists
Finally, explode the lists.
Input:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[('123', '2345', 2),
('123', '2345', 0),
('123', '5422', 0),
('123', '7622', 0),
('231', '4322', 2),
('231', '4322', 0),
('231', '8342', 0),
('231', '5342', 0)],
['userid', 'sku', 'action'])
Script:
df = df.groupBy('userid', 'sku').agg(
F.when(F.max('action') > 0, 'p').otherwise('n').alias('_flag')
)
df = (df
.groupBy('userid').pivot('_flag', ['p', 'n']).agg(F.collect_list('sku'))
.withColumn('sku_pos', F.explode('p'))
.withColumn('sku_neg', F.explode('n'))
.drop('p', 'n')
)
df.show()
# ------ ------- -------
# |userid|sku_pos|sku_neg|
# ------ ------- -------
# | 231| 4322| 5342|
# | 231| 4322| 8342|
# | 123| 2345| 7622|
# | 123| 2345| 5422|
# ------ ------- -------