I am trying to understand how to use .all, for example:
import pandas as pd
df = pd.DataFrame({
"user_id": [1,1,1,1,1,2,2,2,3,3,3,3],
"score": [1,2,3,4,5,3,4,5,5,6,7,8]
})
When I try:
df.groupby("user_id").all(lambda x: x["score"] > 2)
I get:
score
user_id
1 True
2 True
3 True
But I expect:
score
user_id
1 False # Since for first group of users the score is not greater than 2 for all
2 True
3 True
In fact it doesn't even matter what value I pass instead of 2, the result DataFrame always has True
for the score column.
Why do I get the result that I get? How can I get my expected result?
I looked at the documentation: https://pandas.pydata.org/docs/reference/api/pandas.core.groupby.DataFrameGroupBy.all.html, but it is very brief and did not help me.
CodePudding user response:
the line
df.groupby("user_id").all(lambda x: x["score"] > 2)
is not asking "are all datapoints larger than 2?", in reality is asking "are there datapoints?"
to ask what you really want you need to do the following:
df['score'].gt(2).groupby(df['user_id']).all()
Out
user_id
1 False
2 True
3 True
CodePudding user response:
groupby.all
does not take any function as parameter. The only parameter (skipna
) accepts a boolean and is used to change how NaN values are interpreted.
You probably want:
df['score'].gt(2).groupby(df['user_id']).all()
Which can also be written as:
df.assign(flag=df['score'].gt(2)).groupby('user_id')['flag'].all()