Home > Software engineering >  UDF lookup mapping a pyspark dataframe column
UDF lookup mapping a pyspark dataframe column

Time:04-27

I have a pyspark.sql.dataframe.DataFrame object df which contains Continent and Country code. I also have a dictionary of dictionary dicts which contains the lookup value for each column.

import pyspark.sql.functions as F
import pyspark.sql.types as T


df = sc.parallelize([('A1','JP'),('A1','CH'),('A2','CA'),
   ('A2','US')]).toDF(['Continent','Country'])

dicts = sc.broadcast(dict([('Country', dict([
                          ('US', 'USA'), 
                          ('JP', 'Japan'),
                          ('CA', 'Canada'),
                          ('CH', 'China')
              ])),
              ('Continent', dict([
                          ('A1','Asia'), 
                          ('A2','America')])
              )
              ]))

 --------- ------- 
|Continent|Country|
 --------- ------- 
|       A1|     JP|
|       A1|     CH|
|       A2|     CA|
|       A2|     US|
 --------- ------- 

I want to replace both Country and Continent into it lookup value as I have try:

preprocess_request = F.udf(lambda colname, key: 
                       dicts.value[colname].get[key], 
                      T.StringType())
df.withColumn('Continent', preprocess_request('Continent', F.col('Continent')))\
.withColumn('Country', preprocess_request('Country', F.col('Country')))\
.display()

but got me error said object is not subscriptable.

What I expect exactly like this:

 --------- ------- 
|Continent|Country|
 --------- ------- 
|     Asia|  Japan|
|     Asia|  China|
|  America| Canada|
|  America|    USA|
 --------- ------- 

CodePudding user response:

There is a problem with your arguments to a function - when you specify 'Continent' - it's treated as a column name, not a fixed value, so when your UDF is called, the value of this column is passed, not the word Continent. To fix this, you need to wrap Continent and Country into F.lit:

preprocess_request = F.udf(lambda colname, key: 
                       dicts.value.get(colname, {}).get(key), 
                      T.StringType())
df.withColumn('Continent', preprocess_request(F.lit('Continent'), F.col('Continent')))\
.withColumn('Country', preprocess_request(F.lit('Country'), F.col('Country')))\
.display()

with it it gives correct result:

 --------- ------- 
|Continent|Country|
 --------- ------- 
|     Asia|  Japan|
|     Asia|  China|
|  America| Canada|
|  America|    USA|
 --------- ------- 

But really you don't need UDF for that, as it's very slow due serialization overhead. It could be much faster if you use native PySpark APIs and represent dictionaries as Spark literal. Something like this:

continents = F.expr("map('A1','Asia', 'A2','America')")
countries = F.expr("map('US', 'USA', 'JP', 'Japan', 'CA', 'Canada', 'CH', 'China')")
df.withColumn('Continent', continents[F.col('Continent')])\
.withColumn('Country', countries[F.col('Country')])\
.show()

gives you the same answer, but should be much faster:

 --------- ------- 
|Continent|Country|
 --------- ------- 
|     Asia|  Japan|
|     Asia|  China|
|  America| Canada|
|  America|    USA|
 --------- ------- 

CodePudding user response:

I would use a pandas udf instead of a plain udf. pandas udfs are vectorized.

Option 1

def map_dict(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for pdf in iterator:
      Continent=pdf.Continent
      Country=pdf.Country
      yield pdf.assign(Continent=Continent.map(dicts.value['Continent']),
       Country=Country.map(dicts.value['Country']))

df.mapInPandas(map_dict, schema=df.schema).show()

Option 2 Please note though this is likely to incur a shuffle.

from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf       


def map_dict(pdf: pd.DataFrame) -> pd.DataFrame:
  Continent=pdf.Continent
  Country=pdf.Country
  return pdf.assign(Continent=Continent.map(dicts.value['Continent']),
       Country=Country.map(dicts.value['Country']))

df.groupby("Continent","Country").applyInPandas(map_dict, schema=df.schema).show()

 --- --------- ------- 
| id|Continent|Country|
 --- --------- ------- 
|  2|     Asia|  China|
|  1|     Asia|  Japan|
|  3|  America| Canada|
|  4|  America|    USA|
 --- --------- ------- 
  • Related