I have a pysark DataFrame looking like that:
df = spark.createDataFrame(
[(0, 'foo'),
(0, 'bar'),
(0, 'foo'),
(0, None),
(1, 'bar'),
(1, 'foo'),
(2, None),
(2, None)
],
['group', 'value'])
df.show()
Out[1]:
group value
0 foo
0 bar
0 foo
0 None
1 bar
1 foo
2 None
2 None
I would like to add rows for each variant of column variant
within each group as of col group
and than fill up each additional row with that variant. As @samkart mentioned as there are 4 zeroes in group
, there should be 4 foo and 4 bar values within the 0 group
. None values should not be counted as additional variants but groups with None values only should keep None as value
so that the result looks like that:
group value
0 foo
0 foo
0 foo
0 foo
0 bar
0 bar
0 bar
0 bar
1 bar
1 bar
1 foo
1 foo
2 None
2 None
I experimented with counting the variants and than exploding the rows with
df = df.withColumn("n",func.expr("explode(array_repeat(n,int(n)))"),)
but I can't figure out a way to fill the variant values in the desired way
CodePudding user response:
You're close. Here's a working example using your input data.
data_sdf. \
withColumn('group_count',
func.count('group').over(wd.partitionBy('group')).cast('int')
). \
filter(func.col('value').isNotNull()). \
dropDuplicates(). \
withColumn('new_val_arr', func.expr('array_repeat(value, group_count)')). \
selectExpr('group', 'explode(new_val_arr) as value'). \
show()
# ----- -----
# |group|value|
# ----- -----
# | 0| foo|
# | 0| foo|
# | 0| foo|
# | 0| foo|
# | 0| bar|
# | 0| bar|
# | 0| bar|
# | 0| bar|
# | 1| bar|
# | 1| bar|
# | 1| foo|
# | 1| foo|
# ----- -----
EDIT - As the question was updated to include null values as is for groups where all values are null.
Two ways to do.
Filter out the nulls and again append records for groups with all null
values
data2_sdf = data_sdf. \
withColumn('group_count',
func.count('group').over(wd.partitionBy('group')).cast('int')
). \
withColumn('null_count',
func.sum(func.col('value').isNull().cast('int')).over(wd.partitionBy('group'))
)
data2_sdf. \
filter(func.col('group_count') != func.col('null_count')). \
filter(func.col('value').isNotNull()). \
dropDuplicates(). \
withColumn('new_val_arr', func.expr('array_repeat(value, group_count)')). \
selectExpr('group', 'explode(new_val_arr) as value'). \
unionByName(data2_sdf.
filter(func.col('group_count') == func.col('null_count')).
select('group', 'value')
). \
show()
# ----- -----
# |group|value|
# ----- -----
# | 0| foo|
# | 0| foo|
# | 0| foo|
# | 0| foo|
# | 0| bar|
# | 0| bar|
# | 0| bar|
# | 0| bar|
# | 1| bar|
# | 1| bar|
# | 1| foo|
# | 1| foo|
# | 2| null|
# | 2| null|
# ----- -----
Or, create an array of unique values and explode it
data_sdf. \
withColumn('group_count',
func.count('group').over(wd.partitionBy('group')).cast('int')
). \
withColumn('null_count',
func.sum(func.col('value').isNull().cast('int')).over(wd.partitionBy('group'))
). \
filter(func.col('value').isNotNull() | (func.col('group_count') == func.col('null_count'))). \
groupBy('group', 'group_count'). \
agg(func.collect_set(func.coalesce('value', func.lit('null'))).alias('val_set')). \
withColumn('new_val_arr', func.expr('flatten(array_repeat(val_set, group_count))')). \
selectExpr('group', 'explode(new_val_arr) as value'). \
withColumn('value', func.when(func.col('value') != 'null', func.col('value'))). \
show()
# ----- -----
# |group|value|
# ----- -----
# | 0| bar|
# | 0| foo|
# | 0| bar|
# | 0| foo|
# | 0| bar|
# | 0| foo|
# | 0| bar|
# | 0| foo|
# | 1| bar|
# | 1| foo|
# | 1| bar|
# | 1| foo|
# | 2| null|
# | 2| null|
# ----- -----