-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
io/image: Replace PyTinyrenderer with a Pure JAX renderer #367
base: main
Are you sure you want to change the base?
Conversation
Simple benchmark suggests no improvment though, see https://colab.research.google.com/drive/1gBIevFjnRrEpo2uU9blZ5qu6KIzDWTl7
Friendly ping @erikfrey . With the release of jaxrenderer 0.3.0, I believe this should be a suitable replacement for current CPU renderer pytinyrenderer as 1) it is a pure JAX implementation; and 2) it performs better on high-end GPUs like A100 (>2x speedup). Let me know what you think :) |
Hi Joey - thanks for this tremendous effort! Unfortunately we're bound by certain restrictions that make it very difficult for us to depend on external packages unless they go through an arduous vetting process behind the scenes. I'm curious how your approach performs on CPU? You've probably already found that using XLA on its own limits your performance here, as XLA wasn't designed for rasterization-like operations and it does not know how to use the underlying GPU primitives to actually do this performantly. If you're interested, there is a way to register custom XLA ops, so that you could call into CUDA's rasterization functions. Tensorflow graphics does such a thing: You could then make such an op accessible to JAX. This would be very fast, but also quite tricky to get all the plumbing to work. Just thought I'd put it out there in case you're curious to hack on such a thing. Either way though, it would be hard for us to accept this PR due to the policies that control which external packages we can rely on. Sorry about that! But I'll leave the issue open for a while - I'd love to hear if you continue hacking on this. |
Thank you so much for your suggestions! I am also thinking about customised op as that would dramatically improve the performance. For the current performance, using high-end GPUs the rendering will actually be faster than CPU; current implementation is not optimised for TPU so the performance is really bad on TPUs, both for execution time and memory usage (up to 98% padding...). I will benchmark and improve the performance over a batch of small images (e.g. 84x84, as this is used in RL for Atori environment) to see if my implementation could benefit from rendering batches of environments in parallel (although we could definitely benefit if we are using a TPU Pod and pmap/xmap all environment simulation + rendering over devices). |
How about brax-contrib package or something like that where this can be added? |
This solves #331 #108 #47 #67 and is involved in #302
Example Colab (adopt majority of the code from Brax teams's Brax Training)
Known issue: the plane may not display in full correctly, due to unclipped vertices behind/at the camera plane. This will be solved once a proper clipping algorithm is implemented in JaxRenderer (will be delivered soon).The plane rendering issue is now solved by implementing a rasterisation based on homogeneous interpolation.
Currently this only make changes in the v2 API, with minimal changes. Feel free to tell me if you think it is needed to back port to v1 API as well. I will try to optimise the performance further later.
Update:
Ant
environment with 960x540 resolution and 1x SSAA is now about 500ms using T4 (0.5 fps), and ~70-100ms using A100 (10-14 fps). Previous CPU implementation is about 200ms per frame (5fps).