Home > Enterprise >  Polars - Perform matrix inner product on lazy frames to produce sparse representation of gram matrix
Polars - Perform matrix inner product on lazy frames to produce sparse representation of gram matrix

Time:08-19

Suppose we have a polars dataframe like:

df = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}).lazy()

shape: (3, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1   ┆ 3   │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 2   ┆ 4   │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 3   ┆ 5   │
└─────┴─────┘

I would like to X^TX the matrix while preserving the sparse matrix format for arrow* - in pandas I would do something like:

pdf = df.collect().to_pandas()
numbers = pdf[["a", "b"]]
(numbers.T @ numbers).melt(ignore_index=False)

  variable  value
a        a     14
b        a     26
a        b     26
b        b     50

I did something like this in polars:

df.select(
    [
        (pl.col("a") * pl.col("a")).sum().alias("aa"),
        (pl.col("a") * pl.col("b")).sum().alias("ab"),
        (pl.col("b") * pl.col("a")).sum().alias("ba"),
        (pl.col("b") * pl.col("b")).sum().alias("bb"),
    ]
).melt().collect()

shape: (4, 2)
┌──────────┬───────┐
│ variable ┆ value │
│ ---      ┆ ---   │
│ str      ┆ i64   │
╞══════════╪═══════╡
│ aa       ┆ 14    │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ ab       ┆ 26    │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ ba       ┆ 26    │
├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ bb       ┆ 50    │
└──────────┴───────┘

Which is almost there but not quite. This is a hack to get around the fact that I can't store lists as the column names (and then I could unnest them to become two different columns representing the x and y axis of the matrix). Is there a way to get the same format as shown in the pandas example?

*arrow is a columnar data format which means it's performant when scaled across rows but not across columns, which is why I think the sparse matrix representation is better if I want to use the results of the gram matrix chained with pl.LazyFrames later down the graph. I could be wrong though!

CodePudding user response:

Polars doesn't have matrix multiplication, but we can tweak your algorithm slightly to accomplish what we need:

  • use the built-in dot expression
  • calculate each inner product only once, since <a, b> = <b, a>. We'll use Python's combinations_with_replacement iterator from itertools to accomplish this.
  • automatically generate the list of expressions that will run in parallel

Let's expand your data a bit:

from itertools import combinations_with_replacement
import polars as pl

df = pl.DataFrame(
    {"a": [1, 2, 3, 4, 5], "b": [3, 4, 5, 6, 7], "c": [5, 6, 7, 8, 9]}
).lazy()
df.collect()
shape: (5, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 3   ┆ 5   │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌┤
│ 2   ┆ 4   ┆ 6   │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌┤
│ 3   ┆ 5   ┆ 7   │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌┤
│ 4   ┆ 6   ┆ 8   │
├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌┤
│ 5   ┆ 7   ┆ 9   │
└─────┴─────┴─────┘

The algorithm would be as follows:

expr_list = [
    pl.col(col1).dot(pl.col(col2)).alias(col1   "|"   col2)
    for col1, col2 in combinations_with_replacement(df.columns, 2)
]

dot_prods = (
    df
    .select(expr_list)
    .melt()
    .with_column(
        pl.col('variable').str.split_exact('|', 1)
    )
    .unnest('variable')
    .cache()
)

result = (
    pl.concat([
        dot_prods,
        dot_prods
        .filter(pl.col('field_0') != pl.col('field_1'))
        .select(['field_1', 'field_0', 'value'])
        .rename({'field_0':'field_1', 'field_1': 'field_0'})
        ],
    )
    .sort(['field_0', 'field_1'])
)
result.collect()
shape: (9, 3)
┌─────────┬─────────┬───────┐
│ field_0 ┆ field_1 ┆ value │
│ ---     ┆ ---     ┆ ---   │
│ str     ┆ str     ┆ i64   │
╞═════════╪═════════╪═══════╡
│ a       ┆ a       ┆ 55    │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ a       ┆ b       ┆ 85    │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ a       ┆ c       ┆ 115   │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b       ┆ a       ┆ 85    │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b       ┆ b       ┆ 135   │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ b       ┆ c       ┆ 185   │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ c       ┆ a       ┆ 115   │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ c       ┆ b       ┆ 185   │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ c       ┆ c       ┆ 255   │
└─────────┴─────────┴───────┘

Couple of notes:

  • I'm assuming that a pipe would be an appropriate delimiter for your column names.
  • The use of Python bytecode and iterator will not significantly impair performance. It is only used to generate the list of expressions, not run any calculations.
  • Related