Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. #12

Open
jmSNU opened this issue Sep 25, 2024 · 2 comments
Assignees

Comments

@jmSNU
Copy link

jmSNU commented Sep 25, 2024

Hello,

I would like to report a serious issue I have encountered.

To provide context, my local environment is configured as follows:

GPU: RTX 4090
CUDA version: 12.5
cuDNN version: 8.9
OS: Ubuntu 22.04
When I attempted to run the example provided in a virtual Conda environment, I encountered the following error.

Traceback (most recent call last):
File "recall2imagine/train.py", line 241, in
main()
File "recall2imagine/train.py", line 63, in main
agent = agt.Agent(env.obs_space, env.act_space, step, config)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 20, in init
super().init(agent_cls, *args, **kwargs)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 50, in init
self.varibs = self._init_varibs(obs_space, act_space)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 245, in _init_varibs
state, varibs = self._init_train(varibs, rng, data['is_first'])
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/ninjax.py", line 199, in wrapper
created = init(statics, rng, *args, **kw)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
compiled = _pjit_lower(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

Additionally, I observed that even a simple JAX operation (e.g., jnp.ones((3,))) results in the same error within the Conda environment. Based on my observations, I believe this issue may be related to compatibility between JAX, CUDA, and cuDNN.

I hope this matter can be resolved soon. Thank you for your attention to this issue.

@artemZholus artemZholus self-assigned this Sep 25, 2024
@artemZholus
Copy link
Collaborator

Hi @jmSNU !

I think your issue is related to a broken installation of jax. I recommend you to follow these steps https://github.com/chandar-lab/Recall2Imagine/tree/main?tab=readme-ov-file#conda and create a fresh conda environment. If that does not work, I recommend you to try building a docker image following steps here https://github.com/chandar-lab/Recall2Imagine/tree/main?tab=readme-ov-file#conda . If that also does not work , I can provide you with the docker image we used for experiments.

@jmSNU
Copy link
Author

jmSNU commented Sep 27, 2024

Thank you for your reply.

However, I tried to follow both ways but failed. For conda, jax was a problem, while for docker pip installation failed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants