Skip to content

Commit

Permalink
GPU Simulated Scenes, New more organized asset locations (#217)
Browse files Browse the repository at this point in the history
* Fix: some scene manip env issues

* Better fetch cameras

* Fix: disable collisions only on fetch wheels

* Fix: fetch grapsing retval

* Progress: updates to scene-related classes/envs

* Fix: ai2thor scene builder set poses in initialize()

* Fix: ai2thor fetch scene builder ground collisions

* Fix: ai2 sb don't reconfigure by default

* Fix: sb load_actors instead of load_agents

* Trivial: add fetch resting qpos as attr

* Fix: pick_object scene pose setting

* Fix: renderer follows camera

* Fix: options is None

* Fix: fetch body delta pos controller

* Fix: record wrapper options=None by default

* Progress: support for loading navigable positions

* Progress: update base scene builder w/ DexM3 version

* Trivial: add arm camera, more "working" coacd objects

* minor ai2 sb changes

* Trivial: port over everything needed for scene manip testing

* bigger place reward

* Fix: keep x/y/zrot qvel in obs

* Try tune reward

* scale place reward more

* Try lessen place rew

* work

* work

* add bug fix

* Create scene_gpu.py

* Update sequential_task.py

* temp code

* make more objects static

* base empty scene manipulation env

* work

* work

* tests

* work

* more work

* modify benchmark code to support cpu and gpu sim tests

* add lighting?

* work

* articulations

* work

* Update scene_builder.py

* new dataset location

* fix for new dataset paths for some tasks, remove old tasks

* remove scene tasks as they are not done

* work

* work

* fixes

* change some defaults

---------

Co-authored-by: arth-shukla <[email protected]>
  • Loading branch information
StoneT2000 and arth-shukla authored Mar 1, 2024
1 parent 69c95ea commit 1defbae
Show file tree
Hide file tree
Showing 31 changed files with 915 additions and 323 deletions.
13 changes: 7 additions & 6 deletions examples/benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@ Code here is used to benchmark the performance of various simulators/benchmarks
To benchmark ManiSkill + SAPIEN, after following the setup instructions on this repository's README.md, run

```
python benchmark_gpu_sim.py -e "PickCube-v1" --num-envs=1024 --obs-mode=state # test just state simulation
python benchmark_gpu_sim.py -e "PickCube-v1" --num-envs=128 --obs-mode=rgbd # test state sim + parallel rendering one 128x128 RGBD cameras per environment
python benchmark_gpu_sim.py -e "PickCube-v1" --num-envs=128 --save-video # save a video showing all 128 visual observations
python benchmark_maniskill.py -e "PickCube-v1" --num-envs=1024 --obs-mode=state # test just state simulation
python benchmark_maniskill.py -e "PickCube-v1" --num-envs=128 --obs-mode=rgbd # test state sim + parallel rendering one 128x128 RGBD cameras per environment
python benchmark_maniskill.py -e "PickCube-v1" --num-envs=128 --save-video # save a video showing all 128 visual observations
```


To get the best reported results, we run two commands on a machine with a RTX 4090:
To get the reported results, we run two commands on a machine with a RTX 4090:
```
python benchmark_gpu_sim.py -e "PickCube-v1" --num-envs=4096 --obs-mode=state
python benchmark_gpu_sim.py -e "PickCube-v1" --num-envs=1536 --obs-mode=rgbd
python benchmark_maniskill.py -e "PickCube-v1" --num-envs=4096 --obs-mode=state --control-freq=50
python benchmark_maniskill.py -e "PickCube-v1" --num-envs=1536 --obs-mode=rgbd --control-freq=50
# note we use --control-freq=50 as this is the control frequency isaac sim based repos tend to use
```

These are the expected state-based only results:
Expand Down
44 changes: 0 additions & 44 deletions examples/benchmarking/benchmark_cpu_sim.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import tqdm

import mani_skill2.envs
from mani_skill2.envs.scenes.tasks.planner.planner import PickSubtask
from mani_skill2.envs.scenes.tasks.sequential_task import SequentialTaskEnv
from mani_skill2.utils.scene_builder.ai2thor.variants import ArchitecTHORSceneBuilder
from mani_skill2.utils.scene_builder.replicacad.scene_builder import ReplicaCADSceneBuilder
from mani_skill2.vector.wrappers.gymnasium import ManiSkillVectorEnv
from profiling import Profiler
from mani_skill2.utils.visualization.misc import images_to_video, tile_images
Expand All @@ -21,20 +25,30 @@
def main(args):
profiler = Profiler(output_format=args.format)
num_envs = args.num_envs
env = gym.make(
args.env_id,
num_envs=num_envs,
obs_mode=args.obs_mode,
# enable_shadow=True,
render_mode=args.render_mode,
control_mode=args.control_mode,
sim_cfg=dict(control_freq=50)
)
if isinstance(env.action_space, gym.spaces.Dict):
env = FlattenActionSpaceWrapper(env)
env = ManiSkillVectorEnv(env)
sim_cfg = dict()
if args.control_freq:
sim_cfg["control_freq"] = args.control_freq
if args.sim_freq:
sim_cfg["sim_freq"] = args.sim_freq
if not args.cpu_sim:
env = gym.make(
args.env_id,
num_envs=num_envs,
obs_mode=args.obs_mode,
# enable_shadow=True,
render_mode=args.render_mode,
control_mode=args.control_mode,
sim_cfg=sim_cfg
)
if isinstance(env.action_space, gym.spaces.Dict):
env = FlattenActionSpaceWrapper(env)
env = ManiSkillVectorEnv(env)
base_env = env.base_env
else:
env = gym.make_vec(args.env_id, num_envs=args.num_envs, vectorization_mode="async", vector_kwargs=dict(context="spawn"), obs_mode=args.obs_mode,)
base_env = gym.make(args.env_id, obs_mode=args.obs_mode).unwrapped
sensor_settings_str = []
for uid, cam in env.base_env._sensors.items():
for uid, cam in base_env._sensors.items():
cfg = cam.cfg
sensor_settings_str.append(f"{cfg.width}x{cfg.height}")
sensor_settings_str = "_".join(sensor_settings_str)
Expand All @@ -51,10 +65,10 @@ def main(args):
f"render_mode={args.render_mode}, sensor_details={sensor_settings_str}, save_video={args.save_video}"
)
print(
f"sim_freq={env.base_env.sim_freq}, control_freq={env.base_env.control_freq}"
f"sim_freq={base_env.sim_freq}, control_freq={base_env.control_freq}"
)
print(f"observation space: {env.observation_space}")
print(f"action space: {env.base_env.single_action_space}")
print(f"action space: {base_env.single_action_space}")
print(
"# -------------------------------------------------------------------------- #"
)
Expand All @@ -70,7 +84,7 @@ def main(args):
with profiler.profile("env.step", total_steps=N, num_envs=num_envs):
for i in range(N):
actions = (
2 * torch.rand(env.action_space.shape, device=env.base_env.device)
2 * torch.rand(env.action_space.shape, device=base_env.device)
- 1
)
obs, rew, terminated, truncated, info = env.step(actions)
Expand Down Expand Up @@ -126,6 +140,9 @@ def parse_args():
parser.add_argument("-o", "--obs-mode", type=str, default="state")
parser.add_argument("-c", "--control-mode", type=str, default="pd_joint_delta_pos")
parser.add_argument("-n", "--num-envs", type=int, default=1024)
parser.add_argument("--cpu-sim", action="store_true", help="Whether to use the CPU or GPU simulation")
parser.add_argument("--control-freq", type=int, default=None, help="The control frequency to use")
parser.add_argument("--sim-freq", type=int, default=None, help="The simulation frequency to use")
parser.add_argument(
"--render-mode",
type=str,
Expand Down
7 changes: 6 additions & 1 deletion mani_skill2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
PACKAGE_DIR = Path(__file__).parent.resolve()
PACKAGE_ASSET_DIR = PACKAGE_DIR / "assets"
# Non-package data
ASSET_DIR = Path(os.getenv("MS2_ASSET_DIR", "data"))
ASSET_DIR = Path(
os.getenv("MS_ASSET_DIR", os.path.join(os.path.expanduser("~"), ".maniskill/data"))
)
DEMO_DIR = Path(
os.getenv("MS_ASSET_DIR", os.path.join(os.path.expanduser("~"), ".maniskill/demos"))
)


def format_path(p: str):
Expand Down
52 changes: 32 additions & 20 deletions mani_skill2/agents/robots/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,25 @@ class Fetch(BaseAgent):
q=[1, 0, 0, 0],
width=128,
height=128,
fov=1.57,
fov=2,
near=0.01,
far=10,
entity_uid="head_camera_link",
),
CameraConfig(
uid="fetch_hand",
p=[-0.1, 0, 0.1],
q=[1, 0, 0, 0],
width=128,
height=128,
fov=2,
near=0.01,
far=10,
entity_uid="gripper_link",
)
]
REACHABLE_DIST = 1.5
RESTING_QPOS = np.array([0, 0, 0, 0.386, 0, -0.370, 0.562, -1.032, 0.695, 0.955, -0.1, 2.077, 0, 0.015, 0.015])

def __init__(self, *args, **kwargs):
self.arm_joint_names = [
Expand Down Expand Up @@ -204,14 +216,14 @@ def controller_configs(self):
# -------------------------------------------------------------------------- #
# Body
# -------------------------------------------------------------------------- #
body_pd_joint_pos = PDJointPosControllerConfig(
body_pd_joint_delta_pos = PDJointPosControllerConfig(
self.body_joint_names,
[-1.57, -1.6056, 0],
[1.57, 1.6056, 0.38615],
-0.1,
0.1,
self.body_stiffness,
self.body_damping,
self.body_force_limit,
normalize_action=True,
use_delta=True,
)

# -------------------------------------------------------------------------- #
Expand All @@ -229,69 +241,69 @@ def controller_configs(self):
pd_joint_delta_pos=dict(
arm=arm_pd_joint_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_joint_pos=dict(
arm=arm_pd_joint_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_ee_delta_pos=dict(
arm=arm_pd_ee_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_ee_delta_pose=dict(
arm=arm_pd_ee_delta_pose,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_ee_delta_pose_align=dict(
arm=arm_pd_ee_delta_pose_align,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
# TODO(jigu): how to add boundaries for the following controllers
pd_joint_target_delta_pos=dict(
arm=arm_pd_joint_target_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_ee_target_delta_pos=dict(
arm=arm_pd_ee_target_delta_pos,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_ee_target_delta_pose=dict(
arm=arm_pd_ee_target_delta_pose,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
# Caution to use the following controllers
pd_joint_vel=dict(
arm=arm_pd_joint_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_joint_pos_vel=dict(
arm=arm_pd_joint_pos_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
pd_joint_delta_pos_vel=dict(
arm=arm_pd_joint_delta_pos_vel,
gripper=gripper_pd_joint_pos,
body=body_pd_joint_pos,
body=body_pd_joint_delta_pos,
base=base_pd_joint_vel,
),
)
Expand All @@ -315,7 +327,7 @@ def _after_init(self):
self.r_wheel_link: Link = get_obj_by_name(
self.robot.get_links(), "r_wheel_link"
)
for link in [self.base_link, self.l_wheel_link, self.r_wheel_link]:
for link in [self.l_wheel_link, self.r_wheel_link]:
for body in link._bodies:
cs = body.get_collision_shapes()[0]
cg = cs.get_collision_groups()
Expand All @@ -326,8 +338,8 @@ def _after_init(self):
self.robot.get_links(), "torso_lift_link"
)

self.torso_lift_link: Link = get_obj_by_name(
self.robot.get_links(), "torso_lift_link"
self.head_camera_link: Link = get_obj_by_name(
self.robot.get_links(), "head_camera_link"
)

self.queries: Dict[
Expand Down Expand Up @@ -419,7 +431,7 @@ def is_grasping(self, object: Actor = None, min_impulse=1e-6, max_angle=85):
and np.rad2deg(rangle) <= max_angle
)

return all([lflag, rflag])
return torch.tensor([all([lflag, rflag])], dtype=bool)

def is_static(self, threshold: float = 0.2):
qvel = self.robot.get_qvel()[..., :-2]
Expand Down
4 changes: 2 additions & 2 deletions mani_skill2/assets/robots/fetch/fetch.urdf
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
<parent link="root"/>
<child link="root_arm_1_link_1"/>
<axis xyz="1 0 0"/>
<limit effort="100" lower="-10" upper="10" velocity="0.50"/>
<limit effort="100" lower="-20" upper="20" velocity="0.50"/>
<dynamics damping="10.0" friction="10"/>
</joint>
<joint name="root_y_axis_joint" type="prismatic">
<parent link="root_arm_1_link_1"/>
<child link="root_arm_1_link_2"/>
<axis xyz="0 1 0"/>
<limit effort="100" lower="-10" upper="10" velocity="0.50"/>
<limit effort="100" lower="-20" upper="20" velocity="0.50"/>
<dynamics damping="10.0" friction="10"/>
</joint>
<joint name="root_z_rotation_joint" type="continuous">
Expand Down
2 changes: 1 addition & 1 deletion mani_skill2/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# from .ms1 import *
from .ms2 import *
from .scenes.pick_object import *
from .scenes import *
from .tasks import *
4 changes: 4 additions & 0 deletions mani_skill2/envs/sapien_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,10 @@ def render_human(self):
if self._viewer is None:
self._viewer = Viewer()
self._setup_viewer()
if "render_camera" in self._human_render_cameras:
self._viewer.set_camera_pose(
self._human_render_cameras["render_camera"].camera.global_pose
)

for obj in self._hidden_objects:
obj.show_visual()
Expand Down
28 changes: 28 additions & 0 deletions mani_skill2/envs/scenes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mani_skill2.utils.registration import register_env
from mani_skill2.utils.scene_builder.ai2thor.variants import (
ArchitecTHORSceneBuilder,
ProcTHORSceneBuilder,
RoboTHORSceneBuilder,
iTHORSceneBuilder,
)
from mani_skill2.utils.scene_builder.replicacad.scene_builder import (
ReplicaCADSceneBuilder,
)

from .base_env import SceneManipulationEnv

scene_builders = {
"ReplicaCAD": ReplicaCADSceneBuilder,
"ArchitecTHOR": ArchitecTHORSceneBuilder,
"ProcTHOR": ProcTHORSceneBuilder,
"RoboTHOR": RoboTHORSceneBuilder,
"iTHOR": iTHORSceneBuilder,
}

# Register environments just for benchmarking/exploration and to be creatable by just ID, these don't have any specific tasks designed in them.
for k, scene_builder in scene_builders.items():
register_env(
f"{k}_SceneManipulation-v1",
max_episode_steps=None,
scene_builder_cls=scene_builder,
)(SceneManipulationEnv)
Loading

0 comments on commit 1defbae

Please sign in to comment.