Home > OS >  How to filter or delete the row in spark dataframe by a specific number?
How to filter or delete the row in spark dataframe by a specific number?

Time:05-21

I want to make a manipulate to a spark dataframe. For example, there is a dataframe with two columns.

 -------------------- -------------------- 
|                 key|               value|
 -------------------- -------------------- 
|1                   |Bob                 |
|2                   |Bob                 |
|3                   |Alice               |
|4                   |Alice               |
|5                   |Alice               |
............

There are two kinds of name in the column value and the number of Alice is more than Bob, what I want to modify is to delete some row containing Alice to make the number of row with Alice same of the row with Bob. The row should be deleted ramdomly but I found no API supporting such manipulation. What should I do to delete the row to a specific number?

CodePudding user response:

Here's your sudo code:

  1. Count "BOB"
  2. [repartition the data]/[groupby] (partionBy/GroupBy)
  3. [use iteration to cut off data at "BOB's" Count] (mapParitions/mapGroups)

You must remember that technically spark does not guarantee ordering on a dataset, so adding new data can randomly change the order of the data. So you could consider this random and just cut the count when your done. This should be faster than creating a window. If you really felt compelled you could create your own random probability function to return a fraction of each partition.

You can also use a window with this, paritionBy("value").orderBy("value") and use row_count & where to filter the partitions to "Bob's" Count.

CodePudding user response:

Perhaps you can use spark window function with row_count and subsequent filtering, something like this:

>>> df.show(truncate=False)
 --- -----                                                                      
|key|value|
 --- ----- 
|1  |Bob  |
|2  |Bob  |
|3  |Alice|
|4  |Alice|
|5  |Alice|
 --- ----- 

>>> from pyspark.sql import Window
>>> from pyspark.sql.functions import *
>>> window = Window.orderBy("value").partitionBy("value")                                                                                                            
>>> df2 = df.withColumn("seq",row_number().over(window))                                                                                                                             
>>> df2.show(truncate=False)
 --- ----- ---                                                                  
|key|value|seq|
 --- ----- --- 
|1  |Bob  |1  |
|2  |Bob  |2  |
|3  |Alice|1  |
|4  |Alice|2  |
|5  |Alice|3  |
 --- ----- --- 

>>> N = 2
>>> df3 = df2.where("seq <= %d" % N).drop("seq")                                                                                                                                     
>>> df3.show(truncate=False)
 --- -----                                                                      
|key|value|
 --- ----- 
|1  |Bob  |
|2  |Bob  |
|3  |Alice|
|4  |Alice|
 --- ----- 

>>> 
  • Related