I'm working on a problem where a matrix has to be iteratively computed a very large number of times. The necessary matrix multiplication takes the form
t(X) %*% ( ( X %*% W %*% t(X) * mu0 ) * mu1 )
where X
is N x P
and W
is a symmetric P x P
matrix. mu0
and mu1
are N x 1
vectors that are cheap to compute and enter the respective products element-wise.
Unfortunately, N
may be quite large, which leads to an immense computational demand due to X %*% W %*% t(X)
being N x N
. I was wondering whether there are any strategies or computational tricks, for instance based on matrix decompositions, that could be used to speed up computations here. In every iteration, mu0
and mu1
change, but X
and W
are fixed, so any precomputations including those matrices will work.
BENCHMARKING
The fastest approach that I could think of so far is to do some obvious precomputations:
# fake data
N = 2500
P = 10
X = matrix(rnorm(N*P), N, P)
W = matrix(rnorm(P*P), P, P)
mu0 = rnorm(N)
mu1 = rnorm(N)
# precomputations
tX = t(X)
XWX = X %*% W %*% t(X)
# functions
f_raw = function(X, W, mu0, mu1){t(X) %*% ( ( X %*% W %*% t(X) * mu0 ) * mu1 )}
f_precomp = function(XWX, tX, mu0, mu1){tX %*% ( ( XWX * mu0 ) * mu1 )}
# benchmark
microbenchmark::microbenchmark(f_raw(X, W, mu0, mu1),
f_precomp(XWX, tX, mu0, mu1))
Unit: milliseconds
expr min lq mean median uq max neval
f_raw(X, W, mu0, mu1) 283.5918 286.5080 299.4621 289.5151 302.9726 355.4271 100
f_precomp(XWX, tX, mu0, mu1) 167.4169 168.7336 180.6468 171.0852 197.7475 263.8090 100
CodePudding user response:
Your example data and attempt:
# fake data
N = 2500
P = 10
X = matrix(rnorm(N*P), N, P)
W = matrix(rnorm(P*P), P, P)
mu0 = rnorm(N)
mu1 = rnorm(N)
# precomputations
tX = t(X)
XWX = X %*% W %*% t(X)
# functions
f_raw = function(X, W, mu0, mu1){t(X) %*% ( ( X %*% W %*% t(X) * mu0 ) * mu1 )}
f_precomp = function(XWX, tX, mu0, mu1){tX %*% ( ( XWX * mu0 ) * mu1 )}
A better idea is to pre-compute
WX <- tcrossprod(W, X)
Then
f_better <- function (X, mu0, mu1, WX) crossprod(X, (mu0 * mu1) * X) %*% WX
Benchmark on my laptop:
microbenchmark::microbenchmark(f_raw(X, W, mu0, mu1),
f_precomp(XWX, tX, mu0, mu1),
f_better(X, mu0, mu1, WX))
#Unit: milliseconds
# expr min lq mean median
# f_raw(X, W, mu0, mu1) 236.926979 238.984573 243.86319 242.212716
# f_precomp(XWX, tX, mu0, mu1) 151.190689 152.612059 156.25277 155.646093
# f_better(X, mu0, mu1, WX) 1.113031 1.126434 1.17974 1.138207
# uq max neval
# 244.721163 270.477737 100
# 157.970378 182.935082 100
# 1.146352 5.130876 100
To verify correctness:
ans_raw <- f_raw(X, W, mu0, mu1)
ans_precomp <- f_precomp(XWX, tX, mu0, mu1)
ans_better <- f_better(X, mu0, mu1, WX)
all.equal(ans_raw, ans_precomp)
# [1] TRUE
all.equal(ans_raw, ans_better)
# [1] TRUE
I don't know if mu0
and mu1
are guaranteed to be non-negative in your real application. If so, the following is 2x faster:
f_fastest <- function (X, mu0, mu1, WX) crossprod(sqrt(mu0 * mu1) * X) %*% WX