Home > OS >  sklearn custom transformer with pd.drop()
sklearn custom transformer with pd.drop()

Time:03-24

I've made the following custom transformer with the transform func being :

def transform(self, X):
        data = X.copy()
        data = data.drop(self.columns_dropped,axis=1)
        for col in self.cat_col:
            data[col] = data[col].str.lower()
        data[self.col_to_ordinal] = data[self.col_to_ordinal].fillna("na") 
        data[self.col_to_ordinal] = data[self.col_to_ordinal].replace(self.qc_order)
        
        for x in range(len(self.g1)):
            data = data.replace(self.g1[x], x)
        for x in range(len(self.g2)):
            data = data.replace(self.g2[x], x)
        data = data.set_index(data['Id'])
        data = data.drop('Id',axis=1)
        return data 

I've made it part of a pipeline that goes:

Pipeline([
        ('preprocessor', preprocessor),
        ('model', model)
        ])
preprocessor = ColumnTransformer(transformers=[
    ('cat', cat_pipeline, cat_col),
    ('num', SimpleImputer(), num_col)
    ])
cat_pipeline = Pipeline([
    ('pandas_transform', PandasTransform() ),
    ('cat_encoder', OneHotEncoder(handle_unknown='ignore') )
    ]) 

But afterwards, when I want to run GridSearchCV, it raises error:

"['LotFrontage', 'Alley', 'FireplaceQu', 'PoolQC', 'Fence', 'MiscFeature', 'YearRemodAdd', 'YearBuilt', 'MoSold', 'YrSold', 'GarageType', 'GarageYrBlt', 'GarageFinish', 'GarageArea', 'GarageCond', 'Exterior2nd']
 not found in axis"

(The list is the variable columns_dropped) when I ran the transform func outside the pipeline it works... I tried copying the input inside transform() with df.copy(), in case the original data was being modified and then when the func would run consecutive times, there would be no columns to drop 'cause they would've already been dropped with the last func call, but I still get the error.

Any ideas?

CodePudding user response:

Perhaps you could try the following

def transform(self, X):
    data = X.copy()
    for x in self.columns_dropped:
        if x in data.columns.to_list():
            data = data.drop(x, axis=1)

    for col in self.cat_col:
        data[col] = data[col].str.lower()
    data[self.col_to_ordinal] = data[self.col_to_ordinal].fillna("na") 
    data[self.col_to_ordinal] = data[self.col_to_ordinal].replace(self.qc_order)
    
    for x in range(len(self.g1)):
        data = data.replace(self.g1[x], x)
    for x in range(len(self.g2)):
        data = data.replace(self.g2[x], x)
    data = data.set_index(data['Id'])
    data = data.drop('Id',axis=1)
    return data 

Let me know if that works. It isn't the most elegant solution but from what you've provided and if I've understood the issue correctly that should at least allow you to get past the error.

  • Related