I have a pyspark.sql.dataframe.DataFrame
object df
which contains Continent
and Country
code.
I also have a dictionary of dictionary dicts
which contains the lookup value for each column.
import pyspark.sql.functions as F
import pyspark.sql.types as T
df = sc.parallelize([('A1','JP'),('A1','CH'),('A2','CA'),
('A2','US')]).toDF(['Continent','Country'])
dicts = sc.broadcast(dict([('Country', dict([
('US', 'USA'),
('JP', 'Japan'),
('CA', 'Canada'),
('CH', 'China')
])),
('Continent', dict([
('A1','Asia'),
('A2','America')])
)
]))
--------- -------
|Continent|Country|
--------- -------
| A1| JP|
| A1| CH|
| A2| CA|
| A2| US|
--------- -------
I want to replace both Country
and Continent
into it lookup value as I have try:
preprocess_request = F.udf(lambda colname, key:
dicts.value[colname].get[key],
T.StringType())
df.withColumn('Continent', preprocess_request('Continent', F.col('Continent')))\
.withColumn('Country', preprocess_request('Country', F.col('Country')))\
.display()
but got me error said object is not subscriptable
.
What I expect exactly like this:
--------- -------
|Continent|Country|
--------- -------
| Asia| Japan|
| Asia| China|
| America| Canada|
| America| USA|
--------- -------
CodePudding user response:
There is a problem with your arguments to a function - when you specify 'Continent'
- it's treated as a column name, not a fixed value, so when your UDF is called, the value of this column is passed, not the word Continent
. To fix this, you need to wrap Continent
and Country
into F.lit
:
preprocess_request = F.udf(lambda colname, key:
dicts.value.get(colname, {}).get(key),
T.StringType())
df.withColumn('Continent', preprocess_request(F.lit('Continent'), F.col('Continent')))\
.withColumn('Country', preprocess_request(F.lit('Country'), F.col('Country')))\
.display()
with it it gives correct result:
--------- -------
|Continent|Country|
--------- -------
| Asia| Japan|
| Asia| China|
| America| Canada|
| America| USA|
--------- -------
But really you don't need UDF for that, as it's very slow due serialization overhead. It could be much faster if you use native PySpark APIs and represent dictionaries as Spark literal. Something like this:
continents = F.expr("map('A1','Asia', 'A2','America')")
countries = F.expr("map('US', 'USA', 'JP', 'Japan', 'CA', 'Canada', 'CH', 'China')")
df.withColumn('Continent', continents[F.col('Continent')])\
.withColumn('Country', countries[F.col('Country')])\
.show()
gives you the same answer, but should be much faster:
--------- -------
|Continent|Country|
--------- -------
| Asia| Japan|
| Asia| China|
| America| Canada|
| America| USA|
--------- -------
CodePudding user response:
I would use a pandas udf instead of a plain udf. pandas udfs are vectorized.
Option 1
def map_dict(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
for pdf in iterator:
Continent=pdf.Continent
Country=pdf.Country
yield pdf.assign(Continent=Continent.map(dicts.value['Continent']),
Country=Country.map(dicts.value['Country']))
df.mapInPandas(map_dict, schema=df.schema).show()
Option 2 Please note though this is likely to incur a shuffle.
from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf
def map_dict(pdf: pd.DataFrame) -> pd.DataFrame:
Continent=pdf.Continent
Country=pdf.Country
return pdf.assign(Continent=Continent.map(dicts.value['Continent']),
Country=Country.map(dicts.value['Country']))
df.groupby("Continent","Country").applyInPandas(map_dict, schema=df.schema).show()
--- --------- -------
| id|Continent|Country|
--- --------- -------
| 2| Asia| China|
| 1| Asia| Japan|
| 3| America| Canada|
| 4| America| USA|
--- --------- -------