Home > database >  Is there a way to remove some rows in the training set based on values in another column
Is there a way to remove some rows in the training set based on values in another column

Time:05-31

I have a dataframe and I split it into training and testing (80:20). It looks like this:

V1  V2  V3  V4  V5 Target
5   2   34  12  9   1
1   8   24  14  12  0
12  27  4   12  9   0

Then I build a simple regression model and made predictions.

The code worked with me, but my question is that, after I split the data into training and testing. I need to remove (or exclude) some data points or some rows in the training set (removing specific rows in the X_train and their corresponding y_train) based on some conditions or based on values in another column.

For example, I need to remove any row in the training set if V1 > 10.

As results this row in the X_train and its y_train should be deleted:

V1  V2  V3  V4  V5 Target
12  27  4   12  9   0
# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
print("X_train:", X_train.shape)
print("X_test:", X_test.shape)
print("y_train:", y_train.shape)
print("y_test:", y_test.shape)

# Train and fit the model
regressor = LinearRegression()
regressor.fit(X_train, y_train)

# Make prediction
y_pred = regressor.predict(X_test)

I think the way to do it is to extract the indexes for the rows we need to remove using the required condition and then removing them from the x_train and y_train

The suggested questions did not answer my question because here is a different scenario. It did not consider the training and testing set. I need to delete some value rows in the X_train and their corresponding y_train.

CodePudding user response:

if X_train and y_train are numpy arrays, how I suppose, you can simply do:

y_train = y_train[X_train[:,0]<=10]
X_train = X_train[X_train[:,0]<=10]

EDIT

if if X_train is a pandas DataFrame and y_train is a pandas Series:

y_train = y_train[X_train["V1"]<=10]
X_train = X_train.loc[X_train["V1"]<=10]
  • Related