I have a PySpark DataFrame with a map column as below:
root
|-- id: long (nullable = true)
|-- map_col: map (nullable = true)
| |-- key: string
| |-- value: double (valueContainsNull = true)
The map_col
has keys which need to be converted based on a dictionary. For example, the dictionary might be:
mapping = {'a': '1', 'b': '2', 'c': '5', 'd': '8' }
So, the DataFrame
needs to change from:
[Row(id=123, map_col={'a': 0.0, 'b': -42.19}),
Row(id=456, map_col={'a': 13.25, 'c': -19.6, 'd': 15.6})]
to the following:
[Row(id=123, map_col={'1': 0.0, '2': -42.19}),
Row(id=456, map_col={'1': 13.25, '5': -19.6, '8': 15.6})]
I see that transform_keys
is an option if I could write-out the dictionary, but it's too large and dynamically-generated earlier in the workflow. I think an explode
/pivot
could also work, but seems non-performant?
Any ideas?
Edit: Added a bit to show that size of map
in map_col
is not uniform.
CodePudding user response:
transform_keys
can use a lambda
, as shown in the example, it's not just limited to an expr
. However, the lambda
or Python callable will need to utilize a function either defined in pyspark.sql.functions
, a Column
method, or a Scala UDF, so using a Python UDF which refers to the mapping
dictionary object isn't currently possible with this mechanism. However, we can make use of the when function to apply the mapping, by unrolling the key-value pairs in the mapping
into chained when
conditions. See the below example to illustrate the idea:
from typing import Dict, Callable
from functools import reduce
from pyspark.sql.functions import Column, when, transform_keys
from pyspark.sql import SparkSession
def apply_mapping(mapping: Dict[str, str]) -> Callable[[Column, Column], Column]:
def convert_mapping_into_when_conditions(key: Column, _: Column) -> Column:
initial_key, initial_value = mapping.popitem()
initial_condition = when(key == initial_key, initial_value)
return reduce(lambda x, y: x.when(key == y[0], y[1]), mapping.items(), initial_condition)
return convert_mapping_into_when_conditions
if __name__ == "__main__":
spark = SparkSession\
.builder\
.appName("Temp")\
.getOrCreate()
df = spark.createDataFrame([(1, {"foo": -2.0, "bar": 2.0})], ("id", "data"))
mapping = {'foo': 'a', 'bar': 'b'}
df.select(transform_keys(
"data", apply_mapping(mapping)).alias("data_transformed")
).show(truncate=False)
The output of the above is:
---------------------
|data_transformed |
---------------------
|{b -> 2.0, a -> -2.0}|
---------------------
which demonstrates the defined mapping (foo -> a, bar -> b
) was successfully applied to the column. The apply_mapping
function should be generic enough to copy and utilize in your own pipeline.