You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried updating our Brax version from 0.11.0 to 0.12.1 but noticed issues while doing that. Our trainer, which uses PPO and has been working fine until now, now freezes due to the acting.Evaluator.run_evaluator taking for ever to compile (literally, at least 12 hours). I've spent a couple of full working days chasing down this bug and there seems to be some pretty weird dynamics in play here, and thus I've not been able to narrow the issue down to a single underlying issue.
And here's what I've observed so far. First, when running this file, the compilation takes a long time. I have not run this particular script for more than 10 minutes, but the original training pipeline in our code base ran for >12 hours without the evaluation loop compilation finishing (it takes a couple of minutes with brax==0.11.0).
$ time python -m evaluator_compile_freeze
2024-12-27 15:31:01.399190: E external/xla/xla/service/slow_operation_alarm.cc:73]
********************************
[Compiling module jit_generate_eval_unroll] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
^C^\[2] 310662 quit (core dumped) python -m evaluator_compile_freeze
python -m evaluator_compile_freeze 656.21s user 2.35s system 106% cpu 4:02.79 total
Second, I believe, although am not sure, that the underlying issue might have something to do with the mocap bodies in the environment. If I comment out the mocap code above, the code runs fine:
$ time python -m evaluator_compile_freeze
{'eval/walltime': 34.80631756782532, 'eval/episode_reward': Array(0., dtype=float32), 'eval/episode_reward_std': Array(0., dtype=float32), 'eval/avg_episode_length': Array(600., dtype=float32), 'eval/epoch_eval_time': 34.80631756782532, 'eval/sps': 2206.495985975641}
python -m evaluator_compile_freeze 54.26s user 2.43s system 140% cpu 40.298 total
The reason I'm not sure about this connection is that the use of mocap_pos is not a culprit in itself: I tried to reproduce this issue by swapping the above environment with the Humanoid-environment from the MJX tutorial notebook, with similar mocap body added to it, and that compiled fine.
jax-ml/jax#6823 and jax-ml/jax#9651 seemed like they could be related and I tried running the code with JAX_ENABLE_MLIR=0 OMP_NUM_THREADS=1 (as suggested in the issues) but that did not fix the problem.
I also tried the recommended XLA_FLAGS=--xla_dump_to=/tmp/foo. I don't know exactly how to interpret the results for those, however. What I notice, though, is that the working run outputs files up to module_0211.jit__..., whereas the hanging run only outputs files until module_0109.jit_generate_eval_unroll..., which suggests that the compilation gets stuck at generate_eval_unroll. I've manually verified that this is indeed the case.
Does anyone have pointers to how to fix or further investigate this issue?
My setup (I've installed jax with jax[cuda12-local] extras):
I appears that it's jax==0.4.36 that breaks this. I haven't been able to point out the source of the problem in that release though. The release notes mention a "stackless" tracing machinery, which sounded like it could've changed something. I ran the hanging cases above with config.jax_data_dependent_tracing_fallback (as per the release notes) but that didn't change anything.
Perhaps this is something that should be reported in Jax. Unfortunately, however, I'm also unable to reproduce this issue outside of Brax + MJX. Any leads would be much appreciated.
Hey folks,
I tried updating our Brax version from
0.11.0
to0.12.1
but noticed issues while doing that. Our trainer, which uses PPO and has been working fine until now, now freezes due to theacting.Evaluator.run_evaluator
taking for ever to compile (literally, at least 12 hours). I've spent a couple of full working days chasing down this bug and there seems to be some pretty weird dynamics in play here, and thus I've not been able to narrow the issue down to a single underlying issue.Here's the model I've been using:
Example script
And here's what I've observed so far. First, when running this file, the compilation takes a long time. I have not run this particular script for more than 10 minutes, but the original training pipeline in our code base ran for >12 hours without the evaluation loop compilation finishing (it takes a couple of minutes with
brax==0.11.0
).Second, I believe, although am not sure, that the underlying issue might have something to do with the mocap bodies in the environment. If I comment out the mocap code above, the code runs fine:
The reason I'm not sure about this connection is that the use of
mocap_pos
is not a culprit in itself: I tried to reproduce this issue by swapping the above environment with theHumanoid
-environment from the MJX tutorial notebook, with similar mocap body added to it, and that compiled fine.jax-ml/jax#6823 and jax-ml/jax#9651 seemed like they could be related and I tried running the code with
JAX_ENABLE_MLIR=0 OMP_NUM_THREADS=1
(as suggested in the issues) but that did not fix the problem.I also tried the recommended
XLA_FLAGS=--xla_dump_to=/tmp/foo
. I don't know exactly how to interpret the results for those, however. What I notice, though, is that the working run outputs files up tomodule_0211.jit__...
, whereas the hanging run only outputs files untilmodule_0109.jit_generate_eval_unroll...
, which suggests that the compilation gets stuck atgenerate_eval_unroll
. I've manually verified that this is indeed the case.Does anyone have pointers to how to fix or further investigate this issue?
My setup (I've installed
jax
withjax[cuda12-local]
extras):The text was updated successfully, but these errors were encountered: