Home > database >  Classify DataFrame rows based on first matching condition
Classify DataFrame rows based on first matching condition

Time:05-14

I have a pandas DataFrame, each column represents a quarter, the most recent quarters are placed to the right, not all the information gets at the same time, some columns might be missing information (NaN values)

I would like to create a new column with the first criteria number that the row matches, or zero if it doesn't match any criteria

The criteria gets applied to the 3 most recent columns that have data (an integer, ignoring NaNs) and a match is considered if the value in the list is greater than or equal to its corresponding value in the DataFrame

I tried using apply, but I couldn't make it work and the failed attempts were slow

import pandas as pd
import numpy as np

criteria_dict = {
    1: [10, 0, 10]
    , 2: [0, 10, 10]
    }

list_of_tuples = [
    (78, 7, 11, 15),  # classify as 2 since  7 >= 0, 11 >= 10, 15 >= 10
    (98, -5, np.NaN, 18), # classify as 0, ignoring NaN it doesn't match any criteria because of the -5
    (-78, 20, 64, 28),    # classify as 1  20 >= 10, 64 >= 0, 28 >= 10
    (35, 63, 27, np.NaN), # classify as 1, NaN value should be ignored, 35 >= 10, 63 >=0, 27 >= 10
    (-11, 0, 56, 10) # classify as 2,   0 >= 0, 56 >= 10, 10 >= 10
]

df = pd.DataFrame(
    list_of_tuples,
    index=['A', 'B', 'C', 'D', 'E'],
    columns=['2021Q2', '2021Q3', '2021Q4', '2022Q1']
)

print(df)

CodePudding user response:

Applying a custom function to each row should work.

def func(x):
    x = x.dropna().to_numpy()[-3:]
    if len(x) < 3:
        return 0
    for k, v in criteria_dict.items():
        if np.all(x >= v):
            return k
    return 0

df.apply(func, axis=1)

CodePudding user response:

Probably using apply is best, but I wanted to try a solution with numpy, which should be faster with DataFrames with many rows.

import numpy as np

# Rows with too many NaNs.
df_nans = df[df.isna().sum(axis=1) > len(df.columns)-3]
df_valid = df[df.isna().sum(axis=1) <= len(df.columns)-3]

df_arr = df_valid.to_numpy()

# Find NaNs.
nans = np.nonzero(np.isnan(df_arr))

# Roll the rows so that the latest three columns with valid data are all to the right.
for row, col in zip(*nans):
    df_arr[row, :] = np.roll(df_arr[row, :], shift=np.shape(df_arr)[1]-col)

criterias = np.zeros(len(df_arr))

# Check for matching criteria.
for crit in criteria_dict:
    matching_crit = np.all((df_arr[:, 1:] - criteria_dict[crit])>=0, axis=1)
    criterias[matching_crit & (criterias == 0)] = crit

# Add the invalid rows back.
df = pd.concat([df_valid, df_nans])
df['criteria'] = np.concatenate((criterias, np.zeros(len(df_nans))))

print(df)
   2021Q2  2021Q3  2021Q4  2022Q1  criteria
A      78       7    11.0    15.0       2.0
B      98      -5     NaN    18.0       0.0
C     -78      20    64.0    28.0       1.0
D      35      63    27.0     NaN       1.0
E     -11       0    56.0    10.0       2.0
  • Related