Home > Blockchain >  neural network with R package nnet: rubbish prediction due to overfitting?
neural network with R package nnet: rubbish prediction due to overfitting?

Time:06-28

Trying to figure out if I have an R problem or a general neural net problem.

Say I have this data:

set.seed(123)
n = 1e3
x = rnorm(n)
y = 1   3*sin(x/2)   15*cos(pi*x)   rnorm(n = length(x))
df = data.frame(y,x)
df$train = sample(c(TRUE, FALSE), length(y), replace=TRUE, prob=c(0.7,0.3))
df_train = subset(df, train = TRUE)
df_test = subset(df, train = FALSE)

then you train the neural net and it looks good on the holdout:

library(nnet)
nn = nnet(y~x, data = df_train, size = 60, linout=TRUE) 
yhat_nn = predict(nn, newdata = df_test)
plot(df_test$x,df_test$y)
points(df_test$x, yhat_nn, col = 'blue')

enter image description here

Ok, so then I thought, let's just generate new data and then predict using the trained net. But the predictions are way off:

x2 = rnorm(n)
y2 = 1   3*sin(x2/2)   15*cos(pi*x2)   rnorm(n = length(x2))
df2 = data.frame(y2,x2)
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

enter image description here

Is this because I overfitted to the training set? I thought by splitting the original data into test-train I would avoid overfitting.

CodePudding user response:

The fatal issue is that your new data frame, df2, does not have the correct variable names. As a result, predict.nnet can not find the right values.

names(df)
#[1] "y"     "x"     "train"

names(df2)
#[1] "y2"     "x2"

Be careful when you construct a data frame for predict.

## the right way
df2 <- data.frame(y = y2, x = x2)

## and it solves the mystery
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')

prediction on df2


Another minor issue is your use of subset. It should be

## not train = TRUE or train = FALSE
df_train <- subset(df, train == TRUE) ## or simply subset(df, train)
df_test <- subset(df, train == FALSE) ## or simply subset(df, !train)

This has interesting effect:

nrow(subset(df, train == TRUE))
#[1] 718

nrow(subset(df, train = TRUE))  ## oops!!
#[1] 1000

The complete R session

set.seed(123)
n = 1e3
x = rnorm(n)
y = 1   3*sin(x/2)   15*cos(pi*x)   rnorm(n = length(x))
df = data.frame(y,x)
df$train = sample(c(TRUE, FALSE), length(y), replace=TRUE, prob=c(0.7,0.3))
df_train = subset(df, train == TRUE)  ## fixed
df_test = subset(df, train == FALSE)  ## fixed
library(nnet)
nn = nnet(y~x, data = df_train, size = 60, linout=TRUE) 
yhat_nn = predict(nn, newdata = df_test)
plot(df_test$x,df_test$y)
points(df_test$x, yhat_nn, col = 'blue')
x2 = rnorm(n)
y2 = 1   3*sin(x2/2)   15*cos(pi*x2)   rnorm(n = length(x2))
df2 = data.frame(y = y2, x = x2)  ## fixed
plot(df2$x, df2$y)
points(df2$x, predict(nn, newdata = df2), col = 'blue')
  • Related