I am trying to import tensorflow_probability.substrates.jax (specifically to use the distributions) and getting the error shown below (it looks like a self-import). I have installed tensorflow (2.8.2), tensorflow-probability (0.14.0) and jax (0.3.25).
Trying
import tensorflow_probability.substrates.jax as tfp
I get
ImportError: cannot import name 'bijectors' from partially initialized module
'tensorflow_probability.substrates.jax' (most likely due to a circular import)
(/path-to-anaconda3-env/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/__init__.py)
I have tried a few different versions of tensorflow-probability with the same results.
CodePudding user response:
This sounds like it's due to version incompatibility. tensorflow_probability v0.14 was released in Sept 2021 (history), at which point JAX's most recent release was version 0.2.20 (history). JAX has has 36 releases since then, so it's not surprising that some incompatibilities may have arisen.
I tried in Google Colab and found that the following combination works:
import tensorflow
import tensorflow_probability
import jax
print(f"{jax.__version__=}")
print(f"{tensorflow.__version__=}")
print(f"{tensorflow_probability.__version__=}")
import tensorflow_probability.substrates.jax as tfp
print("loaded!")
jax.__version__='0.3.25'
tensorflow.__version__='2.9.2'
tensorflow_probability.__version__='0.17.0'
loaded!
Another thing that can cause similar issues is if you are working in a notebook environment and installing new versions of packages that you've already imported. If you're working in notebooks, be sure to restart your Python runtime after you install or update a package.