Home > Net >  Wrap model with sklearn interface
Wrap model with sklearn interface

Time:12-21

what is the best way to wrap an existing model with sklearn BaseEstimator interface to be gridsearchCV compatible? The model that I have hasn't neither set_param nor get_params. My approach is the following:

class Wrapper(BaseEstimator):
    
    def __init__(param1, param2):
        self.model = ModelClass(param1, param2)
    
    def fit(data):
        self.model.fit(data)
        return self

    def predict(data):
        return self.model.predict(data)

    def get_params(self, deep=True): # ?
        return self.model.__dict__

    def set_params(self, **parameters): # ?, have I to recreate model?
        for parameter, value in parameters.items():
            setattr(self.model, parameter, value)
        return self
        

CodePudding user response:

In the get_params method, you can return a dictionary of the parameters of the Wrapper instance using the __dict__ attribute. This will allow GridSearchCV to access the parameters of the Wrapper instance and use them for hyperparameter tuning.

Don't forget to add self.param1 = param1 and self.param2 = param2 under __init__ to allow access for get and set.

class Wrapper(BaseEstimator):
    
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2
        self.model = ModelClass(param1, param2)
    
    def fit(self, data):
        self.model.fit(data)
        return self

    def predict(self, data):
        return self.model.predict(data)
    
    def score(self, data):
        return self.model.score(data)

    def get_params(self, deep=True):
        return {'param1': self.param1, 'param2': self.param2}

    def set_params(self, **parameters):
        self.param1 = parameters.get('param1', self.param1)
        self.param2 = parameters.get('param2', self.param2)
        self.model = ModelClass(self.param1, self.param2)
        return self

Example of using GridsearchCV:

from sklearn.model_selection import GridSearchCV

param_grid = {'param1': [1, 10, 100], 'param2': [0.01, 0.1, 1]}

model = Wrapper()
grid_search = GridSearchCV(estimator=model, param_grid=param_grid)
grid_search.fit(X_train, y_train)
test_score = grid_search.score(X_test, y_test)
print(f'Test score: {test_score:.2f}')
  • Related