I have a data frame like this:
------------ ----------------- ------------------------------------
| Name | Age | Answers |
------------ ----------------- ------------------------------------
| Maria | 23 | [apple, mango, orange, banana] |
| John | 55 | [apple, orange, banana] |
| Brad | 44 | [banana] |
| Alex | 55 | [apple, mango, orange, banana] |
------------ ----------------- ------------------------------------
The "Answers" column contains an array of elements.
My expected output:
----- --- -------- -------
| Name|Age| answer| value |
----- --- -------- -------
|Maria| 23| apple| True |
|Maria| 23| mango| True |
|Maria| 23| orange| True |
|Maria| 23| banana| True |
| John| 55| apple| True |
| John| 55| mango| False |
| John| 55| orange| True |
| John| 55| banana| True |
| Brad| 44| apple| False |
| Brad| 44| mango| False |
| Brad| 44| orange| False |
| Brad| 44| banana| True |
|Alex | 55| apple| True |
|Alex | 55| mango| True |
|Alex | 55| orange| True |
|Alex | 55| banana| True |
----- --- -------- -------
How can I explode the "Answers" column in such a way that I would get the "value" column with True or False based on the array?
For example,
| John| 55| mango| False |
there is no "mango" in John's answer. Hence the value is false. Similarly for Brad there will be three false rows.
CodePudding user response:
Before exploding, you could collect all possible values in "Answers" column. Add them to the dataframe, explode and select required columns.
Input:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[('Maria', 23, ['apple', 'mango', 'orange', 'banana']),
('John', 55, ['apple', 'orange', 'banana']),
('Brad', 44, ['banana']),
('Alex', 55, ['apple', 'mango', 'orange', 'banana'])],
['Name', 'Age', 'Answers'])
Script:
unique_answers = set(df.agg(F.flatten(F.collect_set('Answers'))).head()[0])
df = df.withColumn('answer', F.explode(F.array([F.lit(x) for x in unique_answers])))
df = df.select(
'Name', 'Age', 'answer',
F.exists('Answers', lambda x: x == F.col('answer')).alias('value')
*[c for c in df.columns if c not in {'Name', 'Age', 'Answers', 'answer'}]
)
df.show()
# ----- --- ------ -----
# | Name|Age|answer|value|
# ----- --- ------ -----
# |Maria| 23|orange| true|
# |Maria| 23| mango| true|
# |Maria| 23| apple| true|
# |Maria| 23|banana| true|
# | John| 55|orange| true|
# | John| 55| mango|false|
# | John| 55| apple| true|
# | John| 55|banana| true|
# | Brad| 44|orange|false|
# | Brad| 44| mango|false|
# | Brad| 44| apple|false|
# | Brad| 44|banana| true|
# | Alex| 55|orange| true|
# | Alex| 55| mango| true|
# | Alex| 55| apple| true|
# | Alex| 55|banana| true|
# ----- --- ------ -----