Home > Back-end >  Calculate cumulative probability (Kaplan-Meier survival function) in a `dplyr` pipeline
Calculate cumulative probability (Kaplan-Meier survival function) in a `dplyr` pipeline

Time:11-17

I'm trying to create Kaplan-Meier life tables using a dplyr pipeline. I'm having trouble calculating the cumulative probability of survival without using a for loop. Here is some example data.

df <- tibble(
  months = c(1, 3, 9, 13, 17, 20),
  n_at_risk = c(10, 8, 7, 5, 3, 2),
  cond_prob_event = c(0.100, 0.125, 0.143, 0.200, 0.333, 0.500),
  cond_prob_surv = c(0.900, 0.875, 0.857, 0.800, 0.667, 0.50)
)

df
# A tibble: 6 × 4
  months n_at_risk cond_prob_event cond_prob_surv
   <dbl>     <dbl>           <dbl>          <dbl>
1      1        10           0.1            0.9  
2      3         8           0.125          0.875
3      9         7           0.143          0.857
4     13         5           0.2            0.8  
5     17         3           0.333          0.667
6     20         2           0.5            0.5  

In this case, the cumulative probability of survival is calculated as the product of the previous (lagged) cumulative probability of survival and the current conditional probability of survival. I can get the answer I'm looking for using a for loop:

out <- vector(mode = "numeric", 6)

for (i in seq_along(df$cond_prob_surv)) {
  if (i == 1) {
    out[i] <- df$cond_prob_surv[i]
  } else {
    out[i] <- out[i - 1] * df$cond_prob_surv[i]
  }
}

df$cum_prob_survival <- out
df
# A tibble: 6 × 5
  months n_at_risk cond_prob_event cond_prob_surv cum_prob_survival
   <dbl>     <dbl>           <dbl>          <dbl>             <dbl>
1      1        10           0.1            0.9               0.9  
2      3         8           0.125          0.875             0.788
3      9         7           0.143          0.857             0.675
4     13         5           0.2            0.8               0.540
5     17         3           0.333          0.667             0.360
6     20         2           0.5            0.5               0.180

However, for reasons, I'd really like to find a dplyr only solution. Any help is greatly appreciated!

CodePudding user response:

We may need cumprod here

library(dplyr)
df <- df %>% 
    mutate(cum_prob_survival = cumprod(cond_prob_surv))

-output

df
# A tibble: 6 × 5
  months n_at_risk cond_prob_event cond_prob_surv cum_prob_survival
   <dbl>     <dbl>           <dbl>          <dbl>             <dbl>
1      1        10           0.1            0.9               0.9  
2      3         8           0.125          0.875             0.788
3      9         7           0.143          0.857             0.675
4     13         5           0.2            0.8               0.540
5     17         3           0.333          0.667             0.360
6     20         2           0.5            0.5               0.180

Or another option is accumulate

library(purrr)
df <- df %>% 
     mutate(cum_prob_survival = accumulate(cond_prob_surv, `*`))

-output

df
# A tibble: 6 × 5
  months n_at_risk cond_prob_event cond_prob_surv cum_prob_survival
   <dbl>     <dbl>           <dbl>          <dbl>             <dbl>
1      1        10           0.1            0.9               0.9  
2      3         8           0.125          0.875             0.788
3      9         7           0.143          0.857             0.675
4     13         5           0.2            0.8               0.540
5     17         3           0.333          0.667             0.360
6     20         2           0.5            0.5               0.180

CodePudding user response:

A base R option using Reduce

transform(
  df,
  cum_prob_survival = Reduce(`*`, cond_prob_surv, accumulate = TRUE)
)

gives

  months n_at_risk cond_prob_event cond_prob_surv cum_prob_survival
1      1        10           0.100          0.900         0.9000000
2      3         8           0.125          0.875         0.7875000
3      9         7           0.143          0.857         0.6748875
4     13         5           0.200          0.800         0.5399100
5     17         3           0.333          0.667         0.3601200
6     20         2           0.500          0.500         0.1800600
  • Related