Home > database >  Sample Pandas dataframe based on multiple values in column
Sample Pandas dataframe based on multiple values in column

Time:03-24

I'm trying to even up a dataset for machine learning. There are great answers for how to sample a dataframe with two values in a column (a binary choice).

In my case I have many values in column x. I want an equal number of records in the dataframe where

  • x is 0 or not 0
  • or in a more complicated example the value in x is 0, 5 or other value

Examples

     x
0    5
1    5
2    5
3    0
4    0
5    9
6   18
7    3
8    5

** For the first ** I have 2 rows where x = 0 and 7 where x != 0. The result should balance this up and be 4 rows: the two with x = 0 and 2 where x != 0 (randomly selected). Preserving the same index for the sake of illustration

1    5
3    0
4    0
6   18

** For the second ** I have 2 rows where x = 0, 4 rows where x = 5 and 3 rows where x != 0 && x != 5. The result should balance this up and be 6 rows in total: two for each condition. Preserving the same index for the sake of illustration

1    5
3    0
4    0
5    9
6   18
8    5

I've done examples with 2 conditions & 3 conditions. A solution that generalises to more would be good. It is better if it detects the minimum number of rows (for 0 in this example) so I don't need to work this out first before writing the condition.

How do I do this with pandas? Can I pass a custom function to .groupby() to do this?

CodePudding user response:

IIUC, you could groupby on the condition whether "x" is 0 or not and sample the smallest-group-size number of entries from each group:

g = df.groupby(df['x']==0)['x']
out = g.sample(n=g.count().min()).sort_index()

(An example) output:

1    5
3    0
4    0
5    9
Name: x, dtype: int64

For the second case, we could use numpy.select and numpy.unique to get the groups (the rest are essentially the same as above):

import numpy as np
groups = np.select([df['x']==0, df['x']==5], [1,2], 3)
g = df.groupby(groups)['x']
out = g.sample(n=np.unique(groups, return_counts=True)[1].min()).sort_index()

An example output:

2    5
3    0
4    0
5    9
7    3
8    5
Name: x, dtype: int64

CodePudding user response:

IIUC, and you want any two non-zero records:

mask = df['x'].eq(0)
pd.concat([df[mask], df[~mask].sample(mask.sum())]).sort_index()

Output:

   x
1  5
2  5
3  0
4  0

Part II:

mask0 = df['x'].eq(0)
mask5 = df['x'].eq(5)
pd.concat([df[mask0], 
           df[mask5].sample(mask0.sum()), 
           df[~(mask0 | mask5)].sample(mask0.sum())]).sort_index()

Output:

    x
2   5
3   0
4   0
6  18
7   3
8   5
  • Related