Home > front end >  How do you split a pandas multiindex dataframe into train/test sets?
How do you split a pandas multiindex dataframe into train/test sets?

Time:07-11

I have a multi-index pandas dataframe consisting of a date element and an index representing store locations. I want to split into training and test sets based on the time index. So, everything before a certain time being my training data set and after being my testing dataset. Below is some code for a sample dataset.

import pandas as pd
import stats
data = stats.poisson(mu=[5,2,1,7,2]).rvs([60, 5]).T.ravel()
dates = pd.date_range('2017-01-01', freq='M', periods=60)
locations = [f'location_{i}' for i in range(5)]
df_train = pd.DataFrame(data, index=pd.MultiIndex.from_product([dates, locations]), columns=['eaches'])
df_train.index.names = ['date', 'location']

I would like df_train to represent everything before 2021-01 and df_test to represent everything after.

I've tried using df[df.loc['dates'] > '2020-12-31'] but that yielded errors.

CodePudding user response:

You have 'date' as an index, that's why your query doesn't work. For index, you can use:

df_train.loc['2020-12-31':]

That will select all rows, where df_train >= '2020-12-31'. So, if you would like to choose only rows where df_train > '2020-12-31', you should use df_train.loc['2021-01-01':]

CodePudding user response:

You can't do df.loc['dates'] > '2020-12-31' because df.loc['dates'] still represents your numerical data, and you can't compare those to a string.

You can use query which works with index:

df.query('date>"2020-12-31"')
  • Related