I am processing a dataset of 20 million XML documents using Spark. I originally was processing all of them, but I only actually need about a third of them. In a different spark workflow, I created a dataframe keyfilter
with one column being the key of each XML and the second column being a boolean, True
if the xml corresponding to the key should be processed and False
otherwise.
The XMLs themselves are processed using a Pandas UDF, which I can't share.
My notebook on DataBricks works essentially like this:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/os/s3.parquet')
keyfilter.cache()
def process_part(part, fraction=1, filter=True, return_df=False):
try:
df = spark.read.parquet('/path/to/parquets/on/s3/%s/part-d*' % (DATE, part))
# Sometimes, the file part-xxxxx doesn't exist
except AnalysisException:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(keyfilter, on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/part-d_%s_%d' % (part, DATE, time.time()))
if return_df:
return mod_df
n_cores = 6
i=0
while n_cores*i < 1024:
with ThreadPool(n_cores) as p:
p.map(process_part, range(n_cores*i, min(1024, n_cores*i n_cores)))
i = 1
The reason I'm posting this question is that despite the fact that the Pandas UDF should be the most expensive operation taking place, adding the filtering actually makes my code run much slower than if I weren't filtering at all. I am very new to Spark and I'm wondering if I'm doing something stupid here that is causing the joins with keyfilter
to be very slow, and if so, if there is a way to make them fast (e.g., is there a way to make keyfilter
act like a hash table from keys to booleans, like CREATE INDEX in SQL?). I imagine that the large size of keyfilter
is playing some kind of role here; it has 20 million rows while df
in process_part
has only a tiny fraction of those rows (df
is much larger in size, however, as it contains XML documents). Should I maybe be combining all the parts into one giant dataframe instead of processing them one at a time?
Or is there a way of informing Spark that keys are unique in both dataframes?
CodePudding user response:
The key to getting the join to happen in a reasonable time frame was to use broadcast
on keyfilter
to do a Broadcast Hash Join instead of a standard join. I also merged some of the parts and reduced the parallelism (for some reason, too many threads seems to sometimes result in the engine crashing). My newly performant code looks like this:
import pyspark
import time
from pyspark.sql.types import StringType
from pyspark.sql.functions import pandas_udf, col, braodcast
from pyspark.sql.utils import AnalysisException
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
import pandas as pd
DATE = '20200314'
<define UDF pandas_xml_convert_string()>
keyfilter = spark.read.parquet('/path/to/keyfilter/on/s3.parquet')
keyfilter.cache()
def process_parts(part_pair, fraction=1, return_df=False, filter=True):
dfs = []
parts_start, parts_end = part_pair
parts = range(parts_start, parts_end)
for part in parts:
try:
df = spark.read.parquet('/input/path/on/s3/%s/part-d*' % (DATE, part))
dfs.append(df)
except AnalysisException:
print("There is no part d!" % part)
continue
if len(dfs) >= 2:
df = reduce(lambda x, y: x.union(y), dfs)
elif len(dfs) == 1:
df = dfs[0]
else:
return None
if fraction < 1:
df = df.sample(fraction=fraction, withReplacement=False)
if filter:
df_with_filter = df.join(broadcast(keyfilter), on='key', how='left').fillna(False)
filtered_df = df_with_filter.filter(col('filter')).drop('filter')
mod_df = filtered_df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
else:
mod_df = df.select(col('key'), pandas_xml_convert_string(col('xml')).alias('xmlplain'), col('xml'))
mod_df.write.parquet('/output/path/on/s3/parts-d-d_%s_%d' % (parts_start, parts_end-1, DATE, time.time()))
if return_df:
return mod_df
start_time = time.time()
pairs = [(i*4, i*4 4) for i in range(256)]
with ThreadPool(3) as p:
batch_start_time = time.time()
for i, _ in enumerate(p.imap_unordered(process_parts, pairs, chunksize=1)):
batch_end_time = time.time()
batch_len = batch_end_time - batch_start_time
cum_len = batch_end_time - start_time
print('Processed group %d/256 %d minutes and %d seconds after previous group.' % (i 1, batch_len // 60, batch_len % 60))
print('%d hours, %d minutes, %d seconds since start.' % (cum_len // 3600, (cum_len % 3600) // 60, cum_len % 60))
batch_start_time = time.time()