Home > other >  Write pyspark UDF with initialization code
Write pyspark UDF with initialization code

Time:06-07

I want to run a pyspark UDF with initialization code which will run only once per Python process. I don't need to guarantee this code will run exactly once, but for performance's sake I don't want it to run for each row.

In scala I can do:

object MyUDF {
    // This code runs once per JVM (executor), initialize heavy objects

}

class MyUDF extends UDF2[Double, Double, String] {
  override def call(lat: Double, long: Double): String = {
    // This code runs per record, 
       but can use the static objects that are already initialized
  }
}

Is something similar possible with Python? I know each executor has a single Python process running to answer that executor's UDF calls.

CodePudding user response:

Let's consider a toy example where the big object is a list, and your UDF simply checks if an element is in that list.

One way to go at it is simply to define the object outside of the UDF definition. The initialization code would therefore be executed only once in the driver.

from pyspark.sql import functions as f
big_list = [1, 2, 4, 6, 8, 10]

is_in_list = f.udf(lambda x: x in big_list)

spark.range(10).withColumn("x", is_in_list(f.col("id"))).show()

Which yields:

 --- ----- 
| id|    x|
 --- ----- 
|  0|false|
|  1| true|
|  2| true|
|  3|false|
|  4| true|
|  5|false|
|  6| true|
|  7|false|
|  8| true|
|  9|false|
 --- ----- 

The problem with that code AND with your scala code as well is that spark will ship a copy of that object with every single task. If your concern is just about the time necessary to initialize the object, that's fine. But if the object is big, it may hurt the performance of the job.

To solve that problem, you can use a broadcast variable. Indeed, according to spark's documentation:

Broadcast variables allow the programmer to keep a read-only variable cached on each machine rather than shipping a copy of it with tasks.

The code would be very similar:

from pyspark.sql import functions as f
big_list = [1, 2, 4, 6, 8, 10]
big_list_bc = sc.broadcast(big_list)

is_in_list = f.udf(lambda x: x in big_list_bc.value)

spark.range(10).withColumn("x", is_in_list(f.col("id"))).show()

CodePudding user response:

One of the several types of pandas_udf is Iterator[pd.Series]) -> Iterator[pd.Series].
Prior to mapping every pd.Series and yielding a result you can have an init stage that will run once every executor.

Pretty good guide about the subject can be found here .

@pandas_udf("long")
def plus_one(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    <init something>  
    for x in batch_iter:
        yield <your code here>
  • Related