Does anybody know how to select multiple columns from a pandas DataFrame whose contents meet multiple complex conditions? For example, the columns I'm trying to select MUST contain at least one instance of the integers 1 through 4 and MAY contain np.nan but CANNOT contain any other integers (below 1 or above 4) or datatypes.
I know I can do this by executing many lines of code in a row, but I'm hoping to find a better solution (especially one that will catch outside cases).
Example:
# import packages
import pandas as pd
import numpy as np
# Initialize list of lists
data = [[1, 4, 3, 5], [2, np.nan, 1, 3], [4, 3, 2, 1], [3, 3, 1, 4]]
# Create the pandas DataFrame
df = pd.DataFrame(data, columns=['A', 'B', 'C', 'D'])
print(df)
Output:
A B C D
0 1 4.0 3 5
1 2 NaN 1 3
2 4 3.0 2 1
3 3 3.0 1 4
Execute many statements or some optimized code to fulfill above conditions
df = df.select_dtypes(include='number')
df = df.loc[:,[(df[col] == 1).any() for col in df.columns]]
df = df.loc[:,[(df[col] == 2).any() for col in df.columns]]
df = df.loc[:,[(df[col] == 3).any() for col in df.columns]]
df = df.loc[:,[(df[col] == 4).any() for col in df.columns]]
df = df.loc[:,[~(df[col] > 4).any() for col in df.columns]]
df = df.loc[:,[~(df[col] < 1).any() for col in df.columns]]
print(df)
Output:
A
0 1
1 2
2 4
3 3
CodePudding user response:
import numpy as np
import pandas as pd
data = [[1, 4, 3, 5], [2, np.nan, 1, 3], [4, 3, 2, 1], [3, 3, 1, 4]]
df = pd.DataFrame(data, columns=['A', 'B', 'C', 'D'])
mat = df.values
cols = (~((mat < 1) | (mat > 4))).all(axis=0) & np.in1d(mat, np.array(range(1,5))).reshape(mat.shape).any(axis=0)
print(df.loc[:,cols])
Output:
A B C
0 1 4.0 3
1 2 NaN 1
2 4 3.0 2
3 3 3.0 1
CodePudding user response:
Combine them in one call and pass to loc
:
(df
.select_dtypes('number')
.loc(axis=1)[lambda df: df.eq(1).any() &
df.eq(2).any() &
df.eq(3).any() &
df.eq(4).any() &
df.le(4).any() &
df.ge(1).any()
]
)
A
0 1
1 2
2 4
3 3
Taking it a bit further, let's encapsulate your logic with less steps:
trim = df.transform(lambda df: pd.Series([1,2,3,4]).isin(df)).all()
df.loc(axis=1)[trim]
A
0 1
1 2
2 4
3 3