Home > Back-end >  Use np.where to check if a column string exists in another list column
Use np.where to check if a column string exists in another list column

Time:10-28

I have a large DataFrame with 100 million records, I am trying to optimize the run time by using numpy.

Sample data:

dat = pd.DataFrame({'ID' : [1,2,3,4,5],
                   'item' : ['beauty', 'beauty', 'shoe','shoe','handbag'],
                   'mylist' : [['beauty','something'], ['shoe', 'something', 'else'], ['shoe', 'else','some'], ['else'], ['some', 'thing', 'else']]})


dat

    ID  item    mylist
0   1   beauty  [beauty, something]
1   2   beauty  [shoe, something, else]
2   3   shoe    [shoe, else, some]
3   4   shoe    [else]
4   5   handbag [some, thing, else]

I am trying to filter those rows where item column's string exists in mylist column using:

dat[np.where(dat['item'].isin(dat['mylist']), True, False)]

But I am not getting any output and all of above values as False. I could get the required results using:

dat[dat.apply(lambda row : row['item'] in row['mylist'], axis = 1)]

    ID  item    mylist
0   1   beauty  [beauty, something]
2   3   shoe    [shoe, else, some]

But as numpy operations are faster, I am trying to use np.where. Could someone please let me know who to fix the code?

CodePudding user response:

You can't vectorize easily with Series of lists, you can use a list comprehension to be a bit faster than apply:

out = dat.loc[[item in l for item,l in zip(dat['item'], dat['mylist'])]]

A vectorial solution would be:

out = dat.loc[dat.explode('mylist').eval('item == mylist').groupby(level=0).any()]

# or
out = dat.explode('mylist').query('item == mylist').groupby(level=0).first()

# or, if you are sure that there is at most 1 match
out = dat.explode('mylist').query('item == mylist')

But the explode step might be a bottleneck. You must try with your real data.

output:

   ID    item               mylist
0   1  beauty  [beauty, something]
2   3    shoe   [shoe, else, some]

timing

I ran a quick test on 100k rows (using df = pd.concat([dat]*20000, ignore_index=True))

  • the list comprehension is the fastest (~20ms)
  • explode approaches are between 60-90ms (explode itself requiring 40ms)
  • apply is by far the slowest (almost 600ms)
  • Related