Skip to content

Commit

Permalink
fix: fix env_id reset for in-hand-rotating env
Browse files Browse the repository at this point in the history
  • Loading branch information
yzqin committed Feb 29, 2024
1 parent cbc0925 commit f562d09
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions mani_skill2/envs/tasks/dexterity/rotate_single_object_in_hand.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,11 @@ def _load_actors(self):

def _initialize_actors(self, env_idx: torch.Tensor):
with torch.device(self.device):
b = len(env_idx)
# Initialize object pose
self.table_scene.initialize()
pose = self.obj.pose
new_pos = torch.randn((self.num_envs, 3)) * self.obj_init_pos_noise
new_pos = torch.randn((b, 3)) * self.obj_init_pos_noise
# hand_init_height is robot hand position while the 0.03 is a margin to ensure
new_pos[:, 2] = (
torch.abs(new_pos[:, 2]) + self.hand_init_height + self.obj_heights
Expand All @@ -156,9 +157,9 @@ def _initialize_actors(self, env_idx: torch.Tensor):

# Initialize object axis
if self.difficulty_level <= 2:
axis = torch.ones((self.num_envs,), dtype=torch.long) * 2
axis = torch.ones((b,), dtype=torch.long) * 2
else:
axis = torch.randint(0, 3, (self.num_envs,), dtype=torch.long)
axis = torch.randint(0, 3, (b,), dtype=torch.long)
self.rot_dir = F.one_hot(axis, num_classes=3)

# Sample a unit vector on the tangent plane of rotating axis
Expand All @@ -169,24 +170,25 @@ def _initialize_actors(self, env_idx: torch.Tensor):
self.unit_vector = vector
self.prev_unit_vector = vector.clone()
self.success_threshold = torch.pi * 4
self.cum_rotation_angle = torch.zeros((self.num_envs,))
self.cum_rotation_angle = torch.zeros((b,))

# Controller parameters
stiffness = torch.tensor(self.agent.controller.config.stiffness)
damping = torch.tensor(self.agent.controller.config.damping)
force_limit = torch.tensor(self.agent.controller.config.force_limit)
self.controller_param = (
stiffness.expand(self.num_envs, self.agent.robot.dof[0]),
damping.expand(self.num_envs, self.agent.robot.dof[0]),
force_limit.expand(self.num_envs, self.agent.robot.dof[0]),
stiffness.expand(b, self.agent.robot.dof[0]),
damping.expand(b, self.agent.robot.dof[0]),
force_limit.expand(b, self.agent.robot.dof[0]),
)

def _initialize_agent(self, env_idx: torch.Tensor):
with torch.device(self.device):
b = len(env_idx)
dof = self.agent.robot.dof
if isinstance(dof, torch.Tensor):
dof = dof[0]
init_qpos = torch.zeros((self.num_envs, dof))
init_qpos = torch.zeros((b, dof))
self.agent.reset(init_qpos)
self.agent.robot.set_pose(
Pose.create_from_pq(
Expand Down Expand Up @@ -259,6 +261,7 @@ def evaluate(self, **kwargs) -> dict:
success=success,
qf=qf,
power=power,
fail=obj_fall,
)

def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
Expand Down

0 comments on commit f562d09

Please sign in to comment.