Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA-introduced copies supersede lax.optimization_barrier #25399

Closed
stephen-huan opened this issue Dec 11, 2024 · 3 comments
Closed

XLA-introduced copies supersede lax.optimization_barrier #25399

stephen-huan opened this issue Dec 11, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@stephen-huan
Copy link

Description

Almost certainly an XLA bug and happy to report there if so.

Consider the function

@partial(jit, donate_argnums=0)
def f(x: Array) -> tuple[Array, Array]:
    y = x[0, 0]
    x = x.at[0, 0].add(1)
    return x, y

Since XLA has control over scheduling, for efficiency it should schedule the slice first and then the in-place update, to avoid an unnecessary copy. However, on specifically the CPU backend it chooses to copy twice instead, generating

ENTRY %main.13 (Arg_0.1: f32[10000,10000]) -> (f32[10000,10000], f32[]) {
  %Arg_0.1 = f32[10000,10000]{1,0} parameter(0), metadata={op_name="x"}
  %copy.1 = f32[10000,10000]{1,0} copy(f32[10000,10000]{1,0} %Arg_0.1)
  %copy = f32[10000,10000]{1,0} copy(f32[10000,10000]{1,0} %copy.1)
  %add_dynamic-update-slice_fusion = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} %copy), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(g)/jit(main)/scatter-add" source_file="..." source_line=30}
  %slice_bitcast_fusion = f32[] fusion(f32[10000,10000]{1,0} %copy.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(g)/jit(main)/squeeze" source_file="..." source_line=29}
  ROOT %tuple.4 = (f32[10000,10000]{1,0}, f32[]) tuple(f32[10000,10000]{1,0} %add_dynamic-update-slice_fusion, f32[] %slice_bitcast_fusion)
}

(I'm not sure why it needs to make two copies here instead of just one, but the important part is that it copies at all.)

By the semantics of lax.optimization_barrier, I would expect that introducing an explicit dependency of x on y would force the slice to happen first, and then the liveliness analysis will kick in and remove the copies.

@partial(jit, donate_argnums=0)
def f(x: Array) -> tuple[Array, Array]:
    y = x[0, 0]
    x, y = lax.optimization_barrier((x, y))
    x = x.at[0, 0].add(1)
    return x, y

However, what ends up happening is XLA still introduces copies and re-orders the calls, so the generated code is the same as the one shown above. This seems to violate the scheduling control one expects from optimization_barrier.

Note that for this particular example, setting the XLA flag --xla_cpu_copy_insertion_use_region_analysis=true removes the copy and generates

ENTRY %main.13 (Arg_0.1: f32[10000,10000]) -> (f32[10000,10000], f32[]) {
  %Arg_0.1 = f32[10000,10000]{1,0} parameter(0), sharding={replicated}, metadata={op_name="x"}
  %slice_bitcast_fusion = f32[] fusion(f32[10000,10000]{1,0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(g)/jit(main)/squeeze" source_file="..." source_line=28}
  %add_dynamic-update-slice_fusion = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} %Arg_0.1), kind=kLoop, calls=%fused_computation.1, control-predecessors={%slice_bitcast_fusion}, metadata={op_name="jit(g)/jit(main)/scatter-add" source_file="..." source_line=30}
  ROOT %tuple.4 = (f32[10000,10000]{1,0}, f32[]) tuple(f32[10000,10000]{1,0} %add_dynamic-update-slice_fusion, f32[] %slice_bitcast_fusion)
}

as expected, with or without optimization_barrier. Also, using a GPU device generates the copyless

ENTRY %main.13 (Arg_0.1.0: f32[10000,10000]) -> (f32[10000,10000], f32[]) {
  %Arg_0.1.0 = f32[10000,10000]{1,0} parameter(0), metadata={op_name="x"}
  %wrapped_slice = f32[1,1]{1,0} fusion(f32[10000,10000]{1,0} %Arg_0.1.0), kind=kLoop, calls=%wrapped_slice_computation
  %bitcast.43.0 = f32[] bitcast(f32[1,1]{1,0} %wrapped_slice)
  %loop_dynamic_update_slice_fusion = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} %Arg_0.1.0), kind=kLoop, calls=%fused_dynamic_update_slice, control-predecessors={%wrapped_slice}, metadata={op_name="jit(g)/jit(main)/scatter-add" source_file="..." source_line=30}
  ROOT %tuple.5 = (f32[10000,10000]{1,0}, f32[]) tuple(f32[10000,10000]{1,0} %loop_dynamic_update_slice_fusion, f32[] %bitcast.43.0)
}

also with or without optimization_barrier.

Some miscellaneous related questions

  1. Is there a JAX interface to HloOrdering, particularly SequentialHloOrdering or is that controlled by the XLA flag --xla_cpu_enable_concurrency_optimized_scheduler? In particular, is there a way of manually writing schedules without relying only on optimization_barrier (which is not precise enough in cases like these)?
  2. I'm a bit confused why the workaround works now, since region analysis was introduced more than 3 years ago in openxla/xla@92292d1. The core logic of RemoveUnnecessaryCopies and TryElideCopy hasn't seemed to change much in that time either. Rather, what has recently changed is the flag xla_cpu_copy_insertion_use_region_analysis was added to CPU (disabled by default) and region analysis was disabled on GPU. Is there some context I'm missing?

(originally reported in the discussion #19165.)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.37
jaxlib: 0.4.36
numpy:  2.1.3
python: 3.12.1 (main, Oct  7 2024, 00:00:00) [GCC 11.4.1 20231218 (Red Hat 11.4.1-3)]
device info: NVIDIA L40S-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='babel-6-5', release='5.14.0-427.40.1.el9_4.x86_64', version='#1 SMP PREEMPT_DYNAMIC Wed Oct 16 07:08:17 EDT 2024', machine='x86_64')


$ nvidia-smi
Wed Dec 11 04:51:52 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:21:00.0 Off |                    0 |
| N/A   30C    P0             37W /  350W |     439MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                    
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    403772      C   python                                        430MiB |
+-----------------------------------------------------------------------------------------+
@stephen-huan stephen-huan added the bug Something isn't working label Dec 11, 2024
@hawkinsp
Copy link
Collaborator

Yes, I think this would be better reported on the XLA github issue tracker.

There's current no JAX way to control the HLO schedule, but that's something we're actively looking into adding as a way to control communication/compute overlap.

@stephen-huan
Copy link
Author

Opened openxla/xla#20440 on the XLA side. Should this issue be closed?

@hawkinsp
Copy link
Collaborator

Yeah, let's track this there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants