I try to use TensorFlow to test the kalman filter. I follow the official instruction (https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/LinearGaussianStateSpaceModel) to define the model, generate a sample and finally calculate the log-likelihood value of the sample.
I am running the code provided by the instruction
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import matplotlib.pyplot as plt
tfd = tfp.distributions
ndims = 2
step_std = 1.0
noise_std = 5.0
model = tfd.LinearGaussianStateSpaceModel(
num_timesteps=1000,
transition_matrix=tf.linalg.LinearOperatorIdentity(ndims),
transition_noise=tfd.MultivariateNormalDiag(
scale_diag=step_std**2 * tf.ones([ndims])),
observation_matrix=tf.linalg.LinearOperatorIdentity(ndims),
observation_noise=tfd.MultivariateNormalDiag(
scale_diag=noise_std**2 * tf.ones([ndims])),
initial_state_prior=tfd.MultivariateNormalDiag(
scale_diag=tf.ones([ndims])))
x = model.sample(1) # Sample from the prior on sequences of observations.
lp = model.log_prob(x) # Marginal likelihood of a (batch of) observations.
print(lp)
It takes 30 second to calculate the log-likelhoo. PS: I ran the code on colab and GPU was used.
My questions: Why it is so slow and how I can improve the performance?
Thanks.
CodePudding user response:
Eager mode (the default in TF) is pretty slow in general. You can graph-ify this with tf.function.
lp = tf.function(model.log_prob, autograph=False, jit_compile=False)(x)
You can also set jit_compile to True and lower to xla. That will add some compile time (possibly nontrivial) but will usually make the code faster and will amortize if you will run it many times.