From 3e217ea13a7bfb74d61d55301df4521c6735185b Mon Sep 17 00:00:00 2001 From: Joey Teng Date: Thu, 1 Jun 2023 15:02:31 +0100 Subject: [PATCH 1/5] feat(io/image): move to JaxRenderer backend https://github.com/JoeyTeng/jaxrenderer --- brax/io/image.py | 125 ++++++++++++++++++++++++++--------------------- setup.py | 1 + 2 files changed, 70 insertions(+), 56 deletions(-) diff --git a/brax/io/image.py b/brax/io/image.py index d099012e..ad6d91ad 100644 --- a/brax/io/image.py +++ b/brax/io/image.py @@ -24,38 +24,29 @@ from jax import numpy as jp import numpy as onp from PIL import Image -from pytinyrenderer import TinyRenderCamera as Camera -from pytinyrenderer import TinyRenderLight as Light -from pytinyrenderer import TinySceneRenderer as Renderer +from renderer import CameraParameters as Camera +from renderer import LightParameters as Light +from renderer import Model as RendererMesh +from renderer import ShadowParameters as Shadow +from renderer import Renderer, Scene, UpAxis, transpose_for_display import trimesh -class TextureRGB888: +def grid(grid_size, color): + grid = onp.zeros((grid_size, grid_size, 3), dtype=onp.single) + grid[:, :] = onp.array(color) / 255.0 + grid[0] = onp.zeros((grid_size, 3), dtype=onp.single) + # to reverse texture along y direction + grid[:, -1] = onp.zeros((grid_size, 3), dtype=onp.single) + return jp.asarray(grid) - def __init__(self, pixels): - self.pixels = pixels - self.width = int(onp.sqrt(len(pixels) / 3)) - self.height = int(onp.sqrt(len(pixels) / 3)) +_GROUND = grid(100, [200, 200, 200]) -class Grid(TextureRGB888): - def __init__(self, grid_size, color): - grid = onp.zeros((grid_size, grid_size, 3), dtype=onp.int32) - grid[:, :] = onp.array(color) - grid[0] = onp.zeros((grid_size, 3), dtype=onp.int32) - grid[:, 0] = onp.zeros((grid_size, 3), dtype=onp.int32) - super().__init__(list(grid.ravel())) - - -_BASIC = TextureRGB888([133, 118, 102]) -_TARGET = TextureRGB888([255, 34, 34]) -_GROUND = Grid(100, [200, 200, 200]) - - -def _scene(sys: brax.System, state: brax.State) -> Tuple[Renderer, List[int]]: - """Converts a brax System and state to a pytinyrenderer scene and instances.""" - scene = Renderer() +def _scene(sys: brax.System, state: brax.State) -> Tuple[Scene, List[int]]: + """Converts a brax System and state to a jaxrenderer scene and instances.""" + scene = Scene() instances = [] def take_i(obj, i): @@ -72,50 +63,60 @@ def take_i(obj, i): for _, geom in link_geoms.items(): for col in geom: - tex = TextureRGB888((col.rgba[:3] * 255).astype('uint32')) + tex = col.rgba[:3].reshape((1, 1, 3)) if isinstance(col, base.Capsule): half_height = col.length / 2 - model = scene.create_capsule(col.radius, half_height, 2, - tex.pixels, tex.width, tex.height) + scene, model = scene.add_capsule( + radius=col.radius, + half_height=half_height, + up_axis=UpAxis.Z, + diffuse_map=tex, + ) elif isinstance(col, base.Box): - model = scene.create_cube(col.halfsize, tex.pixels, tex.width, - tex.height, 16.) + scene, model = scene.add_cube( + half_extents=col.halfsize, + diffuse_map=tex, + texture_scaling=16., + ) elif isinstance(col, base.Sphere): - model = scene.create_capsule( - col.radius, 0, 2, tex.pixels, tex.width, tex.height + scene, model = scene.add_capsule( + radius=col.radius, + half_height=0., + up_axis=UpAxis.Z, + diffuse_map=tex, ) elif isinstance(col, base.Plane): tex = _GROUND - model = scene.create_cube([1000.0, 1000.0, 0.0001], tex.pixels, - tex.width, tex.height, 8192) + scene, model = scene.add_cube( + half_extents=[1000.0, 1000.0, 0.0001], + diffuse_map=tex, + texture_scaling=8192., + ) elif isinstance(col, base.Convex): # convex objects are not visual continue elif isinstance(col, base.Mesh): tm = trimesh.Trimesh(vertices=col.vert, faces=col.face) - vert_norm = tm.vertex_normals - model = scene.create_mesh( - col.vert.reshape((-1)).tolist(), - vert_norm.reshape((-1)).tolist(), - [0] * col.vert.shape[0] * 2, - col.face.reshape((-1)).tolist(), - tex.pixels, - tex.width, - tex.height, - 1.0, + mesh = RendererMesh.create( + verts=tm.vertices, + norms=tm.vertex_normals, + uvs=jp.zeros((tm.vertices.shape[0], 2), dtype=int), + faces=tm.faces, + diffuse_map=tex, ) + scene, model = scene.add_model(mesh) else: raise RuntimeError(f'unrecognized collider: {type(col)}') i = col.link_idx if col.link_idx is not None else -1 x = state.x.concatenate(base.Transform.zero((1,))) - instance = scene.create_object_instance(model) + scene, instance = scene.add_object_instance(model) off = col.transform.pos - pos = onp.array(x.pos[i]) + math.rotate(off, x.rot[i]) + pos = x.pos[i] + math.rotate(off, x.rot[i]) rot = col.transform.rot rot = math.quat_mul(x.rot[i], rot) - scene.set_object_position(instance, list(pos)) - scene.set_object_orientation(instance, [rot[1], rot[2], rot[3], rot[0]]) + scene = scene.set_object_position(instance, pos) + scene = scene.set_object_orientation(instance, rot) instances.append(instance) return scene, instances @@ -132,7 +133,7 @@ def _eye(sys: brax.System, state: brax.State) -> List[float]: def _up(unused_sys: brax.System) -> List[float]: """Determines the up orientation of the camera.""" - return [0, 0, 1] + return [0., 0., 1.] def get_camera( @@ -160,7 +161,9 @@ def render_array(sys: brax.System, height: int, light: Optional[Light] = None, camera: Optional[Camera] = None, - ssaa: int = 2) -> onp.ndarray: + ssaa: int = 2, + shadow: Optional[Shadow] = None, + enable_shadow: bool = True) -> onp.ndarray: """Renders an RGB array of a brax system and QP.""" if (len(state.x.pos.shape), len(state.x.rot.shape)) != (2, 2): raise RuntimeError('unexpected shape in state') @@ -173,7 +176,7 @@ def render_array(sys: brax.System, ambient=0.8, diffuse=0.8, specular=0.6, - shadowmap_center=target) + ) if camera is None: eye, up = _eye(sys, state), _up(sys) hfov = 58.0 @@ -186,12 +189,22 @@ def render_array(sys: brax.System, up=up, hfov=hfov, vfov=vfov) - img = scene.get_camera_image(instances, light, camera).rgb - arr = onp.reshape( - onp.array(img, dtype=onp.uint8), - (camera.view_height, camera.view_width, -1)) + if shadow is None and enable_shadow: + shadow = Shadow(centre=camera.target) + objects = [scene.objects[inst] for inst in instances] + img = Renderer.get_camera_image( + objects=objects, + light=light, + camera=camera, + width=camera.viewWidth, + height=camera.viewHeight, + shadow_param=shadow, + ) + arr = transpose_for_display(jax.lax.clamp(0., img * 255, 255.).astype(jp.uint8)) if ssaa > 1: - arr = onp.asarray(Image.fromarray(arr).resize((width, height))) + arr = onp.asarray(Image.fromarray(onp.asarray(arr)).resize((width, height))) + else: + arr = onp.asarray(arr) return arr diff --git a/setup.py b/setup.py index 4f6f0fec..7b87051e 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ "jax>=0.4.6", "jaxlib>=0.4.6", "jaxopt", + "jaxrenderer>=0.1.3", "jinja2", "mujoco", "numpy", From 24fc22fad69f380706d444b6638221195a4a7e6f Mon Sep 17 00:00:00 2001 From: Joey Teng Date: Sat, 3 Jun 2023 15:33:40 +0100 Subject: [PATCH 2/5] build: bump jaxrenderer to >= 0.2.0 for correct rendering --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7b87051e..bd5b05fd 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ "jax>=0.4.6", "jaxlib>=0.4.6", "jaxopt", - "jaxrenderer>=0.1.3", + "jaxrenderer>=0.2.0", "jinja2", "mujoco", "numpy", From b77de828c16b3cd2a49b5dd849ea3ddf1ccf10a1 Mon Sep 17 00:00:00 2001 From: Joey Teng Date: Sat, 3 Jun 2023 16:16:08 +0100 Subject: [PATCH 3/5] set default precision for dot to be highest --- brax/io/image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/brax/io/image.py b/brax/io/image.py index ad6d91ad..98ceec3d 100644 --- a/brax/io/image.py +++ b/brax/io/image.py @@ -155,6 +155,7 @@ def get_camera( return camera +@jax.default_matmul_precision("float32") def render_array(sys: brax.System, state: brax.State, width: int, From 8a111d97004de74f40fcd8fbd41a02915043e89c Mon Sep 17 00:00:00 2001 From: Joey Teng Date: Sun, 4 Jun 2023 09:30:18 +0100 Subject: [PATCH 4/5] refactor(io/image): refactor to enable jit for rendering --- brax/io/image.py | 226 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 159 insertions(+), 67 deletions(-) diff --git a/brax/io/image.py b/brax/io/image.py index 98ceec3d..3f40710c 100644 --- a/brax/io/image.py +++ b/brax/io/image.py @@ -15,7 +15,7 @@ """Exports a system config and state as an image.""" import io -from typing import List, Optional, Tuple +from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence import brax from brax import base @@ -27,8 +27,9 @@ from renderer import CameraParameters as Camera from renderer import LightParameters as Light from renderer import Model as RendererMesh +from renderer import ModelObject as Instance from renderer import ShadowParameters as Shadow -from renderer import Renderer, Scene, UpAxis, transpose_for_display +from renderer import Renderer, UpAxis, create_capsule, create_cube, transpose_for_display import trimesh @@ -43,18 +44,33 @@ def grid(grid_size, color): _GROUND = grid(100, [200, 200, 200]) +class Obj(NamedTuple): + """An object to be rendered in the scene. -def _scene(sys: brax.System, state: brax.State) -> Tuple[Scene, List[int]]: - """Converts a brax System and state to a jaxrenderer scene and instances.""" - scene = Scene() - instances = [] + Assume the system is unchanged throughout the rendering. + + col is accessed from the batched geoms `sys.geoms`, representing one geom. + """ + instance: Instance + """An instance to be rendered in the scene, defined by jaxrenderer.""" + link_idx: int + """col.link_idx if col.link_idx is not None else -1""" + off: jp.ndarray + """col.transform.rot""" + rot: jp.ndarray + """col.transform.rot""" + +def _build_objects(sys: brax.System) -> List[Obj]: + """Converts a brax System to a list of Obj.""" + objs: List[Obj] = [] def take_i(obj, i): return jax.tree_map(lambda x: jp.take(x, i, axis=0), obj) + link_names: List[str] link_names = [n or f'link {i}' for i, n in enumerate(sys.link_names)] link_names += ['world'] - link_geoms = {} + link_geoms: Dict[str, List[Any]] = {} for batch in sys.geoms: num_geoms = len(batch.friction) for i in range(num_geoms): @@ -64,62 +80,87 @@ def take_i(obj, i): for _, geom in link_geoms.items(): for col in geom: tex = col.rgba[:3].reshape((1, 1, 3)) + # reference: https://github.com/erwincoumans/tinyrenderer/blob/89e8adafb35ecf5134e7b17b71b0f825939dc6d9/model.cpp#L215 + specular_map = jax.lax.full(tex.shape[:2], 2.0) + if isinstance(col, base.Capsule): half_height = col.length / 2 - scene, model = scene.add_capsule( + model = create_capsule( radius=col.radius, half_height=half_height, up_axis=UpAxis.Z, diffuse_map=tex, + specular_map=specular_map, ) elif isinstance(col, base.Box): - scene, model = scene.add_cube( + model = create_cube( half_extents=col.halfsize, diffuse_map=tex, - texture_scaling=16., + texture_scaling=jp.array(16.), + specular_map=specular_map, ) elif isinstance(col, base.Sphere): - scene, model = scene.add_capsule( + model = create_capsule( radius=col.radius, - half_height=0., + half_height=jp.array(0.), up_axis=UpAxis.Z, diffuse_map=tex, + specular_map=specular_map, ) elif isinstance(col, base.Plane): tex = _GROUND - scene, model = scene.add_cube( - half_extents=[1000.0, 1000.0, 0.0001], + model = create_cube( + half_extents=jp.array([1000.0, 1000.0, 0.0001]), diffuse_map=tex, - texture_scaling=8192., + texture_scaling=jp.array(8192.), + specular_map=specular_map, ) elif isinstance(col, base.Convex): # convex objects are not visual continue elif isinstance(col, base.Mesh): tm = trimesh.Trimesh(vertices=col.vert, faces=col.face) - mesh = RendererMesh.create( + model = RendererMesh.create( verts=tm.vertices, norms=tm.vertex_normals, uvs=jp.zeros((tm.vertices.shape[0], 2), dtype=int), faces=tm.faces, diffuse_map=tex, ) - scene, model = scene.add_model(mesh) else: raise RuntimeError(f'unrecognized collider: {type(col)}') - i = col.link_idx if col.link_idx is not None else -1 - x = state.x.concatenate(base.Transform.zero((1,))) - scene, instance = scene.add_object_instance(model) + i: int = col.link_idx if col.link_idx is not None else -1 + instance = Instance(model=model) off = col.transform.pos - pos = x.pos[i] + math.rotate(off, x.rot[i]) rot = col.transform.rot - rot = math.quat_mul(x.rot[i], rot) - scene = scene.set_object_position(instance, pos) - scene = scene.set_object_orientation(instance, rot) - instances.append(instance) + obj = Obj(instance=instance, link_idx=i, off=off, rot=rot) + + objs.append(obj) + + return objs + +def _with_state(objs: Iterable[Obj], x: brax.Transform) -> List[Instance]: + """x must has at least 1 element. This can be ensured by calling + `x.concatenate(base.Transform.zero((1,)))`. x is `state.x`. - return scene, instances + This function does not modify any inputs, rather, it produces a new list of + `Instance`s. + """ + if (len(x.pos.shape), len(x.rot.shape)) != (2, 2): + raise RuntimeError('unexpected shape in state') + + instances: List[Instance] = [] + for obj in objs: + i = obj.link_idx + pos = x.pos[i] + math.rotate(obj.off, x.rot[i]) + rot = math.quat_mul(x.rot[i], obj.rot) + instance = obj.instance + instance = instance.replace_with_position(pos) + instance = instance.replace_with_orientation(rot) + instances.append(instance) + + return instances def _eye(sys: brax.System, state: brax.State) -> List[float]: @@ -136,6 +177,10 @@ def _up(unused_sys: brax.System) -> List[float]: return [0., 0., 1.] +def get_target(state: brax.State) -> jp.ndarray: + """Gets target of camera.""" + return jp.array([state.x.pos[0, 0], state.x.pos[0, 1], 0]) + def get_camera( sys: brax.System, state: brax.State, width: int, height: int, ssaa: int = 2 ) -> Camera: @@ -143,7 +188,7 @@ def get_camera( eye, up = _eye(sys, state), _up(sys) hfov = 58.0 vfov = hfov * height / width - target = [state.x.pos[0, 0], state.x.pos[0, 1], 0] + target = get_target(state) camera = Camera( viewWidth=width * ssaa, viewHeight=height * ssaa, @@ -156,52 +201,75 @@ def get_camera( @jax.default_matmul_precision("float32") -def render_array(sys: brax.System, - state: brax.State, - width: int, - height: int, - light: Optional[Light] = None, - camera: Optional[Camera] = None, - ssaa: int = 2, - shadow: Optional[Shadow] = None, - enable_shadow: bool = True) -> onp.ndarray: - """Renders an RGB array of a brax system and QP.""" - if (len(state.x.pos.shape), len(state.x.rot.shape)) != (2, 2): - raise RuntimeError('unexpected shape in state') - scene, instances = _scene(sys, state) - target = state.x.pos[0, :] +def render_instances( + instances: Sequence[Instance], + width: int, + height: int, + camera: Camera, + light: Optional[Light] = None, + shadow: Optional[Shadow] = None, + camera_target: Optional[jp.ndarray] = None, + enable_shadow: bool = True, +) -> jp.ndarray: + """Renders an RGB array of sequence of instances. + + Rendered result is not transposed with `transpose_for_display`; it is in + floating numbers in [0, 1], not `uint8` in [0, 255]. + """ if light is None: - direction = [0.57735, -0.57735, 0.57735] + direction = jp.array([0.57735, -0.57735, 0.57735]) light = Light( direction=direction, ambient=0.8, diffuse=0.8, specular=0.6, ) - if camera is None: - eye, up = _eye(sys, state), _up(sys) - hfov = 58.0 - vfov = hfov * height / width - camera = Camera( - viewWidth=width * ssaa, - viewHeight=height * ssaa, - position=eye, - target=target, - up=up, - hfov=hfov, - vfov=vfov) if shadow is None and enable_shadow: - shadow = Shadow(centre=camera.target) - objects = [scene.objects[inst] for inst in instances] + assert camera_target is not None, 'camera_target is None' + shadow = Shadow(centre=camera_target) + elif not enable_shadow: + shadow = None + img = Renderer.get_camera_image( - objects=objects, + objects=instances, light=light, camera=camera, - width=camera.viewWidth, - height=camera.viewHeight, + width=width, + height=height, shadow_param=shadow, ) - arr = transpose_for_display(jax.lax.clamp(0., img * 255, 255.).astype(jp.uint8)) + arr = jax.lax.clamp(0., img, 1.) + + return arr + +@jax.default_matmul_precision("float32") +def render_array(sys: brax.System, + state: brax.State, + width: int, + height: int, + light: Optional[Light] = None, + camera: Optional[Camera] = None, + ssaa: int = 2, + shadow: Optional[Shadow] = None, + enable_shadow: bool = True) -> onp.ndarray: + """Renders an RGB array of a brax system and QP.""" + objs = _build_objects(sys) + instances = _with_state(objs, state.x.concatenate(base.Transform.zero((1,)))) + + if camera is None: + camera = get_camera(sys, state, width, height, ssaa) + + img = render_instances( + instances=instances, + width=width * ssaa, + height=height * ssaa, + camera=camera, + light=light, + shadow=shadow, + camera_target=get_target(state), + enable_shadow=enable_shadow, + ) + arr = transpose_for_display((img * 255).astype(jp.uint8)) if ssaa > 1: arr = onp.asarray(Image.fromarray(onp.asarray(arr)).resize((width, height))) else: @@ -216,18 +284,42 @@ def render(sys: brax.System, light: Optional[Light] = None, cameras: Optional[List[Camera]] = None, ssaa: int = 2, - fmt='png') -> bytes: + fmt='png', + shadow: Optional[Shadow] = None, + enable_shadow: bool = True) -> bytes: """Returns an image of a brax system and QP.""" if not states: raise RuntimeError('must have at least one qp') if cameras is None: - cameras = [None] * len(states) + cameras = [get_camera(sys, state, width, height, ssaa) for state in states] + + objs = _build_objects(sys) + _render = jax.jit( + render_instances, + static_argnames=("width", "height", "enable_shadow"), + ) + frames: List[Image.Image] = [] + for state, camera in zip(states, cameras): + x = state.x.concatenate(base.Transform.zero((1,))) + instances = _with_state(objs, x) + target = state.x.pos[0, :] + img = _render( + instances=instances, + width=width * ssaa, + height=height * ssaa, + camera=camera, + light=light, + shadow=shadow, + camera_target=get_target(state), + enable_shadow=enable_shadow, + ) + arr = transpose_for_display((img * 255).astype(jp.uint8)) + frame = Image.fromarray(onp.asarray(arr)) + if ssaa > 1: + frame = frame.resize((width, height)) + + frames.append(frame) - frames = [ - Image.fromarray( - render_array(sys, state, width, height, light, camera, ssaa)) - for state, camera in zip(states, cameras) - ] f = io.BytesIO() if len(frames) == 1: frames[0].save(f, format=fmt) From 828848a8aa46d6c166c87e170fcc708a0218c410 Mon Sep 17 00:00:00 2001 From: Joey Teng Date: Mon, 5 Jun 2023 09:08:31 +0100 Subject: [PATCH 5/5] Refactor(io/image): submit render jobs first then copy back in parallel Simple benchmark suggests no improvment though, see https://colab.research.google.com/drive/1gBIevFjnRrEpo2uU9blZ5qu6KIzDWTl7 --- brax/io/image.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/brax/io/image.py b/brax/io/image.py index 3f40710c..0a9cfa62 100644 --- a/brax/io/image.py +++ b/brax/io/image.py @@ -242,7 +242,7 @@ def render_instances( return arr -@jax.default_matmul_precision("float32") + def render_array(sys: brax.System, state: brax.State, width: int, @@ -298,11 +298,10 @@ def render(sys: brax.System, render_instances, static_argnames=("width", "height", "enable_shadow"), ) - frames: List[Image.Image] = [] + images: List[jp.ndarray] = [] for state, camera in zip(states, cameras): x = state.x.concatenate(base.Transform.zero((1,))) instances = _with_state(objs, x) - target = state.x.pos[0, :] img = _render( instances=instances, width=width * ssaa, @@ -314,11 +313,14 @@ def render(sys: brax.System, enable_shadow=enable_shadow, ) arr = transpose_for_display((img * 255).astype(jp.uint8)) - frame = Image.fromarray(onp.asarray(arr)) - if ssaa > 1: - frame = frame.resize((width, height)) - - frames.append(frame) + images.append(arr) + + images_in_device: List[jp.ndarray] = jax.device_get(images) + np_arrays: Iterable[onp.ndarray] = map(onp.asarray, images_in_device) + frames: List[Image.Image] = [ + Image.fromarray(arr).resize((width, height)) + if ssaa > 1 else Image.fromarray(arr) + for arr in np_arrays] f = io.BytesIO() if len(frames) == 1: