Home > database >  Create new columns based on frequency of array from one column
Create new columns based on frequency of array from one column

Time:06-09

I have dataframe like this:

     column_1     column_2
    ['a','c']            1
    ['b','c']            2
['a','b','c']            1

Now I want to add 3 columns (a, b and c), based of frequency of occurrence.

Desired output:

a   b    c   column_2
1   0    1          1
0   1    1          2
1   1    1          1

CodePudding user response:

Assuming your pyspark dataframe is:

df.show()

 --------- -------- 
| column_1|column_2|
 --------- -------- 
|   [a, c]|       1|
|   [b, c]|       2|
|[a, b, c]|       1|
 --------- -------- 

You can first explode column column_1:

import pyspark.sql.functions as F

df_1 = df.withColumn("explode_col1", F.explode("column_1"))
df_1.show()
 --------- -------- ------------ 
| column_1|column_2|explode_col1|
 --------- -------- ------------ 
|   [a, c]|       1|           a|
|   [a, c]|       1|           c|
|   [b, c]|       2|           b|
|   [b, c]|       2|           c|
|[a, b, c]|       1|           a|
|[a, b, c]|       1|           b|
|[a, b, c]|       1|           c|
 --------- -------- ------------ 

Then use groupby and pivot to count elements (and keep only columns you want):

df_2 = df_1.groupby('column_1', 'column_2').pivot('explode_col1').count().na.fill(0).drop('column_1')
df_2.show()
 -------- --- --- --- 
|column_2|  a|  b|  c|
 -------- --- --- --- 
|       1|  1|  1|  1|
|       1|  1|  0|  1|
|       2|  0|  1|  1|
 -------- --- --- --- 

And if you want to have it all in one line:

df.withColumn("explode_col1", F.explode("column_1")).\
groupby('column_1', 'column_2').pivot('explode_col1').count().\
na.fill(0).\
drop('column_1').\
show()

CodePudding user response:

Assuming you know the names of the columns which you will create beforehand (so, you can store the names in a list), the following approaches do it without shuffling.

If you just need to know if array contains the value:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(['a','c'], 1),
     (['b','c'], 2),
     (['a','b','c'], 1)],
    ['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
arr_cols = F.array([F.lit(x) for x in cols])
arr_vals = F.transform(arr_cols, lambda c: F.array_contains('column_1', c).cast('int'))
df = df.select(
    *[F.element_at(arr_vals, i 1).alias(c) for i, c in enumerate(cols)],
    'column_2'
)
df.show()
#  --- --- --- -------- 
# |  a|  b|  c|column_2|
#  --- --- --- -------- 
# |  1|  0|  1|       1|
# |  0|  1|  1|       2|
# |  1|  1|  1|       1|
#  --- --- --- -------- 

If you need to know the count of occurrences:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(['a','c'], 1),
     (['b','c'], 2),
     (['a','a','b','c'], 1)],
    ['column_1', 'column_2']
)
cols = ['a', 'b', 'c']
arr_cols = F.array([F.lit(x) for x in cols])
arr_vals = F.transform(arr_cols, lambda c: F.size(F.array_remove(F.transform('column_1', lambda v: v == c), False)))
df = df.select(
    *[F.element_at(arr_vals, i 1).alias(c) for i, c in enumerate(cols)],
    'column_2'
)
df.show()
#  --- --- --- -------- 
# |  a|  b|  c|column_2|
#  --- --- --- -------- 
# |  1|  0|  1|       1|
# |  0|  1|  1|       2|
# |  2|  1|  1|       1|
#  --- --- --- -------- 
  • Related