I received an unknown error in Python Polars:
thread '<unnamed>' panicked at 'assertion failed: `(left == right)`
left: `Float64[NaN, 1, NaN, NaN, NaN, ...[clip]...
right: `Float64[NaN, 1, NaN, NaN, NaN, ...[clip]...
Is this an internal error?
The code that triggers it is:
df.select([
pl.col('total').shift().ewm_mean(half_life = 10).over('group')
])
It's hard for me to ask more because the error is so inscrutable.
CodePudding user response:
Another temporary way to work around this is to create the result of shift
with an over
window in another way.
Let's say we have the following groups, numbered observations, and totals.
import numpy as np
import polars as pl
df = pl.DataFrame(
{
"group": ["a", "a", "b", "a", "b", "b"],
"obs": [1, 2, 1, 3, 2, 3],
"total": [1.0, 2, 3, 4, 5, np.NaN],
}
)
df
shape: (6, 3)
┌───────┬─────┬───────┐
│ group ┆ obs ┆ total │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 │
╞═══════╪═════╪═══════╡
│ a ┆ 1 ┆ 1.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ a ┆ 2 ┆ 2.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b ┆ 1 ┆ 3.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ a ┆ 3 ┆ 4.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b ┆ 2 ┆ 5.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b ┆ 3 ┆ NaN │
└───────┴─────┴───────┘
The following code will arrive at the same result as the shift
over the groups:
df = (
df.sort(["group", "obs"])
.with_column(pl.col("total").shift().alias("total_shifted"))
.with_column(
pl.when(pl.col("group").is_first())
.then(None)
.otherwise(pl.col("total_shifted"))
.alias("result")
)
)
df
shape: (6, 5)
┌───────┬─────┬───────┬───────────────┬────────┐
│ group ┆ obs ┆ total ┆ total_shifted ┆ result │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 ┆ f64 ┆ f64 │
╞═══════╪═════╪═══════╪═══════════════╪════════╡
│ a ┆ 1 ┆ 1.0 ┆ null ┆ null │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ a ┆ 2 ┆ 2.0 ┆ 1.0 ┆ 1.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ a ┆ 3 ┆ 4.0 ┆ 2.0 ┆ 2.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ b ┆ 1 ┆ 3.0 ┆ 4.0 ┆ null │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ b ┆ 2 ┆ 5.0 ┆ 3.0 ┆ 3.0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ b ┆ 3 ┆ NaN ┆ 5.0 ┆ 5.0 │
└───────┴─────┴───────┴───────────────┴────────┘
(I've left the intermediate calculations in the dataset for inspection, to show how the algorithm works.)
Notice that the result
column is the same value you'd obtained from a shift
over groups. You can then run your aggregations on the result
column, without the need for using shift.
df.select([
pl.col('result').ewm_mean(half_life = 10).over('group')
])
Of course, you'll have to adapt this to your particular code, but it should work.
CodePudding user response:
This indeed looks like a bug. It's coming from when shift
is called on an expression that contains NaN
values in a window function (over
).
import polars as pl
import numpy as np
df = pl.DataFrame(
{
"group": ["a", "a", "a", "b", "b", "b"],
"total": [1.0, 2, 3, 4, 5, np.NaN],
}
)
df.select([
pl.col('total').shift().over('group')
])
thread '<unnamed>' panicked at 'assertion failed: `(left == right)`
left: `Float64[4, 5, NaN]`,
right: `Float64[4, 5, NaN]`', /github/workspace/polars/polars-core/src/series/unstable.rs:39:9
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/corey/.virtualenvs/StackOverflow3.10/lib/python3.10/site-packages/polars/internals/frame.py", line 4253, in select
self.lazy()
File "/home/corey/.virtualenvs/StackOverflow3.10/lib/python3.10/site-packages/polars/internals/lazy_frame.py", line 476, in collect
return self._dataframe_class._from_pydf(ldf.collect())
pyo3_runtime.PanicException: assertion failed: `(left == right)`
left: `Float64[4, 5, NaN]`,
right: `Float64[4, 5, NaN]`
Since you are using the sum
aggregation, can you use fill_nan(0)
to work around the issue? Or do you need to retain the NaN
value in those cases?
df.select([
pl.col('total').fill_nan(0).shift().sum().over('group')
])
shape: (6, 1)
┌─────────┐
│ literal │
│ --- │
│ f64 │
╞═════════╡
│ 3.0 │
├╌╌╌╌╌╌╌╌╌┤
│ 3.0 │
├╌╌╌╌╌╌╌╌╌┤
│ 3.0 │
├╌╌╌╌╌╌╌╌╌┤
│ 9.0 │
├╌╌╌╌╌╌╌╌╌┤
│ 9.0 │
├╌╌╌╌╌╌╌╌╌┤
│ 9.0 │
└─────────┘
I'll create an issue for it on GitHub.