Home > OS >  Parallel querying indices for a list of filter expressions in polars dataframe
Parallel querying indices for a list of filter expressions in polars dataframe

Time:11-12

I want to get the indices for a list of filters in polars and get a sparse matrix from it, how can I parallel the process? This is what I have right now, a pretty naive and brute force way for achieving what I need, but this is having some serious performance issue

def get_sparse_matrix(exprs: list[pl.Expr]) -> scipy.sparse.csc_matrix:
    df = df.with_row_count('_index')
    rows: list[int] = []
    cols: list[int] = []
    for col, expr in enumerate(exprs):
        r = self.df.filter(expr)['_index']
        rows.extend(r)
        cols.extend([col] * len(r))

    X = csc_matrix((np.ones(len(rows)), (rows, cols)), shape= 
   (len(self.df), len(rules)))

    return X

Example Input:

# df is a polars dataframe with size 8 * 3
df = pl.DataFrame(
[[1,2,3,4,5,6,7,8], 
[3,4,5,6,7,8,9,10], 
[5,6,7,8,9,10,11,12],
[5,6,41,8,21,10,51,12],
])

# three polars expressions
exprs = [pl.col('column_0') > 3, pl.col('column_1') < 6, pl.col('column_4') > 11]

Example output: X is a sparse matrix of size 8 (number of records) X 3 (number of expressions), where the element at i,j equals to 1 if ith record matches the jth expression

CodePudding user response:

So I am not completely sure what exactly you want, but I hope that satisfies your needs

import polars as pl
from scipy.sparse import csc_matrix
import numpy as np

df = pl.DataFrame(
    [[1,2,3,4,5,6,7,8], 
    [3,4,5,6,7,8,9,10], 
    [5,6,7,8,9,10,11,12],
    [5,6,41,8,21,10,51,12],
])


exprs = [(pl.col('column_0') > 3).cast(pl.Int8), 
         (pl.col('column_1') < 6).cast(pl.Int8), 
         (pl.col('column_3') > 11).cast(pl.Int8)]

X = df.select(exprs)
csc_matrix(X.to_numpy())

CodePudding user response:

A GroupBy object is a mapping from a key to a list of indices and has a very fast implementation in polars. You could do something like this:

(df
 .with_column((pl.col('column_0') > 3).alias('e1'))
 .groupby('e1')
 ._groups()
 .filter(pl.col("e1"))
)[0,1]

See my recent blog post on this for more detail: https://braaannigan.github.io/software/2022/10/11/polars-index.html

  • Related