Home > database >  function to decide which classification method to use with optional parameters - Python
function to decide which classification method to use with optional parameters - Python

Time:04-05

My goal is to include a "mod_type" param that indicates the type of model to run, either a decision tree or knn, using kwargs to let the user pass in the optional keyword params "k" for knn and "max_depth" for decision tree. If the user passes these in, when initializing the model use them as appropriate. Return the model object.

For that, I'm using below function:

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier


def my_classification(x,y,mod_type,**kwargs):
    if mod_type == "dt":
        if max_d in kwargs.keys():
            dt = DecisionTreeClassifier(max_depth=max_d.values())
            dt.fit(x,y)
            return dt
        else:
            dt = DecisionTreeClassifier()
            dt.fit(x,y)     
            return dt       
    elif mod_type == "knn":
        if k in kwargs.keys():
            knn = KNeighborsClassifier(k.values())
            knn.fit(x,y)
            return knn
        else:
            knn = KNeighborsClassifier()
            knn.fit(x,y)
            return knn
    else:
        print("unavailable type")

iris = load_iris()
x = pd.DataFrame(iris.data)
y = iris.target
my_classification(x,y,"dt")

Understanding kwargs wasn't easy but I think I might have it now, error it's giving me is: NameError: name 'max_d' is not defined. I've tried creating them prior the function and then changing those within but it prints the model without any alteration.

Could someone please help?

CodePudding user response:

kwargs is a dictonary with the names of the arguments as keys, and their values as values.

This is how you can use it:

if mod_type == "dt":
    if "max_d" in kwargs:
        dt = DecisionTreeClassifier(max_depth=kwargs["max_d"])
        ...
elif mod_type == "knn":
    if "k" in kwargs:
        knn = KNeighborsClassifier(kwargs["k"])
        ...

CodePudding user response:

Keys in a dict object equate to strings.

if 'max_d' in kwargs.keys():

without that being a string, it's looking for an object named max_d in the kwargs.keys() list.

  • Related