Home > Blockchain >  CatBoost randomized_search with train/test splits
CatBoost randomized_search with train/test splits

Time:06-29

The CatBoost documentation says the randomized_search method can accept train and test splits via the cv parameter, instead of defining a cross validation approach. To do this, one should provide:

An iterable yielding train and test splits as arrays of indices.

How do we define this object?

As a broken example, say my feature dataset has 10 rows. I want to use the first 5 rows for training, and the last 5 rows for validation/testing.

I extract the index values

train_index = X[0:5].index
test_index = X[5:10].index

I supply the indexes to the randomized_search method

a_search = model.randomized_search(param_distributions=params, 
                                   X = X,
                                   y = y,
                                   n_iter=5,
                                   cv={train_index,test_index})

This set that I provide in cv={train_index,test_index} is a non-starter, as it's not iterable, but I am at a loss as to how such an iterable should look. I simply want to define which rows of X and y should be used for training, and which for testing. The goal is to speed up training by dispensing with cross validation, and using a dedicated validation dataset.

CodePudding user response:

You could use a list of tuples:

cv=[(train_index, test_index)]

Example:

import pandas as pd
import numpy as np
from catboost import CatBoost

# generate the data
data = np.random.normal(loc=0, scale=1, size=(10, 3))
labels = np.mean(a=data, axis=1)

df = pd.DataFrame(
    data=np.hstack([data, np.expand_dims(labels, axis=1)]),
    columns=['x1', 'x2', 'x3', 'y']
)

# split the data
train_index = df.index[:5]
test_index = df.index[5:]

# instantiate the model
model = CatBoost()

# tune the model
results = model.randomized_search(
    param_distributions={
        'depth': [2, 3, 4],
        'iterations': [5, 6, 7]
    },
    X=df[['x1', 'x2', 'x3']],
    y=df['y'],
    cv=[(train_index, test_index)],
)
  • Related