Home > Software engineering >  PySpark UDF using ElementTree returns Pickling Error
PySpark UDF using ElementTree returns Pickling Error

Time:11-11

I have the below dataframe:

import pyspark.sql.functions as F
import pyspark.sql.types as T
from lxml import etree

data=[
  (123,1,"string123","string456","string789")]
 
importSchema=(T.StructType([
  T.StructField("field1",T.IntegerType(),True),
  T.StructField("field2",T.IntegerType(),True),
  T.StructField("field3",T.StringType(), True),
  T.StructField("field4",T.StringType(),True),
  T.StructField("field5",T.StringType(),True)
]))  
  
df=spark.createDataFrame(data=data,schema=importSchema)

I'm trying to create a udf that takes the values from each of these fields and constructs a xml string using etree.

def create_str(field1,field2,field3,field4,field5):

    outer = etree.SubElement(root, 'outer')
    field1s = etree.SubElement(outer, field1)
    field2s = etree.SubElement(outer, field2)
    field3s = etree.SubElement(outer, field3)
    field4s = etree.SubElement(outer, field4)
    field5s = etree.SubElement(outer, field5)
    field1s.text = field1
    field2s.text = field2
    field3s.text = field3
    field4s.text = field4
    field5s.text = field5
    
    var=etree.tostring(root, pretty_print=True).decode('utf-8')
    
    return var
  
udf_create_str = F.udf(create_str)

df.withColumn("output", udf_create_str(df.field1,df.field2,df.field3,df.field4,df.field5)).show()

However, this returns:

PicklingError: Could not serialize object: TypeError: cannot pickle 'lxml.etree._Element' object

How can I get the etree.tostring() value to a temp var or column?

CodePudding user response:

I think the main issue is that you have int values which you need to convert to str.

Here's my try:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T
from xml.etree import ElementTree as etree


def create_str(fields):
    root = etree.Element("root")
    outer = etree.SubElement(root, "outer")
    field_names = ["field1", "field2", "field3", "field4", "field5"]
    for field, field_name in zip(fields, field_names):
        field_tmp = etree.SubElement(outer, field_name)
        field_tmp.text = str(field)
    return etree.tostring(root, encoding="utf-8").decode("utf-8")


if __name__ == "__main__":
    spark = SparkSession.builder.getOrCreate()
    data = [(123, 1, "string123", "string456", "string789")]

    importSchema = T.StructType(
        [
            T.StructField("field1", T.IntegerType(), True),
            T.StructField("field2", T.IntegerType(), True),
            T.StructField("field3", T.StringType(), True),
            T.StructField("field4", T.StringType(), True),
            T.StructField("field5", T.StringType(), True),
        ]
    )

    df = spark.createDataFrame(data=data, schema=importSchema)
    udf_create_str = F.udf(create_str)
    df.withColumn(
        "output",
        udf_create_str(F.array(df.field1, df.field2, df.field3, df.field4, df.field5)),
    ).show(20, False)

Result:

 ------ ------ --------- --------- --------- ------------------------------------------------------------------------------------------------------------------------------------------------ 
|field1|field2|field3   |field4   |field5   |output                                                                                                                                          |
 ------ ------ --------- --------- --------- ------------------------------------------------------------------------------------------------------------------------------------------------ 
|123   |1     |string123|string456|string789|<root><outer><field1>123</field1><field2>1</field2><field3>string123</field3><field4>string456</field4><field5>string789</field5></outer></root>|
 ------ ------ --------- --------- --------- ------------------------------------------------------------------------------------------------------------------------------------------------ 
  • Related