Home > Enterprise >  sklean Standard Scaler Ridge Pipeline
sklean Standard Scaler Ridge Pipeline

Time:12-09

I'm trying to standardize features and then run a ridge regression.

As provided, the two answers are different.

When I set ridge=0, the answers are the same. When I remove StandardScaler and Dn, the answers are also the same.

I can't figure out how to reconcile the two versions (raw and using sklearn).

Thanks for your help

from sklearn.linear_model import Ridge
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import numpy as np
np.random.seed(0)

x = np.random.randn(100, 3)
y = np.random.randn(100, 2)

xx = x.T @ x
xy = x.T @ y
Dn = np.diag(1 / np.sqrt(np.diag(xx)))

ridge = 1

xx = Dn @ xx @ Dn
xy = Dn @ xy
beta_raw = Dn @ np.linalg.solve(xx   np.eye(len(xx)) * ridge, xy)
f_raw = x @ beta_raw

model = Pipeline([("scaler", StandardScaler(with_mean=False)), ("regression", Ridge(ridge, fit_intercept=False))])
trained_model = model.fit(x, y)
f_ml = trained_model.predict(x)

print(f_ml[:3] / f_raw[:3])

CodePudding user response:

You are scaling by different values, check:

np.diag(Dn)
array([0.09699826, 0.10123938, 0.1016412 ])

model.steps[0][1].scale_
array([1.02603414, 0.98202661, 0.97598415])

Your standard deviation is the sqrt of the diagonal of the covariance matrix. Even though you are not centering your matrix, you still need to subtract the mean to get the covariance. See this post for more information

So if we do it correctly:

x_m = x.mean(axis=0)
x_cov = np.dot((x - x_m).T, x - x_m) / (x.shape[0])
Dn = np.diag(1 / np.sqrt(np.diag(x_cov)))

xx = x.T @ x
xy = x.T @ y

ridge = 1

xx = Dn @ xx @ Dn
xy = Dn @ xy
beta_raw = Dn @ np.linalg.solve(xx   np.eye(len(xx)) * ridge, xy)
f_raw = x @ beta_raw

model = Pipeline([("scaler", StandardScaler(with_mean=False)), ("regression", Ridge(ridge, fit_intercept=False))])
trained_model = model.fit(x, y)
f_ml = trained_model.predict(x)

print(f_ml[:3] / f_raw[:3])

[[1. 1.]
 [1. 1.]
 [1. 1.]]
  • Related