I have dataframe like this:
column_1 column_2
['a','c'] 1
['b','c'] 2
['a','b','c'] 1
Now I want to add 3 columns (a, b and c), based of frequency of occurrence.
Desired output:
a b c column_2
1 0 1 1
0 1 1 2
1 1 1 1
CodePudding user response:
Assuming your pyspark dataframe is:
df.show()
--------- --------
| column_1|column_2|
--------- --------
| [a, c]| 1|
| [b, c]| 2|
|[a, b, c]| 1|
--------- --------
You can first explode
column column_1
:
import pyspark.sql.functions as F
df_1 = df.withColumn("explode_col1", F.explode("column_1"))
df_1.show()
--------- -------- ------------
| column_1|column_2|explode_col1|
--------- -------- ------------
| [a, c]| 1| a|
| [a, c]| 1| c|
| [b, c]| 2| b|
| [b, c]| 2| c|
|[a, b, c]| 1| a|
|[a, b, c]| 1| b|
|[a, b, c]| 1| c|
--------- -------- ------------
Then use groupby
and pivot
to count elements (and keep only columns you want):
df_2 = df_1.groupby('column_1', 'column_2').pivot('explode_col1').count().na.fill(0).drop('column_1')
df_2.show()
-------- --- --- ---
|column_2| a| b| c|
-------- --- --- ---
| 1| 1| 1| 1|
| 1| 1| 0| 1|
| 2| 0| 1| 1|
-------- --- --- ---
And if you want to have it all in one line:
df.withColumn("explode_col1", F.explode("column_1")).\
groupby('column_1', 'column_2').pivot('explode_col1').count().\
na.fill(0).\
drop('column_1').\
show()
CodePudding user response:
Assuming you know the names of the columns which you will create beforehand (so, you can store the names in a list), the following approaches do it without shuffling.
If you just need to know if array contains the value:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(['a','c'], 1),
(['b','c'], 2),
(['a','b','c'], 1)],
['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
arr_cols = F.array([F.lit(x) for x in cols])
arr_vals = F.transform(arr_cols, lambda c: F.array_contains('column_1', c).cast('int'))
df = df.select(
*[F.element_at(arr_vals, i 1).alias(c) for i, c in enumerate(cols)],
'column_2'
)
df.show()
# --- --- --- --------
# | a| b| c|column_2|
# --- --- --- --------
# | 1| 0| 1| 1|
# | 0| 1| 1| 2|
# | 1| 1| 1| 1|
# --- --- --- --------
If you need to know the count of occurrences:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[(['a','c'], 1),
(['b','c'], 2),
(['a','a','b','c'], 1)],
['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
arr_cols = F.array([F.lit(x) for x in cols])
arr_vals = F.transform(arr_cols, lambda c: F.size(F.array_remove(F.transform('column_1', lambda v: v == c), False)))
df = df.select(
*[F.element_at(arr_vals, i 1).alias(c) for i, c in enumerate(cols)],
'column_2'
)
df.show()
# --- --- --- --------
# | a| b| c|column_2|
# --- --- --- --------
# | 1| 0| 1| 1|
# | 0| 1| 1| 2|
# | 2| 1| 1| 1|
# --- --- --- --------