Home > Software engineering >  .SD in data.table join to refer to arbitrary list of columns in i
.SD in data.table join to refer to arbitrary list of columns in i

Time:11-28

Problem: Calculate weighted mean across columns of one table using weights in another table based on a join key.

Here are the steps in reprex:

library(data.table)
#DT1 table of values - here just 2 columns, but may be an arbitrary number
DT1 <- data.table(k1 = c('A1','A2','A3'), 
                  k2 = c('X','X','Y'), 
                  v1 = c(10,11,12), 
                  v2 = c(.5, .6, 1.7))
#DT2 table of weights - columns correspond to value columns in table 1
DT2 <- data.table(k2 = c('X','Y'), 
                  w1 = c(5,2), 
                  w2 = c(1,7))
#Vectors of corresponding column names (could be any number of columns)
vals <- c('v1','v2')
weights <- c('w1','w2')
i.weights <- paste0('i.', weights)

#1. This returns all columns
DT1[DT2, on=.(k2)]
#>    k1 k2 v1  v2 w1 w2
#> 1: A1  X 10 0.5  5  1
#> 2: A2  X 11 0.6  5  1
#> 3: A3  Y 12 1.7  2  7
#2. This use of SD is standard
DT1[DT2, on=.(k2), .SD, .SDcols = vals, by=.(k1)]
#>    k1 v1  v2
#> 1: A1 10 0.5
#> 2: A2 11 0.6
#> 3: A3 12 1.7
#3. But refer to the columns of i (DT2) and it fails, both without and with the i. prefix
DT1[DT2, on=.(k2), .SD, .SDcols = weights, by=.(k1)]
#> Error in `[.data.table`(DT1, DT2, on = .(k2), .SD, .SDcols = weights, : Some items of .SDcols are not column names: [w1, w2]
DT1[DT2, on=.(k2), .SD, .SDcols = i.weights, by=.(k1)]
#> Error in `[.data.table`(DT1, DT2, on = .(k2), .SD, .SDcols = i.weights, : Some items of .SDcols are not column names: [i.w1, i.w2]
#4. So following suggestion in https://stackoverflow.com/questions/43257664/sd-and-sdcols-for-the-i-expression-in-data-table-join
# turn to mget() - in one command it fails
DT1[DT2, on=.(k2), c(mget(vals), mget(weights)), by=.(k1,k2)]
#> Error: value for 'w1' not found
#5. But by exploiting 1. above and splitting into chained queries we get success!
DT1[DT2, on=.(k2),][, c(mget(vals), mget(weights)), by=.(k1,k2)]
#>    k1 k2 v1  v2 w1 w2
#> 1: A1  X 10 0.5  5  1
#> 2: A2  X 11 0.6  5  1
#> 3: A3  Y 12 1.7  2  7
#6. Now we can turn to the original intention, but no luck
DT1[DT2, on=.(k2)][, .(wmean = weighted.mean(mget(vals), mget(weights))), by=.(k1,k2)]
#> Error in x * w: non-numeric argument to binary operator
#7. One more step - turn the lists returned by mget to data.tables - hurrahh!
DT1[DT2, on=.(k2)][, .(wmean = weighted.mean(setDT(mget(vals)), setDT(mget(weights)))), by=.(k1,k2)]
#>    k1 k2    wmean
#> 1: A1  X 8.416667
#> 2: A2  X 9.266667
#> 3: A3  Y 3.988889

Created on 2021-11-26 by the reprex package (v2.0.0)

Should it really be this hard to do? Is there a more straightforward (and preferably more performant) way of doing this?

Corollary - I actually want to create a new column in DT1 with this calculation, but as this ends up with two chained queries I can't do the assignment in this command. I have to create a new table and join this back to the original to add the column. Is there a solution to the above that avoids this extra step?

CodePudding user response:

Another approach would be to melt the data from wide to long and then join to each other.

molten_dt1 = melt(DT1, measure.vars = vals)[, variable := as.integer(substring(variable, 2))]
molten_dt2 = melt(DT2, measure.vars = weights)[, variable := as.integer(substring(variable, 2))]

molten_dt1[molten_dt2, 
           on = .(k2, variable)
           ][,
             weighted.mean(value, i.value),
             by = .(k1, k2)]

The reason that it is not straight forward is that anytime we need to do parallel column lookups (i.e., v1 * w1 and v2 * w2), complications always increase because we need to account for that relationship between columns. Melting the data allows us to simplify our approach because the data structure allows us to join and also we are using vectors in the weighted.mean as opposed to data.frames.

One other note, is that you may be able to simplify the original approach if you create a new weighted.mean() method for lists which allows us to skip the setDT requirement.

## slight changes made to stats:::weighted.mean.default
weighted.mean.list = function (x, w, ..., na.rm = FALSE) 
{
  x = unlist(x)
  if (missing(w)) {
    if (na.rm) 
      x <- x[!is.na(x)]
    return(sum(x)/length(x))
  }
  w = unlist(w)
  if (length(w) != length(x)) 
    stop("'x' and 'w' must have the same length")
  if (na.rm) {
    i <- !is.na(x)
    w <- w[i]
    x <- x[i]
  }
  sum((x * w)[w != 0])/sum(w)
}

DT1[DT2, on=.(k2)][, .(wmean = weighted.mean(mget(vals), mget(weights))), by=.(k1,k2)]
  • Related