Home > Software design >  Spark - how to use function in where condition?
Spark - how to use function in where condition?

Time:11-07

I am tryint to use avg (f.avg("age")) as the condition as bellow.

from pyspark.sql import SparkSession
from pyspark.sql import functions as f
...
...
(
titanic
    .where("Survived = 1")
    .where(f.col("age") > f.avg("age"))
    .groupBy('Sex','Pclass')
    .count()
    .show()
)

But it not works. How can I do it?

Something similar using spark sql (works):

spark.sql("""
    select sex, Pclass,count(*) from titan where Survived = 1 
    and age > (select avg(age) from titan) group by sex,Pclass
""").show()

CodePudding user response:

You have to compute the average first. I assume you want the average age for ('Sex','Pclass'). You can change that

(
titanic.withColumn('avg_age',mean('age').over(Window.partitionBy('Sex','Pclass')))
    .where("Survived = 1")
    .where(f.col("age") > f.col("avg_age"))
     
    .groupBy('Sex','Pclass')
    .count()
    .show()
)
  • Related