Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io/image: Replace PyTinyrenderer with a Pure JAX renderer #367

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

JoeyTeng
Copy link

@JoeyTeng JoeyTeng commented Jun 2, 2023

This solves #331 #108 #47 #67 and is involved in #302

Example Colab (adopt majority of the code from Brax teams's Brax Training)

Known issue: the plane may not display in full correctly, due to unclipped vertices behind/at the camera plane. This will be solved once a proper clipping algorithm is implemented in JaxRenderer (will be delivered soon).
The plane rendering issue is now solved by implementing a rasterisation based on homogeneous interpolation.

Currently this only make changes in the v2 API, with minimal changes. Feel free to tell me if you think it is needed to back port to v1 API as well. I will try to optimise the performance further later.

Update:

  1. with 8a111d9, I jitted the rendering for each frame when renders a batch of states.
  2. 828848a refactor batch rendering a bit but with no significant improvement in performance (if there is any), see simple benchmark here
  3. with 0.3.0: Performance Improvement JoeyTeng/jaxrenderer#2 (release 0.3.0), the performance is improved to about 10x. Rendering one frame of the Ant environment with 960x540 resolution and 1x SSAA is now about 500ms using T4 (0.5 fps), and ~70-100ms using A100 (10-14 fps). Previous CPU implementation is about 200ms per frame (5fps).
  4. with Lower minimum Python version to 3.8; Improve typing annotations JoeyTeng/jaxrenderer#3 (release 0.3.1), the minimum Python version is lowered to Python 3.8 which is the same as brax.

@JoeyTeng JoeyTeng marked this pull request as ready for review June 3, 2023 15:44
@JoeyTeng
Copy link
Author

Friendly ping @erikfrey . With the release of jaxrenderer 0.3.0, I believe this should be a suitable replacement for current CPU renderer pytinyrenderer as 1) it is a pure JAX implementation; and 2) it performs better on high-end GPUs like A100 (>2x speedup). Let me know what you think :)

@erikfrey
Copy link
Collaborator

Hi Joey - thanks for this tremendous effort! Unfortunately we're bound by certain restrictions that make it very difficult for us to depend on external packages unless they go through an arduous vetting process behind the scenes.

I'm curious how your approach performs on CPU?

You've probably already found that using XLA on its own limits your performance here, as XLA wasn't designed for rasterization-like operations and it does not know how to use the underlying GPU primitives to actually do this performantly.

If you're interested, there is a way to register custom XLA ops, so that you could call into CUDA's rasterization functions. Tensorflow graphics does such a thing:

https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/rendering/opengl/rasterizer_op.cc

You could then make such an op accessible to JAX. This would be very fast, but also quite tricky to get all the plumbing to work. Just thought I'd put it out there in case you're curious to hack on such a thing.

Either way though, it would be hard for us to accept this PR due to the policies that control which external packages we can rely on. Sorry about that! But I'll leave the issue open for a while - I'd love to hear if you continue hacking on this.

@JoeyTeng
Copy link
Author

Thank you so much for your suggestions! I am also thinking about customised op as that would dramatically improve the performance.

For the current performance, using high-end GPUs the rendering will actually be faster than CPU; current implementation is not optimised for TPU so the performance is really bad on TPUs, both for execution time and memory usage (up to 98% padding...). I will benchmark and improve the performance over a batch of small images (e.g. 84x84, as this is used in RL for Atori environment) to see if my implementation could benefit from rendering batches of environments in parallel (although we could definitely benefit if we are using a TPU Pod and pmap/xmap all environment simulation + rendering over devices).

@sai-prasanna
Copy link

How about brax-contrib package or something like that where this can be added?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants