Home > Back-end >  'DecisionTree' object has no attribute 'criterion'
'DecisionTree' object has no attribute 'criterion'

Time:01-04

There is an error 'DecisionTree' object has no attribute 'criterion' in line reg = GridSearchCV....

The task is the following:

Using 5-fold cross-validation, find the optimal value for the max_depth and criterion parameters. For the max_depth parameter use range(2, 9) and for criterion use {'variance', 'mad_median'} . Quality criterion scoring='neg_mean_squared_error' .

class DecisionTree(BaseEstimator):
    
    def __init__(self, max_depth=np.inf, min_samples_split=2, 
                 criterion='gini', debug=False):
      params = {'max_depth':max_depth,
                 'min_samples_split':min_samples_split,
                 'criterion':criterion,
                 'debug':debug}
    
    def fit(self, X, y):
        pass
        
    def predict(self, X):
        pass
        
    def predict_proba(self, X):
        pass


tree_params = {'max_depth': list(range(2,9)),
              'criterion':['variance','mad_median']}

reg = GridSearchCV(DecisionTree(), tree_params,
                  cv=5, scoring='neg_mean_squared_error',n_jobs=8)

reg.fit(X_train, y_train)

How to fix it?

CodePudding user response:

I believe its because the parameters in the constructor are not correctly initialized. You can add this line after your dictionary:

  for key, value in params.items():
        setattr(self, key, value)

It will look like this:

class DecisionTree(BaseEstimator):

def __init__(self, max_depth=np.inf, min_samples_split=2, 
             criterion='gini', debug=False):
  params = {'max_depth':max_depth,
             'min_samples_split':min_samples_split,
             'criterion':criterion,
             'debug':debug}
  for key, value in params.items():
        setattr(self, key, value)

def fit(self, X, y):
    pass
    
def predict(self, X):
    pass
    
def predict_proba(self, X):
    pass

Or, the "usual" way:

class DecisionTree(BaseEstimator):

def __init__(self, max_depth=np.inf, min_samples_split=2, 
             criterion='gini', debug=False):
  self.max_depth = max_depth
  self.min_samples_split = min_samples_split
  self.criterion = criterion
  self.debug = debug
  
def fit(self, X, y):
    pass
    
def predict(self, X):
    pass
    
def predict_proba(self, X):
    pass
  • Related