Home > OS >  Spark performance issues with multiple subsequent joins
Spark performance issues with multiple subsequent joins

Time:09-17

We are migrating a lot of locally running Python ETL code (using pandas) to Spark running on Databricks. We are running into some performance issues in a part that performs many subsequent joins (which runs fine in pandas).

We are our running our code as a package on the Databricks cluster (which is one it is hard to share isolated, working code in this question). All joins happen subsequently in one function.

The main dataframe we are joining to is not super large: 819.000 records over 44 columns. In total we are left joining 27 other dataframe to this main dataframe, each adding only between 1 and 3 extra columns to the main dataframe.

All dataframes are joined on the same column (customer unique identifier) which is unique in the main dataframe and should not be skewed in any way.

The problem arises when we run our full code on the databricks cluster and try to perform any action like count() or display() after all the joins have been run. The runtime of a simple count on the main dataframe explodes once it has to perform all the subsequent joins and we do not understand how.

Some extra info:

  • The cluster is a simple DS4_v2 (28gb RAM, 8 cores) with 2-8 worker nodes.
  • Changing the default shuffle partition number to a very low number like 2 did not seem to help
  • Persisting and unpersisting dataframes in between joins did not seem to help
  • Scaling up or down the cluster did not seem to help

Below the code of all the joins that are performed. Each code block is part of a separate function. These functions are subsequently called in another function that we call on the Databricks cluster. df_pop is our main dataframe.

Join 1:

df_pop = df_pop.join(other_df1, df_pop.bc == other_df1.bc, how='left_outer')

Join 2 - 21 (loops over 20 population dataframes and joins each to df_pop):

for pop in self.des_config.get('populations'):
    population = self.cleaned_data.get(pop).withColumnRenamed('bc', f'bc_{pop}').select(col(f'bc_{pop}'))
    df_pop = df_pop.join(population, df_pop.bc == f'bc_{pop}', how = 'left_outer')\
                   .withColumn(f'bc_{pop}', F.when(col(f'bc_{pop}').isNull(), F.lit(False)).otherwise(F.lit(True)))\
                   .withColumnRenamed(f'bc_{pop}', f'pop_{pop}')

Join 22 - 25:

df_pop = df_pop.withColumn(f'pop_Real Estate', df_pop.bo_sector == 'REAL ESTATE')\
                       .withColumn(f'pop_O&O', df_pop.bo_sector == 'GOVERNMENT & EDUCATION')\
                       .join(other_df2, on = 'bc', how = 'left_outer').drop('other_df2.bc')\
                       .join(other_df3, on = 'bc', how = 'left_outer').drop('other_df3.bc')\
                       .join(other_df4, on = 'bc', how = 'left_outer').drop('other_df4.bc')\
                       .join(other_df5, other_df5.bc == df_pop.bc, how = 'left_outer')

Join 26:

df_pop = df_pop.join(other_df6, other_df6.bc == df_pop.bc, how = 'Left_outer')\
                        .fillna(False, subset=['pop_edr_backlog'])

Join 27:

df_pop = df_pop.join(other_df7, other_df7.bc == df_pop.bc, how = 'left_outer')

Any action on df_pop is impossible after these joins. Any idea how this can be fixed?

CodePudding user response:

We just resolved this issue. Joins 2-21 that are performed in the for-loop apparently resulted in a 'broadcast nested loop join' under the hood because of the .withColumnRenamed that is happening before to join. Removing this resulted in 'broadcast hash joins' being performend which are a lot faster. So check the execution plan with `df_pop.explain('formatted') and check the types of joins that are being performed.

  • Related