Let's say I have this pyspark Dataframe:
data = spark.createDataFrame(schema=['Country'], data=[('AT',), ('BE',), ('France',), ('Latvia',)])
And let's say I want to collect various statistics about this data. For example, I might want to know how many rows use a 2-character country code and how many use longer country names:
count_short = data.where(F.length(F.col('Country')) == 2).count()
count_long = data.where(F.length(F.col('Country')) > 2).count()
This works, but when I want to collect many different counts based on different conditions, it becomes very slow even for tiny datasets. In Azure Synapse Studio, where I am working, every count takes 1-2 seconds to compute.
I need to do 100 counts, and it takes multiple minutes to compute for a dataset of 10 rows. And before somebody asks, the conditions for those counts are more complex than in my example. I cannot group by length or do other tricks like that.
I am looking for a general way to do multiple counts on arbitrary conditions, fast.
I am guessing that the reason for the slow performance is that for every count call, my pyspark notebook starts some Spark processes that have significant overhead. So I assume that if there was some way to collect these counts in a single query, my performance problems would be solved.
One possible solution I thought of is to build a temporary column that indicates which of my conditions have been matched, and then call countDistinct
on it. But then I would have individual counts for all combinations of condition matches. I also noticed that depending on the situation, the performance is a bit better when I do data = data.localCheckpoint()
before computing my statistics, but the general problem still persists.
Is there a better way?
CodePudding user response:
Function "count" can be replaced by "sum" with condition (Scala):
data.select(
sum(
when(length(col("Country")) === 2, 1).otherwise(0)
).alias("two_characters"),
sum(
when(length(col("Country")) > 2, 1).otherwise(0)
).alias("more_than_two_characters")
)