I'm looking for the most elegant and effective way to convert a dictionary to Spark Data Frame with PySpark with the described output and input.
Input :
data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}
Output :
vals | keys
------------
"val1" | ["key1"]
"val2" | ["key1"]
"val3" | ["key1", "key2"]
"val4" | ["key2"]
"val5" | ["key2"]
edit: I prefer to do most of the manipulation with Spark. maybe first convert it to
vals | keys
------------
"val1" | "key1"
"val2" | "key1"
"val3" | "key1"
"Val3" | "key2"
"val4" | "key2"
"val5" | "key2"
CodePudding user response:
First construct the Spark dataframe from dictionary items. Then explode
the vals
and then group by the vals
and collect
all keys
that contain that value.
from pyspark.sql import functions as F
data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}
df = spark.createDataFrame(data.items(), ("keys", "vals"))
(df.withColumn("vals", F.explode("vals"))
.groupBy("vals").agg(F.collect_list("keys").alias("keys"))
).show()
"""
---- ------------
|vals| keys|
---- ------------
|val1| [key1]|
|val3|[key1, key2]|
|val2| [key1]|
|val4| [key2]|
|val5| [key2]|
---- ------------
"""
CodePudding user response:
data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}
df = spark.createDataFrame(data.items(), ("keys", "vals"))
df
from pyspark.sql.functions import *
from pyspark.sql.types import *
def flatten_test(df, sep="_"):
"""Returns a flattened dataframe.
.. versionadded:: x.X.X
Parameters
----------
sep : str
Delimiter for flatted columns. Default `_`
Notes
-----
Don`t use `.` as `sep`
It won't work on nested data frames with more than one level.
And you will have to use `columns.name`.
Flattening Map Types will have to find every key in the column.
This can be slow.
Examples
--------
data_mixed = [
{
"state": "Florida",
"shortname": "FL",
"info": {"governor": "Rick Scott"},
"counties": [
{"name": "Dade", "population": 12345},
{"name": "Broward", "population": 40000},
{"name": "Palm Beach", "population": 60000},
],
},
{
"state": "Ohio",
"shortname": "OH",
"info": {"governor": "John Kasich"},
"counties": [
{"name": "Summit", "population": 1234},
{"name": "Cuyahoga", "population": 1337},
],
},
]
data_mixed = spark.createDataFrame(data=data_mixed)
data_mixed.printSchema()
root
|-- counties: array (nullable = true)
| |-- element: map (containsNull = true)
| | |-- key: string
| | |-- value: string (valueContainsNull = true)
|-- info: map (nullable = true)
| |-- key: string
| |-- value: string (valueContainsNull = true)
|-- shortname: string (nullable = true)
|-- state: string (nullable = true)
data_mixed_flat = flatten_test(df, sep=":")
data_mixed_flat.printSchema()
root
|-- shortname: string (nullable = true)
|-- state: string (nullable = true)
|-- counties:name: string (nullable = true)
|-- counties:population: string (nullable = true)
|-- info:governor: string (nullable = true)
data = [
{
"id": 1,
"name": "Cole Volk",
"fitness": {"height": 130, "weight": 60},
},
{"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}},
{
"id": 2,
"name": "Faye Raker",
"fitness": {"height": 130, "weight": 60},
},
]
df = spark.createDataFrame(data=data)
df.printSchema()
root
|-- fitness: map (nullable = true)
| |-- key: string
| |-- value: long (valueContainsNull = true)
|-- id: long (nullable = true)
|-- name: string (nullable = true)
df_flat = flatten_test(df, sep=":")
df_flat.printSchema()
root
|-- id: long (nullable = true)
|-- name: string (nullable = true)
|-- fitness:height: long (nullable = true)
|-- fitness:weight: long (nullable = true)
data_struct = [
(("James",None,"Smith"),"OH","M"),
(("Anna","Rose",""),"NY","F"),
(("Julia","","Williams"),"OH","F"),
(("Maria","Anne","Jones"),"NY","M"),
(("Jen","Mary","Brown"),"NY","M"),
(("Mike","Mary","Williams"),"OH","M")
]
schema = StructType([
StructField('name', StructType([
StructField('firstname', StringType(), True),
StructField('middlename', StringType(), True),
StructField('lastname', StringType(), True)
])),
StructField('state', StringType(), True),
StructField('gender', StringType(), True)
])
df_struct = spark.createDataFrame(data = data_struct, schema = schema)
df_struct.printSchema()
root
|-- name: struct (nullable = true)
| |-- firstname: string (nullable = true)
| |-- middlename: string (nullable = true)
| |-- lastname: string (nullable = true)
|-- state: string (nullable = true)
|-- gender: string (nullable = true)
df_struct_flat = flatten_test(df_struct, sep=":")
df_struct_flat.printSchema()
root
|-- state: string (nullable = true)
|-- gender: string (nullable = true)
|-- name:firstname: string (nullable = true)
|-- name:middlename: string (nullable = true)
|-- name:lastname: string (nullable = true)
"""
# compute Complex Fields (Arrays, Structs and Maptypes) in Schema
complex_fields = dict(
[
(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType
or type(field.dataType) == StructType
or type(field.dataType) == MapType
]
)
while len(complex_fields) != 0:
col_name = list(complex_fields.keys())[0]
# print ("Processing :" col_name " Type : " str(type(complex_fields[col_name])))
# if StructType then convert all sub element to columns.
# i.e. flatten structs
if type(complex_fields[col_name]) == StructType:
expanded = [
col(col_name "." k).alias(col_name sep k)
for k in [n.name for n in complex_fields[col_name]]
]
df = df.select("*", *expanded).drop(col_name)
# if ArrayType then add the Array Elements as Rows using the explode function
# i.e. explode Arrays
elif type(complex_fields[col_name]) == ArrayType:
df = df.withColumn(col_name, explode_outer(col_name))
# if MapType then convert all sub element to columns.
# i.e. flatten
elif type(complex_fields[col_name]) == MapType:
keys_df = df.select(explode_outer(map_keys(col(col_name)))).distinct()
keys = list(map(lambda row: row[0], keys_df.collect()))
key_cols = list(
map(
lambda f: col(col_name).getItem(f).alias(str(col_name sep f)),
keys,
)
)
drop_column_list = [col_name]
df = df.select(
[
col_name
for col_name in df.columns
if col_name not in drop_column_list
]
key_cols
)
# recompute remaining Complex Fields in Schema
complex_fields = dict(
[
(field.name, field.dataType)
for field in df.schema.fields
if type(field.dataType) == ArrayType
or type(field.dataType) == StructType
or type(field.dataType) == MapType
]
)
return df
df_falt = flatten_test(df)
df_falt
keys | vals |
---|---|
key1 | val1 |
key1 | val2 |
key1 | val3 |
key2 | val3 |
key2 | val4 |
key2 | val5 |