Home > Blockchain >  Create a boolean column for condition being met in the following n rows in Pandas
Create a boolean column for condition being met in the following n rows in Pandas

Time:08-24

I have a Pandas DataFrame like the following:

df = pd.DataFrame({
    'other_stuff': ['lorem', 'ipsum', 'dolor', 'sit', 'amet', 'consectetur', 'adipiscing', 'elit', 'sed', 'do'],
    'value': [12.0, 12.1, 11.9, 12.1, 12.4, 12.1, 12.2, 12.1, 11.8, 12.5]
})

  other_stuff  value
0       lorem   12.0
1       ipsum   12.1
2       dolor   11.9
3         sit   12.1
4         amet  12.4
5  consectetur  12.1
6   adipiscing  12.2
7         elit  12.1
8          sed  11.8
9           do  12.5

and I want to create a boolean column that is True if the value of 12.3 is exceeded in the following 3 rows. So for index 0, it you look at the df['value'] at indices 1, 2, and 3, and return False since the threshold was not exceeded in those 3 rows. For the value of the new column at index 1, it would check df['value'] at indices 2, 3, and 4, and return True since the value was exceeded somewhere in that range. And so on for the rest of the DataFrame.

The final DataFrame would look like:

  other_stuff  value  value > 12.3 in next 3 rows
0       lorem   12.0                        False
1       ipsum   12.1                         True
2       dolor   11.9                         True
3         sit   12.1                         True
4         amet  12.4                        False
5  consectetur  12.1                        False
6   adipiscing  12.2                         True
7         elit  12.1                         True
8          sed  11.8                         True
9           do  12.5                        False 

CodePudding user response:

you can use shift for this

df = pd.DataFrame({
    'other_stuff': ['lorem', 'ipsum', 'dolor', 'sit', 'amet', 'consectetur', 'adipiscing', 'elit', 'sed', 'do'],
    'value': [12.0, 12.1, 11.9, 12.1, 12.4, 12.1, 12.2, 12.1, 11.8, 12.5]
})

threshold = 12.3
df['flag'] = (df['value'].shift(-1)>threshold) | (df['value'].shift(-2)>threshold) | (df['value'].shift(-3)>threshold)

result

    other_stuff value   flag
0   lorem       12.0    False
1   ipsum       12.1    True
2   dolor       11.9    True
3   sit         12.1    True
4   amet        12.4    False
5   consectetur 12.1    False
6   adipiscing  12.2    True
7   elit        12.1    True
8   sed         11.8    True
9   do          12.5    False

CodePudding user response:

If I understand you correctly, you are testing if any of the following 3 rows has a value greater than 12.3.

df["matched"] = (
    # rolling works on "previous 3 rows", not "following 3 rows" 
    # so we need to reverse the column
    df["value"][::-1] 
    .rolling(3, min_periods=1, closed="left")
    .max()
    .gt(12.3)
)
  • Related