Home > Blockchain >  Generate a case when function with an indeterminated number of conditions
Generate a case when function with an indeterminated number of conditions

Time:07-26

I have a data frame df with a column x and I would like to create another column y that is determined by some cut points which I would like to provide to a function as an input.

The function should take an indeterminated number of cut points such as in the example below 0, 5, 10, 15. The function should take another input with the corresponding values for the y variable. In the example 3, 6, 9, 12.

library(tidyverse)

df <- tibble(x = 0:20)

mutate(df, y = case_when(x >= 0 & x < 5 ~ 3, 
                         x >= 5 & x < 10 ~ 6,
                         x >= 10 & x < 15 ~ 9, 
                         x >= 15 ~ 12))

my_cut_points <- c(0, 5, 10, 15)
my_corr_values <- c(3, 6, 9, 12)

my_func <- function(df, cut_points = my_cut_points, corr_values = my_corr_values) {
  ...
}

The solution needs to take input vectors of various lengths. However the pattern is always the same. The lower cut point sould be included in the interval and the upper cut point should not be included. It does not need to use case_when. It could also use cut or whatever provides the correct result.

CodePudding user response:

I would use cut for this.

my_func <- function(df, cut_points=my_cut_points, corr_values=my_corr_values) {
  
  interv <- cut(df$x, c(cut_points, Inf), right=F, labels=F)
  mutate(df, y=corr_values[interv])
}

CodePudding user response:

You may use:

my_func <- function(df, cut_points, corr_values) {
    cut_points[length(cut_points)] <- as.integer(.Machine$integer.max)
    mutate(df, y = case_when(
        length(cut_points) >= 2 &
            x >= cut_points[1] & x < cut_points[2] ~ corr_values[1], 
        length(cut_points) >= 3 &
            x >= cut_points[2] & x < cut_points[3] ~ corr_values[2], 
        length(cut_points) >= 4 &
            x >= cut_points[3] & x < cut_points[4] ~ corr_values[3], 
        length(cut_points) >= 5 &
            x >= cut_points[4] & x < cut_points[5] ~ corr_values[4], 
        length(cut_points) >= 6 &
            x >= cut_points[5] & x < cut_points[6] ~ corr_values[5], 
        length(cut_points) >= 7 &
            x >= cut_points[6] & x < cut_points[7] ~ corr_values[6], 
        length(cut_points) >= 8 &
            x >= cut_points[7] & x < cut_points[8] ~ corr_values[7], 
        length(cut_points) >= 9 &
            x >= cut_points[8] & x < cut_points[9] ~ corr_values[8], 
        length(cut_points) >= 10 &
            x >= cut_points[9] & x < cut_points[10] ~ corr_values[9]))
}

Explanation: The above function hard codes a maximum of 10 range checks inside the case_when(). It first assigns the maximum integer value to the last entry in the incoming cut_points array. This has the effect of making the last range check in case_when essentially behaving like an else. Also, we check the length of the cut_points array before each condition. For input cuts which are only e.g. of length 3, then only the first two conditions would ever be checked.

  •  Tags:  
  • r
  • Related