Home > Enterprise >  Optimization on for loop on columns in Pyspark
Optimization on for loop on columns in Pyspark

Time:06-19

I don't know if my title is very clear. I have a table with a lot columns (more than a hundred). Some of my columns contains values with brackets and I need to explode them into several rows. Here is a reproducible example:

# Import libraries
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import *
import pandas as ps

# Create an example
columns = ["Name", "Age", "Activity", "Studies"]
data = [("Jame", 25, "[Painting,Yoga]", "[Math,Physics]"), ("Anne", 20, "[Garden,Cooking,Travel]", "[Communication,Marketing]"), ("Jane", 10, "[Gymnastique]", "[Basic School]")]
df = spark.createDataFrame(data=data,schema=columns)
df.show(truncate=False)

it shows the following table:

 ---- --- ----------------------- ------------------------- 
|Name|Age|Activity               |Studies                  |
 ---- --- ----------------------- ------------------------- 
|Jame|25 |[Painting,Yoga]        |[Math,Physics]           |
|Anne|20 |[Garden,Cooking,Travel]|[Communication,Marketing]|
|Jane|10 |[Gymnastique]          |[Basic School]           |
 ---- --- ----------------------- ------------------------- 

I need to determine what columns contains brackets as value:

list_col = df.dtypes
df_array_col = spark.createDataFrame(list_col)\
    .withColumnRenamed("_1", "Colname")\
    .withColumnRenamed("_2", "TypeColumn")\
    .filter(col("TypeColumn") == "string")\
    .withColumn("IsBracket", lit(0))\
    .toPandas()

# Function for determining what column contains brackets as a value
def func_isSquaredBracket(my_col):
    A = df.select(first(col(my_col).rlike("\["), ignorenulls=True).alias(my_col))
    val_IsBracket = A.select(col(my_col)).collect()[0][0]

    return val_IsBracket

# For loop for applying the function
n_array = df_array_col.count()["Colname"]

for index, row in df_array_col.iterrows():
   IsBracket_value = func_isSquaredBracket(df_array_col.at[index, "Colname"])
   if IsBracket_value == True:
      df_array_col.at[index, "IsBracket"] = 1 

I succeed what columns have brackets as value. Now I can explode my table:

def func_extractStringInBracket_andSplit(my_col):
    extract_string = regexp_extract(my_col, r'(?<=\[). ?(?=\])', 0).alias(my_col)
    string_split = split(extract_string, "\||,").alias(my_col)
    string_explode_array = explode_outer(string_split).alias(my_col) 

    return string_explode_array

df_explode_bracket = df
for index, row in df_array_bracket_col.iterrows():
    colname = df_array_bracket_col["Colname"][index]
    df_explode_bracket = df_explode_bracket.withColumn(colname, func_extractStringInBracket_andSplit(colname))
df_explode_bracket.show(truncate=False)

I obtain the result I want:

 ---- --- ----------- ------------- 
|Name|Age|Activity   |Studies      |
 ---- --- ----------- ------------- 
|Jame|25 |Painting   |Math         |
|Jame|25 |Painting   |Physics      |
|Jame|25 |Yoga       |Math         |
|Jame|25 |Yoga       |Physics      |
|Anne|20 |Garden     |Communication|
|Anne|20 |Garden     |Marketing    |
|Anne|20 |Cooking    |Communication|
|Anne|20 |Cooking    |Marketing    |
|Anne|20 |Travel     |Communication|
|Anne|20 |Travel     |Marketing    |
|Jane|10 |Gymnastique|Basic School |
 ---- --- ----------- ------------- 

However, this solution is not optimized when I have more than 100 columns and it takes more than 6 minutes to get the result with the following message:

/opt/spark/python/lib/pyspark.zip/pyspark/sql/pandas/conversion.py:289: UserWarning: createDataFrame attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below:
  'JavaPackage' object is not callable
Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
  warnings.warn(msg)

I am pretty new to PySpark and I am not an expert in Python. My question is: How can I optimize the solution by using PySpark instead of Pandas? For loop is not ideal when you have the opportunity to use parallel processing.

CodePudding user response:

It's actually pretty easy, use regexp_extract_all:

df = (
    df.withColumn("Activity_list", F.expr(r"regexp_extract_all(Activity, '(\\w )', 1)"))
    .withColumn("Studies_list", F.expr(r"regexp_extract_all(Studies, '(\\w )', 1)"))
)
df = (
    df.drop("Activity", "Studies")
    .withColumn("Activity", F.explode("Activity_list"))
    .withColumn("Studies", F.explode("Studies_list"))
)

Edit: It even works with strings without brackets.

  • Related