I am trying to run jax on an nvidia dgx box, but am failing miserably, thus:
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(10)
2021-10-25 13:00:05.863667: W
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:80] Couldn't
get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
2021-10-25 13:00:05.864713: F
external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:435]
ptxas returned an error during compilation of ptx to sass: 'INTERNAL: Failed to
launch ptxas' If the error message indicates that a file could not be written,
please verify that sufficient filesystem space is provided.
Aborted (core dumped)
Any suggestions would be much appreciated.
CodePudding user response:
This means that your CUDA installation is not configured correctly, and can generally be fixed by ensuring that the CUDA toolkit binaries (including ptxas
) are present in your $PATH
. See https://github.com/google/jax/discussions/6843 and https://github.com/google/jax/issues/7239 for responses to users reporting similar issues.