Home > Software design >  Sklearn SVM custom rbf kernel function
Sklearn SVM custom rbf kernel function

Time:11-14

I was creating a custom rbf function for the SVC class of sklearn as following:

def rbf_kernel(x, y, gamma):
    dis = np.sqrt(((x.reshape(-1, 1)) - y.reshape(1, -1)) ** 2)
    return np.exp(-(gamma*dis)**2)


def eval_kernel(kernel):
    model = SVC(kernel=kernel, C=C, gamma=gamma, degree=degree, coef0=coef0)
    model.fit(X_train, y_train)
    X_test_predict = model.predict(X_test)
    acc = (X_test_predict == y_test).sum() / y_test.shape[0]
    return acc

for k1, k2 in [('rbf', lambda x, y: rbf_kernel(x, y, gamma))]:
    acc1 = eval_kernel(k1)
    acc2 = eval_kernel(k2)

    assert(abs(acc1 - acc2) < eps)

The shape of X_train is (396, 10), y_train is (396, 10) and X_test is (132, 10). However, when I try to run it, I get an error saying:

ValueError: X.shape[1] = 3960 should be equal to 396, the number of samples at training time

It seems the errors are due to the difference in the dimension of X_test and X_train, but is there any way to fix this error?

Thank you in advance!

CodePudding user response:

Your rbf kernel is written incorrectly. You need to return a matrix that is (n_samples, n_samples). In your code you basically unravelled everything, hence the error. You can refer to the actual code for rbf_kernel used by sklearn , and if we insert that it will work:

from sklearn.datasets import make_classification
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split

X,y = make_classification(528)

X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.25)

def my_kernel(X, Y, gamma=0.1):
    K = euclidean_distances(X, Y, squared=True)
    K *= -gamma
    np.exp(K, K)  # exponentiate K in-place
    return K

def eval_kernel(kernel):
    model = SVC(kernel=kernel,gamma=0.1)
    model.fit(X_train, y_train)
    X_test_predict = model.predict(X_test)
    acc = (X_test_predict == y_test).sum() / y_test.shape[0]
    return acc

eval_kernel('rbf')
0.8409090909090909

eval_kernel(my_kernel)
0.8409090909090909
  • Related