I'd like to model a multivariate distribution with structured parameters, for example: a multivariate normal with a covariance matrix made up of a low-rank part and a diagonal part. What is the recommended way to achieve this? (Tensorflow 2.8)
DIM=4
mean = tf.Variable(np.zeros(DIM), dtype=tf.float32, name='mean')
low_rank = tf.Variable(np.zeros((DIM, 2)), dtype=tf.float32, name='cov')
diagonal = tf.Variable(np.zeros(DIM), dtype=tf.float32, name='noise')
target_distribution = tfd.MultivariateNormalTriL(
loc=mean,
scale_tril=tf.linalg.cholesky(
tf.linalg.matmul(low_rank, low_rank, transpose_b=True) tf.linalg.diag(tf.math.softplus(diagonal))
)
)
print(target_distribution.trainable_variables)
returns only
(<tf.Variable 'mean:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>,)
, i.e. only those variables assigned directly enter the realm of tracked variables, not those that enter via an expression.
What is the syntax so low_rank
and diagonal
become trainable variables that I can fit to data?
I realize that there is tfd.MultivariateNormalDiagPlusLowRank
that solves this specific example, but I'm still interested in the recommended way to model structured parameters.
CodePudding user response:
When you run any TF op on a tf.Variable (in eager mode), the Variable value is read into a Tensor and the new value is computed, losing any previous association with the Variable. So in your example, the cholesky and matmul are all happening before the Distribution is constructed, and it never sees those Variables.
In TFP, we created a few utilities for working around this kind of issue, in particular tfp.util.DeferredTensor
, tfp.util.TransformedVariable
, and tfp.experimental.util.DeferredModule
. Each of these aim to allow for lazy evaluation/construction of some thing. TransformedVariable
is nice because it also handles updating of the underlying variable in pre-transform space. It's limited in the sense that it can only have a single underlying Variable -- your example suggests you'll want to have several floating around. Check out the examples in DeferredModule -- it might get you what you want. You may want to parameterize some composition of [tf.linalg.LinearOperator
s])https://www.tensorflow.org/api_docs/python/tf/linalg/LinearOperator) with some variables or something like that.
Here's your above example rewritten w/ DeferredModule: https://colab.research.google.com/drive/1DRX_Jv58abfsWE6h1BIz6YiQRAzCmn8r