Home > Mobile >  How to compute multiple counts with different conditions on a pyspark DataFrame, fast?
How to compute multiple counts with different conditions on a pyspark DataFrame, fast?

Time:10-19

Let's say I have this pyspark Dataframe:

data = spark.createDataFrame(schema=['Country'], data=[('AT',), ('BE',), ('France',), ('Latvia',)])

And let's say I want to collect various statistics about this data. For example, I might want to know how many rows use a 2-character country code and how many use longer country names:

count_short = data.where(F.length(F.col('Country')) == 2).count()
count_long = data.where(F.length(F.col('Country')) > 2).count()

This works, but when I want to collect many different counts based on different conditions, it becomes very slow even for tiny datasets. In Azure Synapse Studio, where I am working, every count takes 1-2 seconds to compute.

I need to do 100 counts, and it takes multiple minutes to compute for a dataset of 10 rows. And before somebody asks, the conditions for those counts are more complex than in my example. I cannot group by length or do other tricks like that.

I am looking for a general way to do multiple counts on arbitrary conditions, fast.

I am guessing that the reason for the slow performance is that for every count call, my pyspark notebook starts some Spark processes that have significant overhead. So I assume that if there was some way to collect these counts in a single query, my performance problems would be solved.

One possible solution I thought of is to build a temporary column that indicates which of my conditions have been matched, and then call countDistinct on it. But then I would have individual counts for all combinations of condition matches. I also noticed that depending on the situation, the performance is a bit better when I do data = data.localCheckpoint() before computing my statistics, but the general problem still persists.

Is there a better way?

CodePudding user response:

Function "count" can be replaced by "sum" with condition (Scala):

data.select(
  sum(
    when(length(col("Country")) === 2, 1).otherwise(0)
  ).alias("two_characters"),
  sum(
    when(length(col("Country")) > 2, 1).otherwise(0)
  ).alias("more_than_two_characters")
)
  • Related