Home > Net >  Optimizing Spark code with filtering and UDF
Optimizing Spark code with filtering and UDF

Time:03-21

I am processing a dataset of 20 million XML documents using Spark. I originally was processing all of them, but I only actually need about a third of them. In a different spark workflow, I created a dataframe keyfilter with one column being the key of each XML and the second column being a boolean, True if the xml corresponding to the key should be processed and False otherwise.

The XMLs themselves are processed using a Pandas UDF, which I can't share.

My notebook on DataBricks works essentially like this:

import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>

keyfilter = spark.read.parquet('/path/to/keyfilter/os/s3.parquet')
keyfilter.cache()

def process_part(part, fraction=1, filter=True, return_df=False):
  try:
    df = spark.read.parquet('/path/to/parquets/on/s3/%s/part-d*' % (DATE, part))
  # Sometimes, the file part-xxxxx doesn't exist
  except AnalysisException:
    return None
  if fraction < 1:
    df = df.sample(fraction=fraction, withReplacement=False)
  if filter:
    df_with_filter = df.join(keyfilter, on='key', how='left').fillna(False)
    filtered_df = df_with_filter.filter(col('filter')).drop('filter')
    mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  else:
    mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  mod_df.write.parquet('/output/path/on/s3/part-d_%s_%d' % (part, DATE, time.time()))
  if return_df:
    return mod_df


n_cores = 6
i=0
while n_cores*i < 1024:
    with ThreadPool(n_cores) as p:
        p.map(process_part, range(n_cores*i, min(1024, n_cores*i n_cores)))
    i  = 1

The reason I'm posting this question is that despite the fact that the Pandas UDF should be the most expensive operation taking place, adding the filtering actually makes my code run much slower than if I weren't filtering at all. I am very new to Spark and I'm wondering if I'm doing something stupid here that is causing the joins with keyfilter to be very slow, and if so, if there is a way to make them fast (e.g., is there a way to make keyfilter act like a hash table from keys to booleans, like CREATE INDEX in SQL?). I imagine that the large size of keyfilter is playing some kind of role here; it has 20 million rows while df in process_part has only a tiny fraction of those rows (df is much larger in size, however, as it contains XML documents). Should I maybe be combining all the parts into one giant dataframe instead of processing them one at a time?

Or is there a way of informing Spark that keys are unique in both dataframes?

CodePudding user response:

The key to getting the join to happen in a reasonable time frame was to use broadcast on keyfilter to do a Broadcast Hash Join instead of a standard join. I also merged some of the parts and reduced the parallelism (for some reason, too many threads seems to sometimes result in the engine crashing). My newly performant code looks like this:

import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col, braodcast
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>

keyfilter = spark.read.parquet('/path/to/keyfilter/on/s3.parquet')
keyfilter.cache()

def process_parts(part_pair, fraction=1, return_df=False, filter=True):
  dfs = []
  parts_start, parts_end = part_pair
  parts = range(parts_start, parts_end)
  for part in parts:
    try:
      df = spark.read.parquet('/input/path/on/s3/%s/part-d*' % (DATE, part))
      dfs.append(df)
    except AnalysisException:
      print("There is no part d!" % part)
      continue
  if len(dfs) >= 2:
    df = reduce(lambda x, y: x.union(y), dfs)
  elif len(dfs) == 1:
    df = dfs[0]
  else:
    return None
  if fraction < 1:
    df = df.sample(fraction=fraction, withReplacement=False)
  if filter:
    df_with_filter = df.join(broadcast(keyfilter), on='key', how='left').fillna(False)
    filtered_df = df_with_filter.filter(col('filter')).drop('filter')
    mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  else:
    mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
  mod_df.write.parquet('/output/path/on/s3/parts-d-d_%s_%d' % (parts_start, parts_end-1, DATE, time.time()))
  if return_df:
    return mod_df


start_time = time.time()
pairs = [(i*4, i*4 4) for i in range(256)]
with ThreadPool(3) as p:
  batch_start_time = time.time()
  for i, _ in enumerate(p.imap_unordered(process_parts, pairs, chunksize=1)):
    batch_end_time = time.time()
    batch_len = batch_end_time - batch_start_time
    cum_len = batch_end_time - start_time
    print('Processed group %d/256 %d minutes and %d seconds after previous group.' % (i 1, batch_len // 60, batch_len % 60))
    print('%d hours, %d minutes, %d seconds since start.' % (cum_len // 3600, (cum_len % 3600) // 60, cum_len % 60))
    batch_start_time = time.time()
  • Related