I have data that includes X features and Y - binary class ( 0 or 1 ) My problem is imbalanced so I want to make sure my y_test after the split will contain about 50% of the samples classified as 1 after the split.
I tried to use train_test_split stratify but my 0/1 ratio is below 50%, doesn't work.
Any suggestions?
CodePudding user response:
You shouldn't affect train-test split because of imbalance. Train-test split has to correspond to actual testing distribution. If your problem is imbalanced - so should your test set be!
What you can change though is a metric you use and/or training regime, e.g.:
- Use balanced accuracy instead of accuracy for evaluation
- Use class reweighting during training
Both these will technically same the same effect of treating classes in an equally important way, but you do not have to "split things" differently.
And if you really insist on splitting data in such an odd way just do it by hand
import numpy as np
def odd_split(X, y, minority_class=1, minority_test_size=0.1):
minority_indices = np.where(y==minority_class)[0]
majority_indices = np.where(y!=minority_class)[0]
n = max(1, int(minority_test_size* len(minority_indices)))
selected = np.random.choice(range(len(minority_indices)), n, replace=False)
test_minority_indices = minority_indices[selected]
assert (y[test_minority_indices] == minority_class).all()
selected = np.random.choice(range(len(majority_indices)), n, replace=False)
test_majority_indices = majority_indices[selected]
assert (y[test_majority_indices ] != minority_class).all()
test_indices = np.concatenate((test_minority_indices, test_majority_indices))
train_indices = np.array([i for i in range(len(y)) if i not in set(test_indices)])
return X[train_indices], y[train_indices], X[test_indices], y[test_indices]
from collections import Counter
X = np.random.normal(size=(1000, 2))
y = np.random.choice([0, 1], p=[0.9, 0.1], size=1000)
print('Whole', Counter(y))
X_train, y_train, X_test, y_test = odd_split(X, y)
print('Train', Counter(y_train))
print('Test', Counter(y_test))
Which gives
Whole Counter({0: 886, 1: 114})
Train Counter({0: 875, 1: 103})
Test Counter({1: 11, 0: 11})