I'm looking for a function along the lines of
df.groupby('column').agg(sample(10))
so that I can take ten or so randomly-selected elements from each group.
This is specifically so I can read in a LazyFrame and work with a small sample of each group as opposed to the entire dataframe.
Update:
One approximate solution is:
df = lf.groupby('column').agg(
pl.all().sample(.001)
)
df = df.explode(df.columns[1:])
Update 2
That approximate solution is just the same as sampling the whole dataframe and doing a groupby after. No good.
CodePudding user response:
Let start with some dummy data:
n = 100
seed = 0
df = pl.DataFrame(
{
"groups": (pl.arange(0, n, eager=True) % 5).shuffle(seed=seed),
"values": pl.arange(0, n, eager=True).shuffle(seed=seed)
}
)
df
shape: (100, 2)
┌────────┬────────┐
│ groups ┆ values │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞════════╪════════╡
│ 0 ┆ 55 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0 ┆ 40 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2 ┆ 57 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 99 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ... ┆ ... │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2 ┆ 87 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1 ┆ 96 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3 ┆ 43 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 44 │
└────────┴────────┘
This gives us 100 / 5, is 5 groups of 20 elements. Let's verify that:
df.groupby("groups").agg(pl.count())
shape: (5, 2)
┌────────┬───────┐
│ groups ┆ count │
│ --- ┆ --- │
│ i64 ┆ u32 │
╞════════╪═══════╡
│ 1 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 0 ┆ 20 │
└────────┴───────┘
Sample our data
Now we are going to use a window function to take a sample of our data.
df.filter(
pl.arange(0, pl.count()).shuffle().over("groups") < 10
)
shape: (50, 2)
┌────────┬────────┐
│ groups ┆ values │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞════════╪════════╡
│ 0 ┆ 85 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0 ┆ 0 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 84 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 19 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ... ┆ ... │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2 ┆ 87 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1 ┆ 96 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3 ┆ 43 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 44 │
└────────┴────────┘
For every group in over("group")
the pl.arange(0, pl.count())
expression creates an index row. We then shuffle
that range so that we take a sample and not a slice. Then we only want to take the index values that are lower than 10. This creates a boolean mask
that we can pass to the filter
method.
CodePudding user response:
We can try making our own groupby-like functionality and sampling from the filtered subsets.
samples = []
cats = df.get_column('column').unique().to_list()
for cat in cats:
samples.append(df.filter(pl.col('column') == cat).sample(10))
samples = pl.concat(samples)
Found partition_by
in the documentation, this should be more efficient, since at least the groups are made with the api and in single pass of the dataframe. Sampling each group is still linear unfortunately.
pl.concat([x.sample(10) for x in df.partition_by(groups="column")])
Third attempt, sampling indices:
import numpy as np
import random
indices = df.groupby("group").agg(pl.col("value").agg_groups()).get_column("value").to_list()
sampled = np.array([random.sample(x, 10) for x in indices]).flatten()
df[sampled]