Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nvidia GPU acceleration question #344

Open
HowdyMoto opened this issue Apr 24, 2023 · 1 comment
Open

nvidia GPU acceleration question #344

HowdyMoto opened this issue Apr 24, 2023 · 1 comment

Comments

@HowdyMoto
Copy link

I have a machine set up with the following:
nvidia RTX 4090 GPU
Fresh install of Ubuntu 22.04
5.30 nvidia drivers with CUDA 12.1

I set up Brax with pip install -e .
After doing so, learn works fine using CPU.

I then set up GPU-accelerated jax with:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

When I run learn with no flags, I see the below output, which looks like it searches for tpu and rocm, says CUDA is available but doesn't explicitly indicate it's using CUDA. If I let this run, it's painfully slow - it will run for many hours with few updates. learn --helpfull doesn't explicitly show any flags that I need to set to use CUDA and CUDNN - am I missing any steps to get it working? I know this should be extremely fast with this GPU.

åI0423 20:06:11.578360 140285023629312 metrics.py:42] Hyperparameters: {'num_evals': 10, 'num_envs': 4, 'total_env_steps': 50000000}
I0423 20:06:11.665900 140285023629312 xla_bridge.py:440] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter Host CUDA
I0423 20:06:11.666139 140285023629312 xla_bridge.py:440] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0423 20:06:11.666195 140285023629312 xla_bridge.py:440] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
I0423 20:06:13.768080 140285023629312 train.py:107] Device count: 1, process count: 1 (id 0), local device count: 1, devices to be used count: 1
I0423 20:06:41.432691 140285023629312 train.py:320] {'eval/walltime': 22.236454486846924, 'eval/episode_distance_from_origin': Array(48.86863, dtype=float32), 'eval/episode_forward_reward': Array(0.9787996, dtype=float32), 'eval/episode_reward': Array(-8.558953, dtype=float32), 'eval/episode_reward_contact': Array(0., dtype=float32), 'eval/episode_reward_ctrl': Array(-49.092438, dtype=float32), 'eval/episode_reward_forward': Array(0.9787996, dtype=float32), 'eval/episode_reward_survive': Array(39.554688, dtype=float32), 'eval/episode_x_position': Array(0.007761, dtype=float32), 'eval/episode_x_velocity': Array(0.9787996, dtype=float32), 'eval/episode_y_position': Array(-5.1503973, dtype=float32), 'eval/episode_y_velocity': Array(-1.67458, dtype=float32), 'eval/avg_episode_length': Array(39.554688, dtype=float32), 'eval/epoch_eval_time': 22.236454486846924, 'eval/sps': 5756.3133581261}
I0423 20:06:41.434379 140285023629312 metrics.py:51] [0] eval/avg_episode_length=39.5546875, eval/episode_distance_from_origin=48.868629455566406, eval/episode_forward_reward=0.97879958152771, eval/episode_reward=-8.558953285217285, eval/episode_reward_contact=0.0, eval/episode_reward_ctrl=-49.092437744140625, eval/episode_reward_forward=0.97879958152771, eval/episode_reward_survive=39.5546875, eval/episode_x_position=0.0077610015869140625, eval/episode_x_velocity=0.97879958152771, eval/episode_y_position=-5.150397300720215, eval/episode_y_velocity=-1.6745799779891968, eval/epoch_eval_time=22.236454, eval/sps=5756.313358, eval/walltime=22.236454
I0423 20:06:41.435662 140285023629312 train.py:326] starting iteration 0 27.667602062225342

@erikfrey
Copy link
Collaborator

Hello @HowdyMoto - you'll want to modify the hparams to work for your given machine setup. In your case, you want num_envs to be much higher, 2048 or 4096. That is probably the cause of your slowness. Check the training colab for some hparams that work with an accelerator:

https://colab.sandbox.google.com/github/google/brax/blob/main/notebooks/training.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants