access objects in pyspark user-defined function from outer scope, avoid PicklingError: Could not ser


How do I avoid initializing a class within a pyspark user-defined function? Here is an example.

Creating a spark session and DataFrame representing four latitudes and longitudes.

import pandas as pd
from pyspark import SparkConf
from pyspark.sql import SparkSession

conf = SparkConf()
conf.set('spark.sql.execution.arrow.pyspark.enabled', 'true')
spark = SparkSession.builder.config(conf=conf).getOrCreate()

sdf = spark.createDataFrame(pd.DataFrame({
    'lat': [37, 42, 35, -22],
    'lng': [-113, -107, 127, 34]}))

Here is the Spark DataFrame

 --- ---- 
|lat| lng|
 --- ---- 
| 37|-113|
| 42|-107|
| 35| 127|
|-22|  34|
 --- ---- 

Enriching the DataFrame with a timezone string at each latitude / longitude via the timezonefinder package. Code below runs without errors

from typing import Iterator
from timezonefinder import TimezoneFinder

def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for dx in iterator:
        tzf = TimezoneFinder()
        dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
        yield dx
pdf = sdf.mapInPandas(func, schema='lat double, lng double, timezone string').toPandas()

The above code runs without errors and creates the pandas DataFrame below. The issue is the TimezoneFinder class is initialized within the user-defined function which creates a bottleneck

In [4]: pdf
    lat    lng         timezone
0  37.0 -113.0  America/Phoenix
1  42.0 -107.0   America/Denver
2  35.0  127.0       Asia/Seoul
3 -22.0   34.0    Africa/Maputo

The question is how to get this code to run more like below, where the TimezoneFinder class is initialized once and outside of the user-defined function. As is, the code below generates this error PicklingError: Could not serialize object: TypeError: cannot pickle '_io.BufferedReader' object

def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for dx in iterator:
        dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
        yield dx
tzf = TimezoneFinder()
pdf = sdf.mapInPandas(func, schema='lat double, lng double, timezone string').toPandas()

UPDATE - Also tried to use functools.partial and an outer function but still received same error. That is, this approach does not work:

def outer(iterator, tzf):
    def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
        for dx in iterator:
            dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
            yield dx
    return func(iterator)
tzf = TimezoneFinder()
outer = partial(outer, tzf=tzf)
pdf = sdf.mapInPandas(outer, schema='lat double, lng double, timezone string').toPandas()

You will need a cached instance of the object on every worker. You could do that as follows

instance = [None]

def func(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    if instance[0] is None:
        instance[0] = TimezoneFinder()
    tzf = instance[0]
    for dx in iterator:
        dx['timezone'] = [tzf.timezone_at(lng=a, lat=b) for a, b in zip(dx['lng'], dx['lat'])]
        yield dx

Note that for this to work, your function would be defined within a module, to give the instance cache somewhere to live. Else you would have to hang it off some builtin module, e.g., os.instance = [].

