Home > Enterprise >  How would I subtract N specific rows from a PySpark DataFrame?
How would I subtract N specific rows from a PySpark DataFrame?

Time:11-16

I have a dataframe, nsdf, which I would like to sample 5% of. nsdf looks something like this:

col1
8
7
7
8
7
8
8
7
(... and so on)

I sample nsdf like so:

sdf = nsdf.sample(0.05)

I would then like to remove the rows in sdf from nsdf. Now, here I would think I could use nsdf.subtract(sdf), but that would remove all rows in nsdf that match any row from sdf. For example, if sdf contained

col1
7
8

Then every row in nsdf would be removed, as they are all either a 7 or an 8. Is there a way to remove only the number of 7's/8's (or whatever else) that appears in sdf? More specifically, in this example I would like to end up with an nsdf that contains the same data but one 7 fewer and one 8 fewer.

CodePudding user response:

The behavior of subtract is to remove all instances of a row in the left dataframe if present in the right dataframe. What you are looking for is exceptAll.

Example:

Data Setup

df = spark.createDataFrame([(7,), (8,), (7,), (8,)], ("col1", ))

Scenario 1:


df1 = spark.createDataFrame([(7,), (8,)], ("col1", ))

df.exceptAll(df1).show()

Output

 ---- 
|col1|
 ---- 
|   7|
|   8|
 ---- 

Scenario 2:

df2 = spark.createDataFrame([(7,), (7,), (8,)], ("col1", ))

df.exceptAll(df2).show()

Output

 ---- 
|col1|
 ---- 
|   8|
 ---- 
  • Related