Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

acting.Evaluator.run_evaluation compilation hangs #569

Open
hartikainen opened this issue Dec 27, 2024 · 1 comment
Open

acting.Evaluator.run_evaluation compilation hangs #569

hartikainen opened this issue Dec 27, 2024 · 1 comment

Comments

@hartikainen
Copy link

Hey folks,

I tried updating our Brax version from 0.11.0 to 0.12.1 but noticed issues while doing that. Our trainer, which uses PPO and has been working fine until now, now freezes due to the acting.Evaluator.run_evaluator taking for ever to compile (literally, at least 12 hours). I've spent a couple of full working days chasing down this bug and there seems to be some pretty weird dynamics in play here, and thus I've not been able to narrow the issue down to a single underlying issue.

Here's the model I've been using:

Example script
# ruff: noqa: ANN001, ANN202, FIX002, ERA001, PLR0913, PLR0914, PLR0915, PLR0917, TD003

from brax import envs
from brax.envs.base import PipelineEnv
from brax.envs.base import State
from brax.io import mjcf
from brax.training import acting
from brax.training import types
from etils import epath
import jax
from jax import numpy as jnp
import mujoco
from mujoco import mjx
import numpy as np
from robot_descriptions import fr3_mj_description
import time


def create_mj_model() -> mujoco.MjModel:
    franka_fr3_path = epath.Path(fr3_mj_description.MJCF_PATH)
    robot_spec = mujoco.MjSpec.from_file(franka_fr3_path.as_posix())

    for geom in robot_spec.geoms:
        geom.contype = 0
        geom.conaffinity = 0

    robot_spec.option.integrator = mujoco.mjtIntegrator.mjINT_IMPLICITFAST
    robot_spec.option.timestep = 1 / 1_000
    robot_spec.option.iterations = 6
    robot_spec.option.ls_iterations = 8
    robot_spec.option.disableflags |= mujoco.mjtDisableBit.mjDSBL_EULERDAMP

    model = robot_spec.compile()

    return model


class FrankaFR3(PipelineEnv):
    """Environment for training FR3 and DEX-EE to reach the target position."""

    def __init__(self, control_timestep: float = 1 / 50, **kwargs: dict) -> None:
        mj_model = create_mj_model()
        sys = mjcf.load_model(mj_model)

        kwargs["n_frames"] = 10
        kwargs["backend"] = "mjx"
        super().__init__(sys, **kwargs)

    def reset(self, rng: jax.Array) -> State:
        data = self.pipeline_init(self.sys.qpos0, jnp.zeros(self.sys.nv))
        obs, reward, done = jnp.zeros(7), jnp.zeros(()), jnp.zeros(())
        metrics, info = {}, {}
        return State(data, obs, reward, done, metrics, info)

    def step(self, state: State, action: jax.Array) -> State:
        data0 = state.pipeline_state
        data1 = self.pipeline_step(data0, action)
        return state.replace(pipeline_state=data1)


def main():
    environment = FrankaFR3()

    rng = jax.random.PRNGKey(0)
    key_env, eval_key, rng = jax.random.split(rng, 3)

    num_envs = 1024
    process_count = jax.process_count()
    process_id = jax.process_index()
    local_device_count = jax.local_device_count()
    local_devices_to_use = local_device_count
    device_count = local_devices_to_use * process_count

    key_envs = jax.random.split(key_env, num_envs // process_count)
    key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:])

    episode_length = 600
    action_repeat = 1
    env = envs.training.wrap(
        environment,
        episode_length=episode_length,
        action_repeat=action_repeat,
        randomization_fn=None,
    )  # pytype: disable=wrong-keyword-args

    num_eval_envs = 128

    def make_policy(params: types.Params, deterministic: bool = False):
        del params

        def policy(
            observations: types.Observation,
            key_sample: types.PRNGKey,
        ) -> tuple[types.Action, types.Extra]:
            if deterministic:
                actions = jnp.zeros((*observations.shape[0:-1], env.action_size))
            else:
                actions = jax.random.uniform(
                    key_sample,
                    (*observations.shape[0:-1], env.action_size),
                    minval=-1.0,
                    maxval=+1.0,
                )
            return actions, {}

        return policy

    evaluator = acting.Evaluator(
        env,
        make_policy,
        num_eval_envs=num_eval_envs,
        episode_length=episode_length,
        action_repeat=action_repeat,
        key=eval_key,
    )

    metrics = evaluator.run_evaluation(None, training_metrics={})
    print(metrics)


if __name__ == "__main__":
    main()

And here's what I've observed so far. First, when running this file, the compilation takes a long time. I have not run this particular script for more than 10 minutes, but the original training pipeline in our code base ran for >12 hours without the evaluation loop compilation finishing (it takes a couple of minutes with brax==0.11.0).

$ time python -m evaluator_compile_freeze
2024-12-27 15:31:01.399190: E external/xla/xla/service/slow_operation_alarm.cc:73]
********************************
[Compiling module jit_generate_eval_unroll] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
^C^\[2]    310662 quit (core dumped)  python -m evaluator_compile_freeze
python -m evaluator_compile_freeze  656.21s user 2.35s system 106% cpu 4:02.79 total

Second, I believe, although am not sure, that the underlying issue might have something to do with the mocap bodies in the environment. If I comment out the mocap code above, the code runs fine:

@@ -24,12 +24,6 @@ def create_mj_model() -> mujoco.MjModel:
         geom.contype = 0
         geom.conaffinity = 0
 
-    robot_spec.worldbody.add_body(
-        name="mocap-body",
-        pos=[0.5, 0.0, 0.5],
-        mocap=True,
-    )
-
     robot_spec.option.integrator = mujoco.mjtIntegrator.mjINT_IMPLICITFAST
     robot_spec.option.timestep = 1 / 1_000
     robot_spec.option.iterations = 6
@@ -61,8 +55,6 @@ def reset(self, rng: jax.Array) -> State:
     def step(self, state: State, action: jax.Array) -> State:
         data0 = state.pipeline_state
         data1 = self.pipeline_step(data0, action)
-        mocap_pos = data1.mocap_pos.at[0].set(data1.mocap_pos[0] + 0.01)
-        data1 = data1.replace(mocap_pos=mocap_pos)
         return state.replace(pipeline_state=data1)
$ time python -m evaluator_compile_freeze
{'eval/walltime': 34.80631756782532, 'eval/episode_reward': Array(0., dtype=float32), 'eval/episode_reward_std': Array(0., dtype=float32), 'eval/avg_episode_length': Array(600., dtype=float32), 'eval/epoch_eval_time': 34.80631756782532, 'eval/sps': 2206.495985975641}
python -m evaluator_compile_freeze  54.26s user 2.43s system 140% cpu 40.298 total

The reason I'm not sure about this connection is that the use of mocap_pos is not a culprit in itself: I tried to reproduce this issue by swapping the above environment with the Humanoid-environment from the MJX tutorial notebook, with similar mocap body added to it, and that compiled fine.

jax-ml/jax#6823 and jax-ml/jax#9651 seemed like they could be related and I tried running the code with JAX_ENABLE_MLIR=0 OMP_NUM_THREADS=1 (as suggested in the issues) but that did not fix the problem.

I also tried the recommended XLA_FLAGS=--xla_dump_to=/tmp/foo. I don't know exactly how to interpret the results for those, however. What I notice, though, is that the working run outputs files up to module_0211.jit__..., whereas the hanging run only outputs files until module_0109.jit_generate_eval_unroll..., which suggests that the compilation gets stuck at generate_eval_unroll. I've manually verified that this is indeed the case.

Does anyone have pointers to how to fix or further investigate this issue?

My setup (I've installed jax with jax[cuda12-local] extras):

$ python -c 'import jax, brax, mujoco; print(f"{jax.__version__=}, {brax.__version__=}, {mujoco.__version__=}")'
jax.__version__='0.4.36', brax.__version__='0.12.1', mujoco.__version__='3.2.6'

$ lspci | grep NVIDIA
00:03.0 3D controller: NVIDIA Corporation AD104GL [L4] (rev a1)

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 24.04.1 LTS
Release:        24.04
Codename:       noble

$ uname -sm
Linux x86_64
@hartikainen hartikainen changed the title Brax compilation hangs PPO evaluation compilation hangs Dec 27, 2024
@hartikainen hartikainen changed the title PPO evaluation compilation hangs acting.Evaluator.run_evaluation compilation hangs Dec 27, 2024
@hartikainen
Copy link
Author

hartikainen commented Dec 28, 2024

I just tried narrowing down the error to a specific version and here's what I observed:

  • mujoco==3.2.5, jax[cuda12-local]==0.4.35, brax==0.11.0: works
  • mujoco==3.2.6, jax[cuda12-local]==0.4.35, brax==0.12.1: works
  • mujoco==3.2.6, jax[cuda12-local]==0.4.36, brax==0.12.1: hangs
  • mujoco==3.2.5, jax[cuda12-local]==0.4.36, brax==0.11.0: hangs
  • mujoco==3.2.6, jax[cuda12-local]==0.4.37, brax==0.12.1: hangs
  • mujoco==3.2.6, jax[cuda12-local]==0.4.38, brax==0.12.1: hangs

I appears that it's jax==0.4.36 that breaks this. I haven't been able to point out the source of the problem in that release though. The release notes mention a "stackless" tracing machinery, which sounded like it could've changed something. I ran the hanging cases above with config.jax_data_dependent_tracing_fallback (as per the release notes) but that didn't change anything.

Perhaps this is something that should be reported in Jax. Unfortunately, however, I'm also unable to reproduce this issue outside of Brax + MJX. Any leads would be much appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant