You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
#12
I would like to report a serious issue I have encountered.
To provide context, my local environment is configured as follows:
GPU: RTX 4090
CUDA version: 12.5
cuDNN version: 8.9
OS: Ubuntu 22.04
When I attempted to run the example provided in a virtual Conda environment, I encountered the following error.
Traceback (most recent call last):
File "recall2imagine/train.py", line 241, in
main()
File "recall2imagine/train.py", line 63, in main
agent = agt.Agent(env.obs_space, env.act_space, step, config)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 20, in init
super().init(agent_cls, *args, **kwargs)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 50, in init
self.varibs = self._init_varibs(obs_space, act_space)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 245, in _init_varibs
state, varibs = self._init_train(varibs, rng, data['is_first'])
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/ninjax.py", line 199, in wrapper
created = init(statics, rng, *args, **kw)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
compiled = _pjit_lower(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
Additionally, I observed that even a simple JAX operation (e.g., jnp.ones((3,))) results in the same error within the Conda environment. Based on my observations, I believe this issue may be related to compatibility between JAX, CUDA, and cuDNN.
I hope this matter can be resolved soon. Thank you for your attention to this issue.
The text was updated successfully, but these errors were encountered:
Hello,
I would like to report a serious issue I have encountered.
To provide context, my local environment is configured as follows:
GPU: RTX 4090
CUDA version: 12.5
cuDNN version: 8.9
OS: Ubuntu 22.04
When I attempted to run the example provided in a virtual Conda environment, I encountered the following error.
Traceback (most recent call last):
File "recall2imagine/train.py", line 241, in
main()
File "recall2imagine/train.py", line 63, in main
agent = agt.Agent(env.obs_space, env.act_space, step, config)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 20, in init
super().init(agent_cls, *args, **kwargs)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 50, in init
self.varibs = self._init_varibs(obs_space, act_space)
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/jaxagent.py", line 245, in _init_varibs
state, varibs = self._init_train(varibs, rng, data['is_first'])
File "/home/jm/r2i_ws/Recall2Imagine/recall2imagine/ninjax.py", line 199, in wrapper
created = init(statics, rng, *args, **kw)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
compiled = _pjit_lower(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/jm/anaconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
Additionally, I observed that even a simple JAX operation (e.g., jnp.ones((3,))) results in the same error within the Conda environment. Based on my observations, I believe this issue may be related to compatibility between JAX, CUDA, and cuDNN.
I hope this matter can be resolved soon. Thank you for your attention to this issue.
The text was updated successfully, but these errors were encountered: