I have a data set (df) as follows
Company Col1 Col2 Output
AB 10 20 1
AB 20 22 1
AB 14 12 0
XZ 33 22 1
XZ 43 62 0
I want to train_test_split the data such that if a company is in the test set, it should not be in the training set at all. By which I mean if the first row ( AB, 10, 20,1) is in the test set, the second row ( AB, 20,22,1) should also be in the test set. I know stratify would stratify=df[["Name"] would do the exact opposite of what I want. Is there any built in function to do as such?
P.S. Company column is string
CodePudding user response:
This might be a little verbose and not a generic function, but this approach might work for you:
counts = df.groupby("Company").count()["Output"]
frac = 0.8 # Fraction of the training table, will only be approximated
train_companies = []
i = 0
c = 0
total_count = counts.values.sum()
train_count = total_count * frac
while(c < train_count):
train_companies.append(counts.index[i])
c = c counts.values[i]
i = i 1
c = c counts.values[i]
df_train = df[df['Company'].isin(train_companies)]
df_test = df[~df['Company'].isin(train_companies)]