Home > front end >  How to explode array column to produce a boolean column in PySpark
How to explode array column to produce a boolean column in PySpark

Time:11-02

I have a data frame like this:

 ------------ ----------------- ------------------------------------ 
| Name       |   Age           | Answers                            |
 ------------ ----------------- ------------------------------------ 
| Maria      | 23              | [apple, mango, orange, banana]     | 
| John       | 55              | [apple, orange, banana]            |
| Brad       | 44              | [banana]                           |
| Alex       | 55              | [apple, mango, orange, banana]     |
 ------------ ----------------- ------------------------------------ 

The "Answers" column contains an array of elements.

My expected output:

 ----- --- -------- -------                                                               
| Name|Age|  answer| value |
 ----- --- -------- ------- 
|Maria| 23|   apple| True  |
|Maria| 23|   mango| True  |
|Maria| 23|  orange| True  |
|Maria| 23|  banana| True  |
| John| 55|   apple| True  |
| John| 55|   mango| False |
| John| 55|  orange| True  |
| John| 55|  banana| True  |
| Brad| 44|   apple| False |
| Brad| 44|   mango| False |
| Brad| 44|  orange| False |
| Brad| 44|  banana| True  |
|Alex | 55|   apple| True  |
|Alex | 55|   mango| True  |
|Alex | 55|  orange| True  |
|Alex | 55|  banana| True  |
 ----- --- -------- ------- 

How can I explode the "Answers" column in such a way that I would get the "value" column with True or False based on the array?

For example,

| John| 55|   mango| False |

there is no "mango" in John's answer. Hence the value is false. Similarly for Brad there will be three false rows.

CodePudding user response:

Before exploding, you could collect all possible values in "Answers" column. Add them to the dataframe, explode and select required columns.

Input:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [('Maria', 23, ['apple', 'mango', 'orange', 'banana']),
     ('John', 55, ['apple', 'orange', 'banana']),
     ('Brad', 44, ['banana']),
     ('Alex', 55, ['apple', 'mango', 'orange', 'banana'])],
    ['Name', 'Age', 'Answers'])

Script:

unique_answers = set(df.agg(F.flatten(F.collect_set('Answers'))).head()[0])
df = df.withColumn('answer', F.explode(F.array([F.lit(x) for x in unique_answers])))
df = df.select(
    'Name', 'Age', 'answer',
    F.exists('Answers', lambda x: x == F.col('answer')).alias('value')
    *[c for c in df.columns if c not in {'Name', 'Age', 'Answers', 'answer'}]
)
df.show()
#  ----- --- ------ ----- 
# | Name|Age|answer|value|
#  ----- --- ------ ----- 
# |Maria| 23|orange| true|
# |Maria| 23| mango| true|
# |Maria| 23| apple| true|
# |Maria| 23|banana| true|
# | John| 55|orange| true|
# | John| 55| mango|false|
# | John| 55| apple| true|
# | John| 55|banana| true|
# | Brad| 44|orange|false|
# | Brad| 44| mango|false|
# | Brad| 44| apple|false|
# | Brad| 44|banana| true|
# | Alex| 55|orange| true|
# | Alex| 55| mango| true|
# | Alex| 55| apple| true|
# | Alex| 55|banana| true|
#  ----- --- ------ ----- 
  • Related