Home > front end >  How can this R data.table join group summarise, operation be made a lot faster?
How can this R data.table join group summarise, operation be made a lot faster?

Time:01-13

Summary of real-world problem

Essentially this is a scenario evaluation, of a linear system of equations. I have two data tables.

  • s_dt contains the scenarios, drivers (d) and values (v) for each observed scenario (o).
  • c_dt contains a series of terms (n) for a number of fitted model bases (b).
    The individual powers of drivers, and associated coefficients are coded into (d and t) as name-value pairs.
    Each basis (b) is essentially a polynomial with n terms.

The issue

Repro case below gives desired output format. But is far too slow for required use case, even on a cut-down problem. Numbers are junk, but I can't share actual data. Running on real-world data gives similar timing.

Circa 3sec for "lil" problem on my system (12 threads). But "big" problem is 4000 times larger. So expect circa 3hours. Ouch!
Aim is to have the "big" problem run sub 5min (or ideally much faster!)

So, awesome clever people, how can this be made a lot faster?
(And what is the root cause of the slowdown?)

I'll happily accept base/tidyverse solutions too, if they meet the performance needs. I just assumed data.table was the best way to go for the size of the problem.

Current solution

Run fun on s_dt, grouping by o.
fun: Joins c_dt with each group data, to populate v, thus enabling calculation of r (the result of evaluating each of the polynomial equations).

In data.table parlance:

s_dt[, fun(.SD), keyby = .(o)]

Repro case

  • Creates two data.tables that have the combinations and field types matching real-world problem.
    But with cut-down size for illustrative purposes.
  • Defines fun, then runs to populate r for all scenarios.
library(data.table)

# problem sizing ----
dims <- list(o = 50000, d = 50, b = 250, n = 200) # "big" problem - real-life size
dims <- list(o =   100, d = 50, b =  25, n = 200) # "lil" problem (make runtime shorter as example)

# build some test data tables ----
build_s <- function() {
  o <- seq_len(dims$o)
  d <- paste0("d",seq_len(dims$d))
  v <- as.double(seq_len(dims$o * dims$d))/10000
  CJ(o, d)[, `:=`(v = v)]
}
s_dt <- build_s()

build_c <- function() {
  b <- paste0("c", seq_len(dims$b))
  n <- seq_len(dims$n)
  d <- c("c", paste0("d", seq_len(dims$d)))
  t <- as.double(rep_len(0:6, dims$b * dims$n * (dims$d 1)))
  dt <- CJ(d, b, n)[, `:=`(t = t)]
  dt <- dt[t != 0]
}
c_dt <- build_c()

# define fun and evaluate ---- 
# (this is what needs optimising)
profvis::profvis({
  fun <- function(dt) {
    # don't use chaining here, for more useful profvis output
    dt <- dt[c_dt, on = .(d)]
    dt <- dt[, r := fcase(d == "c", t,
                          is.na(v), 0,
                          rep(TRUE, .N), v^t)]
    dt <- dt[, .(r = prod(r)), keyby = .(b, n)]
    dt <- dt[, .(r = sum(r)),  keyby = .(b)]
  }
  res <- s_dt[, fun(.SD), keyby = .(o)]
})

Example inputs and outputs

> res
        o   b            r
   1:   1  c1 0.000000e 00
   2:   1 c10 0.000000e 00
   3:   1 c11 0.000000e 00
   4:   1 c12 0.000000e 00
   5:   1 c13 0.000000e 00
  ---                     
2496: 100  c5 6.836792e-43
2497: 100  c6 6.629646e-43
2498: 100  c7 6.840915e-43
2499: 100  c8 6.624668e-43
2500: 100  c9 6.842608e-43

> s_dt
        o   d      v
   1:   1  d1 0.0001
   2:   1 d10 0.0002
   3:   1 d11 0.0003
   4:   1 d12 0.0004
   5:   1 d13 0.0005
  ---               
4996: 100 d50 0.4996
4997: 100  d6 0.4997
4998: 100  d7 0.4998
4999: 100  d8 0.4999
5000: 100  d9 0.5000

> c_dt
         d  b   n t
     1:  c c1   2 1
     2:  c c1   3 2
     3:  c c1   4 3
     4:  c c1   5 4
     5:  c c1   6 5
    ---            
218567: d9 c9 195 5
218568: d9 c9 196 6
218569: d9 c9 198 1
218570: d9 c9 199 2
218571: d9 c9 200 3

CodePudding user response:

This would be difficult to fully vectorize. The "big" problem requires so many operations that going parallel is probably the most straightforward way to get to ~5 minutes.

But first, we can get a ~3x speed boost by using RcppArmadillo for the product and sum calculations instead of data.table's grouping operations.

library(data.table)
library(parallel)

Rcpp::cppFunction(
  "std::vector<double> sumprod(arma::cube& a) {
  for(unsigned int i = 1; i < a.n_slices; i  ) a.slice(0) %= a.slice(i);
  return(as<std::vector<double>>(wrap(sum(a.slice(0), 0))));
}",
  depends = "RcppArmadillo",
  plugins = "cpp11"
)

cl <- makeForkCluster(detectCores() - 1L)

The following approach requires extensive preprocessing. The upshot is that it makes it trivial to parallelize. However, it will work only if the values of s_dt$d are the same for each o as in the MRE:

identical(s_dt$d, rep(s_dt[o == 1]$d, length.out = nrow(s_dt)))
#> [1] TRUE

Now let's build the functions to accept s_dt and c_dt:

# slightly modified original function for comparison
fun1 <- function(dt, c_dt) {
  # don't use chaining here, for more useful profvis output
  dt <- dt[c_dt, on = .(d)]
  dt <- dt[, r := fcase(d == "c", t,
                        is.na(v), 0,
                        rep(TRUE, .N), v^t)]
  dt <- dt[, .(r = prod(r)), keyby = .(b, n)]
  dt <- dt[, .(r = sum(r)),  keyby = .(b)]
}

fun2 <- function(s_dt, c_dt, cl = NULL) {
  s_dt <- copy(s_dt)
  c_dt <- copy(c_dt)
  # preprocess to get "a", "tt", "i", and "idxs"
  i_dt <- s_dt[o == 1][, idxs := .I][c_dt, on = .(d)][, ic := .I][!is.na(v)]
  ub <- unique(c_dt$b)
  un <- unique(c_dt$n)
  nb <- length(ub)
  nn <- length(un)
  c_dt[, `:=`(i = match(n, un)   nn*(match(b, ub) - 1L), r = 0)]
  c_dt[, `:=`(i = i   (0:(.N - 1L))*nn*nb, ni = .N), i]
  c_dt[d == "c", r := t]
  a <- array(1, c(nn, nb, max(c_dt$ni)))
  a[c_dt$i] <- c_dt$r # 3-d array to store v^t (updated for each unique "o")
  i <- c_dt$i[i_dt$ic] # the indices of "a" to update (same for each unique "o")
  tt <- c_dt$t[i_dt$ic] # c_dt$t ordered for "a" (same for each unique "o")
  idxs <- i_dt$idxs # the indices to order s_dt$v (same for each unique "o")
  uo <- unique(s_dt$o)
  v <- collapse::gsplit(s_dt$v, s_dt$o)
  
  if (is.null(cl)) {
    # non-parallel solution
    data.table(
      o = rep(uo, each = length(ub)),
      b = rep(ub, length(v)),
      r = unlist(
        lapply(
          v,
          function(x) {
            a[i] <- x[idxs]^tt
            sumprod(a)
          }
        )
      ),
      key = "o"
    )
  } else {
    # parallel solution
    clusterExport(cl, c("a", "tt", "i", "idxs"), environment())
    
    data.table(
      o = rep(uo, each = length(ub)),
      b = rep(ub, length(v)),
      r = unlist(
        parLapply(
          cl,
          v,
          function(x) {
            a[i] <- x[idxs]^tt
            sumprod(a)
          }
        )
      ),
      key = "o"
    )
  }
}

Now the data:

# problem sizing ----
bigdims <- list(o = 50000, d = 50, b = 250, n = 200) # "big" problem - real-life size
lildims <- list(o =   100, d = 50, b =  25, n = 200) # "lil" problem (make runtime shorter as example)

# build some test data tables ----
build_s <- function(dims) {
  o <- seq_len(dims$o)
  d <- paste0("d",seq_len(dims$d))
  v <- as.double(seq_len(dims$o * dims$d))/10000
  CJ(o, d)[, `:=`(v = v)]
}

build_c <- function(dims) {
  b <- paste0("c", seq_len(dims$b))
  n <- seq_len(dims$n)
  d <- c("c", paste0("d", seq_len(dims$d)))
  t <- as.double(rep_len(0:6, dims$b * dims$n * (dims$d 1)))
  dt <- CJ(d, b, n)[, `:=`(t = t)]
  dt <- dt[t != 0]
}

Timing the lil problem, which is so small that parallelization doesn't help:

s_dt <- build_s(lildims)
c_dt <- build_c(lildims)

microbenchmark::microbenchmark(fun1 = s_dt[, fun1(.SD, c_dt), o],
                               fun2 = fun2(s_dt, c_dt),
                               times = 10,
                               check = "equal")
#> Unit: seconds
#>  expr      min       lq     mean   median       uq      max neval
#>  fun1 3.204402 3.237741 3.383257 3.315450 3.404692 3.888289    10
#>  fun2 1.134680 1.138761 1.179907 1.179872 1.210293 1.259249    10

Now the big problem:

s_dt <- build_s(bigdims)
c_dt <- build_c(bigdims)

system.time(dt2p <- fun2(s_dt, c_dt, cl))
#>    user  system elapsed 
#>  24.937   9.386 330.600

stopCluster(cl)

A bit longer than 5 minutes with 31 cores.

  • Related