I want to write a generalized weighted_summarise()
function that will automatically parse and transform user-invoked function calls of the form:
data %>% weighted_summarise(weights, a = sum(b), c = mean(d))
into an actual call that delegates to dplyr::summarise
data %>% dplyr::summarise(a = sum(weights * b), c = mean(weights * d))
Here, a
and c
are new columns to be created inside the reduced data, and b
, d
and weights
are existing columns in data
.
Ideally, I want my to call my function exactly as I would a "native" dplyr::summarise
, but with an extra weights
argument that gets sprinkled into each aggregation function.
weighted_summarise <- function(data, weights, ...) {
data %>% dplyr::summarise(
# how to manipulate the ... and inject the weights in each name-value pair?
)
}
Question How can I manipulate the ellipsis so that the weights
will be injected into every name-value pair in the appropriate place? I want to somehow capture an AST and walk it and manipulate it systematically.
CodePudding user response:
Here is one option to interpolate the 'weights' into expression passed in ...
by converting the multiple expressions into a single string and parse it to evaluate
weighted_summarise <- function(data, weights, ...) {
weights <- rlang::as_string(rlang::ensym(weights))
v1 <- purrr::map_chr(rlang::enexprs(...),
~ stringr::str_replace(rlang::as_label(.x), "\\(",
function(x) stringr::str_c("(", weights, "*")))
eval(rlang::parse_expr(stringr::str_c("data %>%
summarise(", stringr::str_c(names(v1), v1, sep = "=",
collapse = ", "), ")")))
}
-testing
> data %>%
weighted_summarise(weights, a = sum(b), c = mean(d))
# A tibble: 1 × 2
a c
<dbl> <dbl>
1 -2.95 1.13
# testing with the original summarise code outside the function
> data %>%
dplyr::summarise(a = sum(weights * b), c = mean(weights * d))
# A tibble: 1 × 2
a c
<dbl> <dbl>
1 -2.95 1.13
data
data <- structure(list(b = c(-0.545880758366027, 0.536585304107612, 0.419623148618683,
-0.583627199210279, 0.847460017311944, 0.266021979364892, 0.444585270360416,
-0.466495123565759, -0.848370043948898, 0.00231194241576697),
d = c(-1.31690812429962, 0.598269112694685, -0.7622143703459,
-1.42909030324076, 0.332244449013422, -0.469060687608488,
-0.334986793584065, 1.53625215550584, 0.609994533253692,
0.51633569843567), weights = 1:10), class = c("tbl_df", "tbl",
"data.frame"), row.names = c(NA, -10L))