Home > Software engineering >  How to vectorize a function with multiple possible outputs in an R dataframe
How to vectorize a function with multiple possible outputs in an R dataframe

Time:10-21

I am trying to apply a function to a data.frame column in R that detects whether or not specific string values exist. There are various string patterns that each constitute their own categorization. The function should create a new column that provides said classification (dat$id_class) based off the string in the dat$id column.

I am relying on the stringr and dplyr packages to do this. Specifically, I'm using dplyr::mutate to apply the function.

This code runs & produces the exact results I'm looking for, but I'm looking for a faster solution (if one exists). This is obviously a small-scale example with a limited dataset, and this same approach on my very large dataset is taking much longer than desired.

library(stringi)
library(dplyr)
library(stringr)

dat <- data.frame(
            id = c(
                    sprintf("%s%s%s", stri_rand_strings(1000000, 5, '[A-Z]'),
                    stri_rand_strings(5, 4, '[0-9]'), stri_rand_strings(5, 1, '[A-Z]'))
            ))


classify <- function(x){
  if(any(stringr::str_detect(x,pattern = c('AA','BB')))){
    'class_1'
  } else if (any(stringr::str_detect(x,pattern = c('AB','BA')))){
    'class_2'
  } else {
    'class_3'
  }
}

dat <- dat %>% rowwise() %>% mutate(id_class = classify(id))

There's a great chance this has already been answered, and that I'm just not looking in the right place, but it's worth a shot.

Any assistance appreciated!

CodePudding user response:

Use case_when which is vectorized and instead of doing rowwise with if/else change the pattern with OR (|)

library(stringr)
library(dplyr)
system.time({
 dat1 <- dat %>%
    mutate(id_class = case_when(str_detect(id, 'AA|BB') ~'class_1', 
            str_detect(id, 'AB|BA') ~ 'class_2', TRUE ~ 'class_3'))
})
#  user  system elapsed 
#  0.460   0.036   0.493 

The timings based on OP's function

system.time({dat2 <- dat %>%
               rowwise() %>% 
                mutate(id_class = classify(id))
  })
#  user  system elapsed 
# 31.927   0.303  32.891 

-checking the outputs

> all.equal(dat1, as.data.frame(dat2), check.attributes = FALSE)
[1] TRUE
  • Related