Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Using XLA crashes out #303

Open
3 tasks done
dapatil211 opened this issue May 24, 2024 · 2 comments
Open
3 tasks done

[BUG] Using XLA crashes out #303

dapatil211 opened this issue May 24, 2024 · 2 comments
Assignees

Comments

@dapatil211
Copy link

Describe the bug

Using XLA interface crashes program.

To Reproduce

Run the example code in envpool/examples/xla_step.py as

JAX_TRACEBACK_FILTERING=off python xla_step.py

This results in the following error:

Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 623, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax_plugins/xla_cuda12/__init__.py", line 83, in initialize
    xla_client.register_custom_call_handler(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jaxlib/xla_client.py", line 633, in register_custom_call_handler
    handler(name, fn, xla_platform_name, api_version)
TypeError: register_custom_call_target(): incompatible function arguments. The following argument types are supported:
    1. register_custom_call_target(c_api: capsule, fn_name: str, fn: capsule, xla_platform_name: str, api_version: int = 0) -> None

Invoked with types: PyCapsule, bytes, PyCapsule, str, int
Traceback (most recent call last):
  File "/home/mila/d/darshan.patil/research/NSRL/test.py", line 106, in <module>
    gym_sync_step()
  File "/home/mila/d/darshan.patil/research/NSRL/test.py", line 56, in gym_sync_step
    run_actor_loop(100, (handle, states))
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 305, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 182, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/core.py", line 2789, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 1552, in _pjit_call_impl
    return xc._xla.pjit(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 1534, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 1464, in _pjit_call_impl_python
    inline=inline, lowering_parameters=mlir.LoweringParameters()).compile()
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2378, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2914, in from_hlo
    xla_executable = _cached_compilation(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2726, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/compiler.py", line 333, in compile_or_get_cached
    return _compile_and_write_cache(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/compiler.py", line 504, in _compile_and_write_cache
    executable = backend_compile(
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/compiler.py", line 238, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to AtariGymEnvPool_140505772519360_send_gpu for platform CUDA

Expected behavior

The example code to run without issues

Screenshots

If applicable, add screenshots to help explain your problem.

System info

This was the extent of my setup:

conda create -n test python=3.10
pip install -U "jax[cuda12]"
pip install envpool
import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)

Above code prints:

0.8.4 1.26.4 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] linux

Additional context

Add any other context about the problem here.

Reason and Possible fixes

N/A

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@ZhaoRunyi
Copy link

SAME ISSUE, could the author please gives solution?

@Michael190502
Copy link

Michael190502 commented Dec 12, 2024

Hi, I had the same issue but I could circumnavigate it with try and error. I dont know, though, what happens in the function I edited and I also do not know if the workaround creates issues in other places. I just tried things and this didnt give an error anymore. So copy this solution at your own risk, I have no clue if this could create issues with your GPU or somewhere else!

So the workaround: Initially I had the following issue: "TypeError: register_custom_call_target(): incompatible function arguments." I tracked that down to the "xla_client.py" file and the "register_custom_call_target" function. I did not get the error if I casted name to a string before passing it to "_custom_callback" or "_custom_callback_handler". Then though I got the same error as you. And it did not occur if I checked whether b' is in name (i.e. whether the string came from a bytestring) and if that was the case trimmed away the byte denominators, i.e. name = name[2:-1].

So before I passed name to the "_custom_callback" or "_custom_callback_handler" I casted it to a string, checked for "b'" and in case of that trimmed the name to name[2:-1]. Trimming it without checking did not work as some names seem to not come from byte strings.

(Again: I have no clue what I am doing here and I do not no whether my adjustments are safe and/or sensible)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants