Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation offloading with scan fails in XLA with same shape pinned-host/device scan outputs #20373

Open
jaro-sevcik opened this issue Dec 10, 2024 · 1 comment · May be fixed by #20374
Open

Comments

@jaro-sevcik
Copy link
Contributor

If there is a while loop with one output on pinned host and other one on device with the same shape, a CSE pass will combine their buffers and this will break host offloader (on GPU).

The CSE pass was newly introduced as a side effect of turning --xla_gpu_experimental_enable_triton_softmax_priority_fusion on by default (see eab45d5).

JAX repro (GPU):

import jax
import jax.numpy as jnp

p = jax.sharding.SingleDeviceSharding(jax.devices()[0], memory_kind="pinned_host")

@jax.jit
def f():
  def g(_1, _2):
    return None, (jax.device_put(jnp.array(1.0), p), jnp.array(2.0))
  return jax.lax.scan(g, None, length = 4)[1]

print(f()[0].sharding)  # doesn't crash

Error message:

Traceback (most recent call last):
  File "repro.py", line 14, in <module>
    print(f()[0].sharding)  # doesn't crash
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: during context [end-of-post-layout_assignment]: Expected instruction to have shape equal to (s32[], f32[4]{0:S(5)}, f32[4]{0:S(5)}), actual shape is (s32[], f32[4]{0:S(5)}, f32[4]{0}):
%tuple.4 = (s32[], f32[4]{0:S(5)}, f32[4]{0}) tuple(s32[] %constant.1, f32[4]{0:S(5)} %custom-call.3, f32[4]{0:S(5)} %custom-call.3), metadata={op_name="jit(f)/jit(main)/while" source_file="repro.py" source_line=12}
@jaro-sevcik
Copy link
Contributor Author

I think the right fix is to move the host-offload-legalize pass after the CSE pass. PR coming soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant