Home > OS >  In pandas, how to replace values matching a condition with a value from the names of their columns?
In pandas, how to replace values matching a condition with a value from the names of their columns?

Time:01-21

I'm writing a function that takes as input a dataframe and a "mask". The dataframe's assumed to have multiindex columns such as ("some string", 0.4): pairs where the second object is numeric. The mask is intended to be something like df < 2, df >= 4, etc.

The output should be a new table where every value that doesn't match the mask is left alone, and every value that does is replaced by the number of the name of its column.

NaNs in the input should be left alone (unless of course the mask is something like df.isna()).

This is what I've come up with (assume this is in a file called mytable.py):

import pandas as pd
import numpy as np


data = {
    ("A", 0.2): [4.0, 1.0, np.nan],
    ("B", 0.6): [0.0, np.nan, 4.0],
    ("C", 0.7): [0.0, 5.0, 1.0],
}
df = pd.DataFrame(data)


def replaced_with_colname(table, mask):
    series1 = (table[col][mask[col]] for col in table.columns)
    series2 = (s.apply(lambda x: s.name[1]) for s in series1)
    t2 = table.copy()
    for s in series2:
        t2.update(s)
    return t2

An example:

$ python3 -i mytable.py
>>> df
     A    B    C
   0.2  0.6  0.7
0  4.0  0.0  0.0
1  1.0  NaN  5.0
2  NaN  4.0  1.0
>>> replaced_with_colname(df, df>2)
     A    B    C
   0.2  0.6  0.7
0  0.2  0.0  0.0
1  1.0  NaN  0.7
2  NaN  0.6  1.0

It seems to do the job, but it seems convoluted and probably slow, though I didn't benchmark it. My question is: is there a (more) "vectorized", idiomatic way of doing it? Using more pandas methods and fewer for-loops?

Similar questions that helped me, and why they're not exactly what I'm trying to do:

CodePudding user response:

It's a perfect use case for np.where: if mask is True returns the second level index values else keep as it.

def replaced_with_colname(table, mask):
    data = np.where(mask, df.columns.levels[1], df)
    return pd.DataFrame(data, index=table.index, columns=table.columns)

Usage:

>>> replaced_with_colname(df, df>2)
     A    B    C
   0.2  0.6  0.7
0  0.2  0.0  0.0
1  1.0  NaN  0.7
2  NaN  0.6  1.0

>>> replaced_with_colname(df, df.isna())
     A    B    C
   0.2  0.6  0.7
0  4.0  0.0  0.0
1  1.0  0.6  5.0
2  0.2  4.0  1.0

>>> replaced_with_colname(df, (0<=df) & (df<=1) | df.isna())
     A    B    C
   0.2  0.6  0.7
0  4.0  0.6  0.7
1  0.2  0.6  5.0
2  0.2  4.0  0.7

CodePudding user response:

You can approach this by using pandas.Index.get_level_values :

out = (
        df
          .gt(2)
          .mul(df.columns.get_level_values(1))
          .mask(lambda d: [d[col].eq(0) for col in d.columns])
          .combine_first(df)
      )

The comparison operators (eq, ne, le, lt, ge, gt) are equivalent to (==, !=, <=, <, >=, >).

​Output :

print(out)

     A    B    C
   0.2  0.6  0.7
0  0.2  0.0  0.0
1  1.0  NaN  0.7
2  NaN  0.6  1.0

If you need a custom function :

def replace_with_colname(table, cond):
    out = (
            df[cond]
              .mul(df.columns.get_level_values(1))
              .mask(lambda d: [d[col].eq(0) for col in d.columns])
              .combine_first(df)
          )
    return out
  • Related