Home > Software design >  How can I pass weights in a variable to rpart?
How can I pass weights in a variable to rpart?

Time:09-28

For some reason, the rpart from the rpart package can't see a variable defined in the context from which it is called. You can see in the reprex below that wts is defined just before the call to rpart, but when I call rpart, I get the error "object 'wts' not found".

If I omit the weights argument, there is no problem.

library(rpart)

data(mpg, package = "ggplot2")

scale <- function(x) {
    x/sum(x)
}

fit_rpart <- function(formula, data, iters = 10) {
    data <- as.data.frame(data)
    models <- list()

    for (i in 1:iters) {
        wts <- scale(runif(nrow(data)))
        print(head(wts))
        models[[i]] <- rpart(formula = formula,
                             data = data, weights = wts,
                             method = "class")
    }
    return(models)
}

results <- fit_rpart(cyl == 4  ~ drv   cty   fl   class, mpg)
#> [1] 0.0072092177 0.0059019498 0.0007893446 0.0038617957 0.0067420603
#> [6] 0.0076892493
#> Error in eval(extras, data, env): object 'wts' not found

Created on 2022-09-28 by the reprex package (v2.0.1)

CodePudding user response:

It seems the answer is similar to that provided here. (Similar, but not the same, as this issue relates to rpart, not lm. Not sure how to generalize the question to cover both.)

If I put these two lines at the top of fit_rpart, the code runs fine.

formula <- as.formula(formula)
environment(formula) <- environment()

CodePudding user response:

It seems to be a name scoping problem as you said. You can also use <<- to pass an affectation to the parent environment.

(Also, data and formula are already existing objects in R so to be clear one could prefer to rename them).

library(rpart)

data(mpg, package = "ggplot2")

scale <- function(x) {
  x/sum(x)
}

fit_rpart <- function(f, d, iters = 10) {
  
  d <- as.data.frame(d)
  models <- list()

  for (i in 1:10) {
    wts <<- scale(runif(nrow(d)))
    print(head(wts))
    models[[i]] <- rpart(formula = f,
                         data = d,
                         weights = wts,
                         method = "class")
  }
  return(models)
}

results <- fit_rpart(cyl == 4  ~ drv   cty   fl   class, mpg)

# [1] 0.0007921094 0.0039999229 0.0026862458 0.0018832820 0.0006866826 0.0076998391
# [1] 0.005962060 0.006942240 0.001572535 0.009360314 0.005485438 0.008135806
# [1] 0.007224990 0.004209857 0.007706282 0.007071345 0.003784652 0.006335056
# [1] 0.004098536 0.003289626 0.006710783 0.007364727 0.003099702 0.007693150
# [1] 0.005154814 0.001212012 0.001169259 0.005829825 0.004401704 0.004269959
# [1] 0.001262500 0.003705485 0.007466314 0.005450551 0.001292365 0.007920012
# [1] 0.0007083308 0.0033698483 0.0073883706 0.0013445097 0.0068669108 0.0048488413
# [1] 0.005139870 0.002053938 0.002125759 0.006488419 0.007129400 0.006937384
# [1] 0.006957664 0.006296365 0.000640707 0.008121049 0.008014404 0.007706194
  • Related