Home > Enterprise >  Filter table on aggregate value and multi-index
Filter table on aggregate value and multi-index

Time:02-22

I have the following data set:

df.head(7)
     Origin        Destination     Date            Quantity
0   Atlanta        LA       2021-09-09      1
1   Atlanta        LA       2021-09-11      4
2   Atlanta        Chicago  2021-09-16      1
3   Atlanta        Seattle  2021-09-27      12
4   Seattle        LA       2021-09-29      2
5   Seattle        Atlanta  2021-09-13      2
6   Seattle        Newark   2021-09-17      7

This table represents the number of items (Quantity) that were sent from a given origin to a given destination on a given date. The table contains 1 month of data. This table was read with:

shipments = pd.read_csv('shipments.csv', parse_dates=['Date'])

Using the shipment data, I can create a new aggregated table that shows me the total quantity shipped between every Origin and Dest pair during this month:

shipments_agg =raw_shipments.groupby(['Origin','Destination']).sum()

As a last step, I'd like to create a new table based on the shipments table, where a row (Origin, Destination, Date, Quantity) is only included if the aggregate Quantity for the (Origin,Destination) pair is larger than 50. In other words, a row (Origin, Destination, Date, Quantity) should only be included if (Origin,Destination) in shipments_agg has a Quantity larger than 50. I'm not quite sure how to accomplish this.

CodePudding user response:

You can do this by using the index from the aggregated data frame to locate the values in the original dataframe.

There's probably a way to do all of this in a mega-one-liner, which I'm not a fan of because of readability / troubleshooting issues, but here is an approach broken down in steps:

In [67]: shipments = pd.read_clipboard()

In [68]: shipments
Out[68]: 
    Origin Destination        Date  Quantity
0  Atlanta          LA  2021-09-09         1
1  Atlanta          LA  2021-09-11         4
2  Atlanta     Chicago  2021-09-16         1
3  Atlanta     Seattle  2021-09-27        12
4  Seattle          LA  2021-09-29         2
5  Seattle     Atlanta  2021-09-13         2
6  Seattle      Newark  2021-09-17         7

In [69]: shipments_agg = shipments.groupby(["Origin", "Destination"]).sum()

In [70]: shipments_agg
Out[70]: 
                     Quantity
Origin  Destination          
Atlanta Chicago             1
        LA                  5
        Seattle            12
Seattle Atlanta             2
        LA                  2
        Newark              7

In [71]: # let's use a cutoff of 4

In [72]: hi_qty_shipments = shipments_agg[shipments_agg["Quantity"] > 4]

In [73]: hi_qty_shipments
Out[73]: 
                     Quantity
Origin  Destination          
Atlanta LA                  5
        Seattle            12
Seattle Newark              7

In [74]: # now re-index the base dataframe and use this multi-index to retrieve what is desired

In [75]: shipments.set_index(["Origin", "Destination"], inplace=True)

In [76]: shipments.loc[hi_qty_shipments.index]
Out[76]: 
                           Date  Quantity
Origin  Destination                      
Atlanta LA           2021-09-09         1
        LA           2021-09-11         4
        Seattle      2021-09-27        12
Seattle Newark       2021-09-17         7
  • Related