Home > Blockchain >  How to remap and filter MapType keys in PySpark based on a Python dict?
How to remap and filter MapType keys in PySpark based on a Python dict?

Time:08-08

I have a PySpark df with this schema:

root
 |-- name: string (nullable = true)
 |-- products: struct (nullable = true)
 |    |-- product_hist: map (nullable = true)
 |    |    |-- key: string
 |    |    |-- value: integer (valueContainsNull = true)
 |    |-- tot_visits: long (nullable = true)

Example of a row:

Mary, {{A -> 2000, B -> 100, C -> 250}, 4}

Given a python dict

my_dict = {'A': 1, 'C': 2}

I'd like to change the keys in the MapType field using the Python dict and filter out any keys that are not in the dict. I'd then get:

Mary, {{1 -> 2000, 2 -> 250}, 4} 

What's the best way to do that?

CodePudding user response:

From Spark 3.1 you can use transform_keys to change map keys and map_filter to filter out map keys which are not needed. Also, withField. All of these are used in this code.

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [('Mary', ({'A': 2000, 'B': 100, 'C': 250}, 4))],
    'name string, products struct<product_hist:map<string,bigint>,tot_visits:bigint>')
my_dict = {'A': 1, 'C': 2}

map_col = F.create_map([F.lit(x) for i in my_dict.items() for x in i])
df = df.withColumn(
    'products',
    F.col('products').withField(
        'product_hist',
        F.transform_keys(
            F.map_filter(
                'products.product_hist',
                lambda k, v: k.isin([*my_dict.keys()])
            ),
            lambda k, v: map_col[k]
        )
    )
)
df.show(truncate=0)
#  ---- -------------------------- 
# |name|products                  |
#  ---- -------------------------- 
# |Mary|{{1 -> 2000, 2 -> 250}, 4}|
#  ---- -------------------------- 

Efficiency is great:

df.explain()
# == Physical Plan ==
# Project [name#469, if (isnull(products#470)) null else named_struct(product_hist, transform_keys(map_filter(products#470.product_hist, lambdafunction(lambda x_50#475 IN (A,C), lambda x_50#475, lambda y_51#476L, false)), lambdafunction(map(keys: [A,C], values: [1,2])[lambda x_52#477], lambda x_52#477, lambda y_53#478L, false)), tot_visits, products#470.tot_visits) AS products#473]
#  - *(1) Scan ExistingRDD[name#469,products#470]
  • Related