Home > OS >  train test split by specific class count
train test split by specific class count

Time:06-14

I have data that includes X features and Y - binary class ( 0 or 1 ) My problem is imbalanced so I want to make sure my y_test after the split will contain about 50% of the samples classified as 1 after the split.

I tried to use train_test_split stratify but my 0/1 ratio is below 50%, doesn't work.

Any suggestions?

CodePudding user response:

You shouldn't affect train-test split because of imbalance. Train-test split has to correspond to actual testing distribution. If your problem is imbalanced - so should your test set be!

What you can change though is a metric you use and/or training regime, e.g.:

Both these will technically same the same effect of treating classes in an equally important way, but you do not have to "split things" differently.

And if you really insist on splitting data in such an odd way just do it by hand

import numpy as np

def odd_split(X, y, minority_class=1, minority_test_size=0.1):
  minority_indices = np.where(y==minority_class)[0]
  majority_indices = np.where(y!=minority_class)[0]
    
  n = max(1, int(minority_test_size* len(minority_indices)))
  selected = np.random.choice(range(len(minority_indices)), n, replace=False)
  test_minority_indices = minority_indices[selected]
  assert (y[test_minority_indices] == minority_class).all()
  
  selected = np.random.choice(range(len(majority_indices)), n, replace=False)
  test_majority_indices = majority_indices[selected]
  assert (y[test_majority_indices ] != minority_class).all()
  
  test_indices = np.concatenate((test_minority_indices, test_majority_indices))
  train_indices = np.array([i for i in range(len(y)) if i not in set(test_indices)])
  
  return X[train_indices], y[train_indices], X[test_indices], y[test_indices]
  

from collections import Counter  
X = np.random.normal(size=(1000, 2))  
y = np.random.choice([0, 1], p=[0.9, 0.1], size=1000)
print('Whole', Counter(y))

X_train, y_train, X_test, y_test = odd_split(X, y)
print('Train', Counter(y_train))
print('Test', Counter(y_test))

Which gives

Whole Counter({0: 886, 1: 114})
Train Counter({0: 875, 1: 103})
Test Counter({1: 11, 0: 11})
  • Related