I am trying to get the 5-fold cross validation error of a model created with TreeBagger using the function crossval but I keep getting an error
Error using crossval>evalFun The function 'regrTree' generated the following error: Too many input arguments.
My code is below. Can anyone point me in the right direction? Thanks
%Random Forest
%%XX is training data matrix, Y is training labels vector
XX=X_Tbl(:,2:end);
Forest_Mdl = TreeBagger(1000,XX,Y,'Method','regression');
err_std = crossval('mse',XX,Y,'Predfun',@regrTree, 'kFold',5);
function yfit_std = regrTree(Forest_Mdl,XX)
yfit_std = predict(Forest_Mdl,XX);
end
CodePudding user response:
Reading the documentation helps a lot!:
The function has to be defined as:
(note that it takes 3 arguments, not 2)
function yfit = myfunction(Xtrain,ytrain,Xtest) % Calculate predicted response ... end
Xtrain — Subset of the observations in X used as training predictor data. The function uses Xtrain and ytrain to construct a classification or regression model.
ytrain — Subset of the responses in y used as training response data. The rows of ytrain correspond to the same observations in the rows of Xtrain. The function uses Xtrain and ytrain to construct a classification or regression model.
Xtest — Subset of the observations in X used as test predictor data. The function uses Xtest and the model trained on Xtrain and ytrain to compute the predicted values yfit.
yfit — Set of predicted values for observations in Xtest. The yfit values form a column vector with the same number of rows as Xtest.