Home > OS >  How to filter a dataframe with a specific condition in Spark
How to filter a dataframe with a specific condition in Spark

Time:11-14

I have the following DF

Cod  Category  N
1    A         1 
1    A         2
1    A         3
1    B         1
1    B         2
1    B         3
1    B         4
1    B         5
2    D         1
3    Z         1
3    Z         2
3    Z         3
3    Z         4

I need to filter this DF to, when N > 3, then all values for the category should be retrieved. My expected output to simplify the example:

Cod  Category  N
1    B         1
1    B         2
1    B         3
1    B         4
1    B         5
3    Z         1
3    Z         2
3    Z         3
3    Z         4

How Can I Implement this type of filter? I tried to use window functions to generate another column with a Flag indicating to filter, but with no success.

CodePudding user response:

You can use a window to associate at each row the maximum N present in its category. Then just apply your condition to this new column to filter the categories.

w = Window.partitionBy("Cod", "Category")

df = df.withColumn("max_N_in_category", F.max("N").over(w))

N = 3
df = df \
    .filter(F.col("max_N_in_category") > N) \
    .drop("max_N_in_category")

CodePudding user response:

Data

df =spark.createDataFrame([(1 ,   'A',         1 ),
(1  ,  'A' ,        2),
(1   , 'A'  ,       3),
(1    ,'B' ,        1),
(1  ,  'B' ,        2),
(1   , 'B' ,        3),
(1   , 'B' ,        4),
(1   , 'B' ,        5),
(2   , 'D' ,        1),
(3   , 'Z' ,        1),
(3   , 'Z'  ,       2)],
('Cod',  'Category',  'N'))




new = (df.withColumn('check', collect_list('N').over(Window.partitionBy('cod','Category')))#create array per group and keep in column check
       
       .where(expr("exists(check, c -> array_contains(check,3))"))#Filter arrays that do not contaiin 3
       .drop('check')#drop column check
      ).show()

outcome

 --- -------- --- 
|Cod|Category|  N|
 --- -------- --- 
|  1|       A|  1|
|  1|       A|  2|
|  1|       A|  3|
|  1|       B|  1|
|  1|       B|  2|
|  1|       B|  3|
|  1|       B|  4|
|  1|       B|  5|

CodePudding user response:

If you don't want to use window functions, then it can be done by groupBy and filter with isin():

df.filter(df.Category.isin([x.Category for x in df.groupBy("Category").max("N").collect() if x["max(N)"] > 3]))

[Out]:
 --- -------- --- 
|Cod|Category|  N|
 --- -------- --- 
|  1|       B|  1|
|  1|       B|  2|
|  1|       B|  3|
|  1|       B|  4|
|  1|       B|  5|
|  3|       Z|  1|
|  3|       Z|  2|
|  3|       Z|  3|
|  3|       Z|  4|
 --- -------- --- 
  • Related