I want to calculate the number of lines that satisfy a condition on a very large dataframe which can be achieved by
df.filter(col("value") >= thresh).count()
I want to know the result for each threshold in range [1, 10]
. Enumerate each threshold then do this action will scan the dataframe for 10 times. It's slow.
If I can achieve it by scanning the df only once?
CodePudding user response:
Using conditional aggregation with when
expressions should do the job.
Here's an example:
from pyspark.sql import functions as F
df = spark.createDataFrame([(1,), (2,), (3,), (4,), (4,), (6,), (7,)], ["value"])
count_expr = [
F.count(F.when(F.col("value") >= th, 1)).alias(f"gte_{th}")
for th in range(1, 11)
]
df.select(*count_expr).show()
# ----- ----- ----- ----- ----- ----- ----- ----- ----- ------
#|gte_1|gte_2|gte_3|gte_4|gte_5|gte_6|gte_7|gte_8|gte_9|gte_10|
# ----- ----- ----- ----- ----- ----- ----- ----- ----- ------
#| 7| 6| 5| 4| 2| 2| 1| 0| 0| 0|
# ----- ----- ----- ----- ----- ----- ----- ----- ----- ------
CodePudding user response:
Create an indicator column for each threshold, then sum:
import random
import pyspark.sql.functions as F
from pyspark.sql import Row
df = spark.createDataFrame([Row(value=random.randint(0,10)) for _ in range(1_000_000)])
df.select([
(F.col("value") >= thresh)
.cast("int")
.alias(f"ind_{thresh}")
for thresh in range(1,11)
]).groupBy().sum().show()
# ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- -----------
# |sum(ind_1)|sum(ind_2)|sum(ind_3)|sum(ind_4)|sum(ind_5)|sum(ind_6)|sum(ind_7)|sum(ind_8)|sum(ind_9)|sum(ind_10)|
# ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- -----------
# | 908971| 818171| 727240| 636334| 545463| 454279| 363143| 272460| 181729| 90965|
# ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- ---------- -----------
CodePudding user response:
You can use higher order functions for this purpose available from spark 2.4
Assuming this to be the input df:
--- -----------
|A |B |
--- -----------
|10 |placeholder|
|1 |placeholder|
|2 |placeholder|
|100|placeholder|
|50 |placeholder|
--- -----------
Code:
start,stop = 1,5 #change this to 10
out = df.withColumn("Range",F.expr(f"sequence({start},{stop})"))\
.withColumn("Check",F.expr(f"transform(Range,x-> A>x) "))\
.select(*df.columns,F.arrays_zip("Range","Check").alias("Res"))
out.show(truncate=False)
--- ----------- ------------------------------------------------------------
|A |B |Res |
--- ----------- ------------------------------------------------------------
|10 |placeholder|[{1, true}, {2, true}, {3, true}, {4, true}, {5, true}] |
|1 |placeholder|[{1, false}, {2, false}, {3, false}, {4, false}, {5, false}]|
|2 |placeholder|[{1, true}, {2, false}, {3, false}, {4, false}, {5, false}] |
|100|placeholder|[{1, true}, {2, true}, {3, true}, {4, true}, {5, true}] |
|50 |placeholder|[{1, true}, {2, true}, {3, true}, {4, true}, {5, true}] |
--- ----------- ------------------------------------------------------------
Note that you can explode the Res
column later to produce to individual columns.