Home > Mobile >  filter then count for many different threshold
filter then count for many different threshold

Time:03-28

I want to calculate the number of lines that satisfy a condition on a very large dataframe which can be achieved by

df.filter(col("value") >= thresh).count()

I want to know the result for each threshold in range [1, 10]. Enumerate each threshold then do this action will scan the dataframe for 10 times. It's slow.

If I can achieve it by scanning the df only once?

CodePudding user response:

Using conditional aggregation with when expressions should do the job.

Here's an example:

from pyspark.sql import functions as F

df = spark.createDataFrame([(1,), (2,), (3,), (4,), (4,), (6,), (7,)], ["value"])

count_expr = [
    F.count(F.when(F.col("value") >= th, 1)).alias(f"gte_{th}")
    for th in range(1, 11)
]

df.select(*count_expr).show()
# ----- ----- ----- ----- ----- ----- ----- ----- ----- ------ 
#|gte_1|gte_2|gte_3|gte_4|gte_5|gte_6|gte_7|gte_8|gte_9|gte_10|
# ----- ----- ----- ----- ----- ----- ----- ----- ----- ------ 
#|    7|    6|    5|    4|    2|    2|    1|    0|    0|     0|
# ----- ----- ----- ----- ----- ----- ----- ----- ----- ------ 

CodePudding user response:

Create an indicator column for each threshold, then sum:

import random
import pyspark.sql.functions as F
from pyspark.sql import Row

df = spark.createDataFrame([Row(value=random.randint(0,10)) for _ in range(1_000_000)])

df.select([
    (F.col("value") >= thresh)
    .cast("int")
    .alias(f"ind_{thresh}") 
    for thresh in range(1,11)
]).groupBy().sum().show()

#  ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ----------- 
# |sum(ind_1)|sum(ind_2)|sum(ind_3)|sum(ind_4)|sum(ind_5)|sum(ind_6)|sum(ind_7)|sum(ind_8)|sum(ind_9)|sum(ind_10)|
#  ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ----------- 
# |    908971|    818171|    727240|    636334|    545463|    454279|    363143|    272460|    181729|      90965|
#  ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ----------- 

CodePudding user response:

You can use higher order functions for this purpose available from spark 2.4

Assuming this to be the input df:

 --- ----------- 
|A  |B          |
 --- ----------- 
|10 |placeholder|
|1  |placeholder|
|2  |placeholder|
|100|placeholder|
|50 |placeholder|
 --- ----------- 

Code:

start,stop = 1,5 #change this to 10

out = df.withColumn("Range",F.expr(f"sequence({start},{stop})"))\
        .withColumn("Check",F.expr(f"transform(Range,x-> A>x) "))\
        .select(*df.columns,F.arrays_zip("Range","Check").alias("Res"))

out.show(truncate=False)
 --- ----------- ------------------------------------------------------------ 
|A  |B          |Res                                                         |
 --- ----------- ------------------------------------------------------------ 
|10 |placeholder|[{1, true}, {2, true}, {3, true}, {4, true}, {5, true}]     |
|1  |placeholder|[{1, false}, {2, false}, {3, false}, {4, false}, {5, false}]|
|2  |placeholder|[{1, true}, {2, false}, {3, false}, {4, false}, {5, false}] |
|100|placeholder|[{1, true}, {2, true}, {3, true}, {4, true}, {5, true}]     |
|50 |placeholder|[{1, true}, {2, true}, {3, true}, {4, true}, {5, true}]     |
 --- ----------- ------------------------------------------------------------ 

Note that you can explode the Res column later to produce to individual columns.

  • Related