Home > database >  Sample from each group in polars dataframe?
Sample from each group in polars dataframe?

Time:06-16

I'm looking for a function along the lines of

df.groupby('column').agg(sample(10))

so that I can take ten or so randomly-selected elements from each group.

This is specifically so I can read in a LazyFrame and work with a small sample of each group as opposed to the entire dataframe.

Update:

One approximate solution is:

df = lf.groupby('column').agg(
        pl.all().sample(.001)
    )
df = df.explode(df.columns[1:])

Update 2

That approximate solution is just the same as sampling the whole dataframe and doing a groupby after. No good.

CodePudding user response:

Let start with some dummy data:

n = 100
seed = 0
df = pl.DataFrame(
    {
        "groups": (pl.arange(0, n, eager=True) % 5).shuffle(seed=seed),
        "values": pl.arange(0, n, eager=True).shuffle(seed=seed)
    }
)
df
shape: (100, 2)
┌────────┬────────┐
│ groups ┆ values │
│ ---    ┆ ---    │
│ i64    ┆ i64    │
╞════════╪════════╡
│ 0      ┆ 55     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0      ┆ 40     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 57     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 99     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ...    ┆ ...    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 87     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1      ┆ 96     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3      ┆ 43     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 44     │
└────────┴────────┘

This gives us 100 / 5, is 5 groups of 20 elements. Let's verify that:

df.groupby("groups").agg(pl.count())
shape: (5, 2)
┌────────┬───────┐
│ groups ┆ count │
│ ---    ┆ ---   │
│ i64    ┆ u32   │
╞════════╪═══════╡
│ 1      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2      ┆ 20    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 0      ┆ 20    │
└────────┴───────┘

Sample our data

Now we are going to use a window function to take a sample of our data.

df.filter(
    pl.arange(0, pl.count()).shuffle().over("groups") < 10
)
shape: (50, 2)
┌────────┬────────┐
│ groups ┆ values │
│ ---    ┆ ---    │
│ i64    ┆ i64    │
╞════════╪════════╡
│ 0      ┆ 85     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0      ┆ 0      │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 84     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 19     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ...    ┆ ...    │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 87     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1      ┆ 96     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3      ┆ 43     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4      ┆ 44     │
└────────┴────────┘

For every group in over("group") the pl.arange(0, pl.count()) expression creates an index row. We then shuffle that range so that we take a sample and not a slice. Then we only want to take the index values that are lower than 10. This creates a boolean mask that we can pass to the filter method.

CodePudding user response:

We can try making our own groupby-like functionality and sampling from the filtered subsets.

samples = []
cats = df.get_column('column').unique().to_list()
for cat in cats:
    samples.append(df.filter(pl.col('column') == cat).sample(10))
samples = pl.concat(samples)

Found partition_by in the documentation, this should be more efficient, since at least the groups are made with the api and in single pass of the dataframe. Sampling each group is still linear unfortunately.

pl.concat([x.sample(10) for x in df.partition_by(groups="column")])

Third attempt, sampling indices:

import numpy as np
import random

indices = df.groupby("group").agg(pl.col("value").agg_groups()).get_column("value").to_list()
sampled = np.array([random.sample(x, 10) for x in indices]).flatten()
df[sampled]
  • Related