You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since XLA has control over scheduling, for efficiency it should schedule the slice first and then the in-place update, to avoid an unnecessary copy. However, on specifically the CPU backend it chooses to copy twice instead, generating
(I'm not sure why it needs to make two copies here instead of just one, but the important part is that it copies at all.)
By the semantics of lax.optimization_barrier, I would expect that introducing an explicit dependency of x on y would force the slice to happen first, and then the liveliness analysis will kick in and remove the copies.
However, what ends up happening is XLA still introduces copies and re-orders the calls, so the generated code is the same as the one shown above. This seems to violate the scheduling control one expects from optimization_barrier.
Note that for this particular example, setting the XLA flag --xla_cpu_copy_insertion_use_region_analysis=true removes the copy and generates
which should introduce a copy does not introduce a copy with --xla_cpu_copy_insertion_use_region_analysis=true.
I'm a bit confused why the flag workaround works now, since region analysis was introduced more than 3 years ago in 92292d1. The core logic of RemoveUnnecessaryCopies and TryElideCopy hasn't seemed to change much in that time either. Rather, what has recently changed is the flag xla_cpu_copy_insertion_use_region_analysis was added to CPU (disabled by default) (#18521) and region analysis was disabled on GPU (#14680). Is there some context I'm missing?
Consider the JAX function
Since XLA has control over scheduling, for efficiency it should schedule the slice first and then the in-place update, to avoid an unnecessary copy. However, on specifically the CPU backend it chooses to copy twice instead, generating
(I'm not sure why it needs to make two copies here instead of just one, but the important part is that it copies at all.)
By the semantics of
lax.optimization_barrier
, I would expect that introducing an explicit dependency ofx
ony
would force the slice to happen first, and then the liveliness analysis will kick in and remove the copies.However, what ends up happening is XLA still introduces copies and re-orders the calls, so the generated code is the same as the one shown above. This seems to violate the scheduling control one expects from
optimization_barrier
.Note that for this particular example, setting the XLA flag
--xla_cpu_copy_insertion_use_region_analysis=true
removes the copy and generatesas expected, with or without
optimization_barrier
. Also, using a GPU device generates the copylessalso with or without
optimization_barrier
. Finall, the reverse explicit schedulewhich should introduce a copy does not introduce a copy with
--xla_cpu_copy_insertion_use_region_analysis=true
.I'm a bit confused why the flag workaround works now, since region analysis was introduced more than 3 years ago in 92292d1. The core logic of
RemoveUnnecessaryCopies
andTryElideCopy
hasn't seemed to change much in that time either. Rather, what has recently changed is the flagxla_cpu_copy_insertion_use_region_analysis
was added to CPU (disabled by default) (#18521) and region analysis was disabled on GPU (#14680). Is there some context I'm missing?(originally reported in the discussion jax-ml/jax#19165 and JAX issue jax-ml/jax#25399.)
The text was updated successfully, but these errors were encountered: