Home > Mobile >  PySpark - Dictionary to Dataframe
PySpark - Dictionary to Dataframe

Time:08-21

I'm looking for the most elegant and effective way to convert a dictionary to Spark Data Frame with PySpark with the described output and input.

Input :

data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}

Output :

 vals  |  keys
------------
"val1" | ["key1"]
"val2" | ["key1"]
"val3" | ["key1", "key2"]
"val4" | ["key2"]
"val5" | ["key2"]

edit: I prefer to do most of the manipulation with Spark. maybe first convert it to

 vals  |  keys
------------
"val1" | "key1"
"val2" | "key1"
"val3" | "key1"
"Val3" | "key2"
"val4" | "key2"
"val5" | "key2"

CodePudding user response:

First construct the Spark dataframe from dictionary items. Then explode the vals and then group by the vals and collect all keys that contain that value.

from pyspark.sql import functions as F

data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}

df = spark.createDataFrame(data.items(), ("keys", "vals"))

(df.withColumn("vals", F.explode("vals"))
  .groupBy("vals").agg(F.collect_list("keys").alias("keys"))
).show()

"""
 ---- ------------ 
|vals|        keys|
 ---- ------------ 
|val1|      [key1]|
|val3|[key1, key2]|
|val2|      [key1]|
|val4|      [key2]|
|val5|      [key2]|
 ---- ------------ 
"""

CodePudding user response:

data = {"key1" : ["val1", "val2", "val3"], "key2" : ["val3", "val4", "val5"]}

df = spark.createDataFrame(data.items(), ("keys", "vals"))
df

from pyspark.sql.functions import *
from pyspark.sql.types import *



def flatten_test(df, sep="_"):
    """Returns a flattened dataframe.
    .. versionadded:: x.X.X

    Parameters
    ----------
    sep : str
        Delimiter for flatted columns. Default `_`

    Notes
    -----
    Don`t use `.` as `sep`
    It won't work on nested data frames with more than one level.
    And you will have to use `columns.name`.

    Flattening Map Types will have to find every key in the column.
    This can be slow.

    Examples
    --------

    data_mixed = [
        {
            "state": "Florida",
            "shortname": "FL",
            "info": {"governor": "Rick Scott"},
            "counties": [
                {"name": "Dade", "population": 12345},
                {"name": "Broward", "population": 40000},
                {"name": "Palm Beach", "population": 60000},
            ],
        },
        {
            "state": "Ohio",
            "shortname": "OH",
            "info": {"governor": "John Kasich"},
            "counties": [
                {"name": "Summit", "population": 1234},
                {"name": "Cuyahoga", "population": 1337},
            ],
        },
    ]

    data_mixed = spark.createDataFrame(data=data_mixed)

    data_mixed.printSchema()

    root
    |-- counties: array (nullable = true)
    |    |-- element: map (containsNull = true)
    |    |    |-- key: string
    |    |    |-- value: string (valueContainsNull = true)
    |-- info: map (nullable = true)
    |    |-- key: string
    |    |-- value: string (valueContainsNull = true)
    |-- shortname: string (nullable = true)
    |-- state: string (nullable = true)


    data_mixed_flat = flatten_test(df, sep=":")
    data_mixed_flat.printSchema()
    root
    |-- shortname: string (nullable = true)
    |-- state: string (nullable = true)
    |-- counties:name: string (nullable = true)
    |-- counties:population: string (nullable = true)
    |-- info:governor: string (nullable = true)




    data = [
        {
            "id": 1,
            "name": "Cole Volk",
            "fitness": {"height": 130, "weight": 60},
        },
        {"name": "Mark Reg", "fitness": {"height": 130, "weight": 60}},
        {
            "id": 2,
            "name": "Faye Raker",
            "fitness": {"height": 130, "weight": 60},
        },
    ]


    df = spark.createDataFrame(data=data)

    df.printSchema()

    root
    |-- fitness: map (nullable = true)
    |    |-- key: string
    |    |-- value: long (valueContainsNull = true)
    |-- id: long (nullable = true)
    |-- name: string (nullable = true)

    df_flat = flatten_test(df, sep=":")

    df_flat.printSchema()

    root
    |-- id: long (nullable = true)
    |-- name: string (nullable = true)
    |-- fitness:height: long (nullable = true)
    |-- fitness:weight: long (nullable = true)

    data_struct = [
            (("James",None,"Smith"),"OH","M"),
            (("Anna","Rose",""),"NY","F"),
            (("Julia","","Williams"),"OH","F"),
            (("Maria","Anne","Jones"),"NY","M"),
            (("Jen","Mary","Brown"),"NY","M"),
            (("Mike","Mary","Williams"),"OH","M")
            ]


    schema = StructType([
        StructField('name', StructType([
            StructField('firstname', StringType(), True),
            StructField('middlename', StringType(), True),
            StructField('lastname', StringType(), True)
            ])),
        StructField('state', StringType(), True),
        StructField('gender', StringType(), True)
        ])

    df_struct = spark.createDataFrame(data = data_struct, schema = schema)

    df_struct.printSchema()

    root
    |-- name: struct (nullable = true)
    |    |-- firstname: string (nullable = true)
    |    |-- middlename: string (nullable = true)
    |    |-- lastname: string (nullable = true)
    |-- state: string (nullable = true)
    |-- gender: string (nullable = true)

    df_struct_flat = flatten_test(df_struct, sep=":")

    df_struct_flat.printSchema()

    root
    |-- state: string (nullable = true)
    |-- gender: string (nullable = true)
    |-- name:firstname: string (nullable = true)
    |-- name:middlename: string (nullable = true)
    |-- name:lastname: string (nullable = true)
    """
    # compute Complex Fields (Arrays, Structs and Maptypes) in Schema
    complex_fields = dict(
        [
            (field.name, field.dataType)
            for field in df.schema.fields
            if type(field.dataType) == ArrayType
            or type(field.dataType) == StructType
            or type(field.dataType) == MapType
        ]
    )

    while len(complex_fields) != 0:
        col_name = list(complex_fields.keys())[0]
        # print ("Processing :" col_name " Type : " str(type(complex_fields[col_name])))

        # if StructType then convert all sub element to columns.
        # i.e. flatten structs
        if type(complex_fields[col_name]) == StructType:
            expanded = [
                col(col_name   "."   k).alias(col_name   sep   k)
                for k in [n.name for n in complex_fields[col_name]]
            ]
            df = df.select("*", *expanded).drop(col_name)

        # if ArrayType then add the Array Elements as Rows using the explode function
        # i.e. explode Arrays
        elif type(complex_fields[col_name]) == ArrayType:
            df = df.withColumn(col_name, explode_outer(col_name))

        # if MapType then convert all sub element to columns.
        # i.e. flatten
        elif type(complex_fields[col_name]) == MapType:
            keys_df = df.select(explode_outer(map_keys(col(col_name)))).distinct()
            keys = list(map(lambda row: row[0], keys_df.collect()))
            key_cols = list(
                map(
                    lambda f: col(col_name).getItem(f).alias(str(col_name   sep   f)),
                    keys,
                )
            )
            drop_column_list = [col_name]
            df = df.select(
                [
                    col_name
                    for col_name in df.columns
                    if col_name not in drop_column_list
                ]
                  key_cols
            )

        # recompute remaining Complex Fields in Schema
        complex_fields = dict(
            [
                (field.name, field.dataType)
                for field in df.schema.fields
                if type(field.dataType) == ArrayType
                or type(field.dataType) == StructType
                or type(field.dataType) == MapType
            ]
        )

    return df

df_falt = flatten_test(df)

df_falt
keys vals
key1 val1
key1 val2
key1 val3
key2 val3
key2 val4
key2 val5
  • Related