Home > Net >  Stratified sampler
Stratified sampler

Time:10-01

I have a dataframe with the following structure:


import pandas as pd


df = pd.DataFrame({
    "x": [0, 0, 1, 1, 0, 0, 1, 1],
    "y": [1, 2, 1, 2, 2, 2, 1, 1],
})

I want a function to generate a random column containing "A" and "B" such that, for a given subset of the columns (say "x"), we have the same appearances of "A" and "B". If we choose "x" as the strata column, a possible outcome is:


import pandas as pd


df = pd.DataFrame({
    "x": [0, 0, 1, 1, 0, 0, 1, 1],
    "y": [1, 2, 1, 2, 2, 2, 1, 1],
    "outcome": ["A", "B", "A", "B", "A", "B", "A", "B"]
})

Keep in mind that we might have both x and y as a subset of columns (that would mean that, for each pair of x, y values, we should have the same number of A and B). Of course, if there is an even combination of A and B, we may not be able to do that, in this case we should have at most 1 more A than B, and viceversa.

Can you help me with that? Thanks!

CodePudding user response:

If you want an exact proportion (in the limit of the parity), you can use groupby.sample:

chosen = df.groupby('x').sample(frac=0.5).index

df['outcome'] = np.where(df.index.isin(chosen), 'A', 'B')

example output:

   x  y outcome
0  0  1       A
1  0  2       A
2  1  1       B
3  1  2       A
4  0  2       B
5  0  2       B
6  1  1       A
7  1  1       B

Generalization to N categories.

Here we need to change strategy. Let's shuffle the DataFrame with sample(frac=1), then assign an order modulo N (per group), finally we map the value to the categories:

cat = ['A', 'B', 'C']
# for more randomness
# np.random.shuffle(cat)

df['outcome'] = (df.sample(frac=1)
                   .groupby('x').cumcount().mod(len(cat))
                   .map(dict(enumerate(cat)))
                )

output:

   x  y outcome
0  0  1       C
1  0  2       A
2  1  1       A
3  1  2       A
4  0  2       B
5  0  2       A
6  1  1       C
7  1  1       B

Check of equal distribution on large sample:

x  outcome
0  A          0.333333
   B          0.333333
   C          0.333333
1  A          0.333333
   B          0.333333
   C          0.333333
dtype: float64
  • Related