Home > Enterprise >  purrr::map returns multiple rows instead of one
purrr::map returns multiple rows instead of one

Time:12-02

I am trying to run linear regression through functional programming. However, I am not able to get the output successfully. purrr:::map returns multiple rows per nested list instead of one row.

#perform linear regression for each cylinder
mtcars_result <- mtcars%>%
  nest(-cyl)%>%
  mutate(model=map(data,~ lm(as.formula("mpg~disp"),data=.)),
         n=map(data,~nrow(.)))

#predict values
mtcars_result$predict <- 1:3

#helper function to obtain predict values
get_prediction <- function(m,varname,predict){
  predictdata <- data.frame(predict)
  names(predictdata) <- c(varname)
  predict(m,newdata=predictdata,interval="confidence",level=0.95)
}

#prediction, notice it returns three rows per nested list
mtcars_result2 <- mtcars_result%>%mutate(predicted_values=map(model,get_prediction,"disp",predict))
mtcars_result2$predicted_values

[[1]]
       fit      lwr      upr
1 19.08559 11.63407 26.53712
2 19.08920 11.67680 26.50160
3 19.09280 11.71952 26.46609

[[2]]
       fit      lwr      upr
1 40.73681 32.68945 48.78418
2 40.60167 32.62715 48.57619
3 40.46653 32.56482 48.36824

[[3]]
       fit      lwr      upr
1 22.01316 14.74447 29.28186
2 21.99353 14.74479 29.24227
3 21.97390 14.74511 29.20268

My attempt:

I notice the main issue is probably due to the predict argument in get_prediction(). When I run this version of get_prediction()

get_prediction <- function(m,varname,predict){
  predict_global<<-predict
  predictdata <- data.frame(predict)
  names(predictdata) <- c(varname)
  predict(m,newdata=predictdata,interval="confidence",level=0.95)
}


> predict_global
[1] 1 2 3

Therefore, my instinct is to use rowwise(), but it ends up with an error:

mtcars_result2 <- mtcars_result%>%rowwise()%>%mutate(predicted_values=map(model,get_prediction,"disp",predict))

Error in UseMethod("predict") : 
  no applicable method for 'predict' applied to an object of class "c('double', 'numeric')" 

Can anyone shed some lights for me? maybe we can use purrr::pmap instead of purrr::map?

CodePudding user response:

One option is to use imap and subset the predict column with the index .y.

mtcars_result %>%
  mutate(predicted_values = imap(model, ~ get_prediction(.x, "disp", predict[.y])))

Alternatively we can use rowwise()

mtcars_result %>% 
  rowwise() %>% 
  mutate(predicted_values = list(get_prediction(model, "disp", predict)))

#> # A tibble: 3 × 6
#> # Rowwise: 
#>     cyl data               model  n         predict predicted_values
#>   <dbl> <list>             <list> <list>      <int> <list>          
#> 1     6 <tibble [7 × 10]>  <lm>   <int [1]>       1 <dbl [1 × 3]>   
#> 2     4 <tibble [11 × 10]> <lm>   <int [1]>       2 <dbl [1 × 3]>   
#> 3     8 <tibble [14 × 10]> <lm>   <int [1]>       3 <dbl [1 × 3]>

Created on 2022-12-01 with reprex v2.0.2

CodePudding user response:

Really similar to the accepted answer but we can also use purrr::map2 and switch the order of arguments in get_prediction

get_prediction <- function(m,predict,varname){
  predictdata <- data.frame(predict)
  names(predictdata) <- c(varname)
  predict(m,newdata=predictdata,interval="confidence",level=0.95)
}

#prediction, notice there are duplicates
mtcars_result2 <- mtcars_result%>%mutate(predicted_values=map2(model,predict,get_prediction,"disp"))
mtcars_result2$predicted_values

[[1]]
       fit      lwr      upr
1 19.08559 11.63407 26.53712

[[2]]
       fit      lwr      upr
1 40.60167 32.62715 48.57619

[[3]]
      fit      lwr      upr
1 21.9739 14.74511 29.20268

CodePudding user response:

Another possibility could be splitting by cyl and using map2:

library(tidyverse)

options(pillar.sigfig = 7)

mtcars %>%
  split(f = .$cyl) %>% 
  map2_dfr(c(2, 1, 3), 
           ~lm(mpg ~ disp, data = .) %>% 
             get_prediction("disp", .y) %>% 
             as_tibble(),
           .id = "cyl")

This returns a tibble for the predicted values 2, 1, 3

# A tibble: 3 × 4
  cyl        fit      lwr      upr
  <chr>    <dbl>    <dbl>    <dbl>
1 4     40.60167 32.62715 48.57619
2 6     19.08559 11.63407 26.53712
3 8     21.97390 14.74511 29.20268
  • Related