I am trying to use Linear Regression, to predict salary in USD. I have the following data:
Data:
- 607 records
- Numerical columns: year, salary, salary in USD
- Categorical columns: experience, type, residence, currency, remote work, company location, and company size.
- Target: salary in USD
Preprocessing dataset:
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
# Columns to drop:
drop_cols = ['Currency', 'Company location', 'Salary', 'Title']
# Attributes of interest
num_attributes = ['Year']
one_hot_attributes = ['Experience', 'Type', 'Remote work', 'Residence', 'Company size']
# Drop columns:
data.drop(drop_cols, 1, inplace=True)
# Setup transformer for column:
preprocessor = ColumnTransformer([
('nums', StandardScaler(), num_attributes),
('one_hot', OneHotEncoder(drop='first', sparse=False), one_hot_attributes)],
remainder='passthrough')
Pipe:
from sklearn.pipeline import Pipeline
pipe = Pipeline(steps =[
('preprocessor', preprocessor),
('model', LinearRegression()),
])
pipe.fit(X_train, y_train)
Perform prediction:
prediction = pipe.predict(X_test)
Error:
ValueError: Found unknown categories ['IR', 'HN', 'MT', 'PH', 'NZ', 'CZ', 'MD'] in column 3 during transform
CodePudding user response:
Your test data contains locations that were never seen during training, but you are using OneHotEncoder. How do you want to represent these never seen before locations?
You need to set handle_unknown
parameter in OneHotEncoder
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html. By default this is set to error
. Instead you can choose for example you can set it to ignore
(will encode it as 0s) or infrequent_if_exist
which will create a dummy category (then you need to set min_frequency
so that rare values from training data are mapped there)
E.g.
# Setup transformer for column:
preprocessor = ColumnTransformer([
('nums', StandardScaler(), num_attributes),
('one_hot', OneHotEncoder(drop='first',
sparse=False,
handle_unknown='ignore'), one_hot_attributes)],
remainder='passthrough')