-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathrobomimic_lowdim.py
143 lines (124 loc) · 4.6 KB
/
robomimic_lowdim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Environment wrapper for Robomimic environments with state observations.
Also return done=False since we do not terminate episode early.
Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/env/robomimic/robomimic_lowdim_wrapper.py
For consistency, we will use Dict{} for the observation space, with the key "state" for the state observation.
"""
import numpy as np
import gym
from gym import spaces
import imageio
class RobomimicLowdimWrapper(gym.Env):
def __init__(
self,
env,
normalization_path=None,
low_dim_keys=[
"robot0_eef_pos",
"robot0_eef_quat",
"robot0_gripper_qpos",
"object",
],
clamp_obs=False,
init_state=None,
render_hw=(256, 256),
render_camera_name="agentview",
):
self.env = env
self.init_state = init_state
self.render_hw = render_hw
self.render_camera_name = render_camera_name
self.video_writer = None
self.clamp_obs = clamp_obs
# set up normalization
self.normalize = normalization_path is not None
if self.normalize:
normalization = np.load(normalization_path)
self.obs_min = normalization["obs_min"]
self.obs_max = normalization["obs_max"]
self.action_min = normalization["action_min"]
self.action_max = normalization["action_max"]
# setup spaces
low = np.full(env.action_dimension, fill_value=-1)
high = np.full(env.action_dimension, fill_value=1)
self.action_space = gym.spaces.Box(
low=low,
high=high,
shape=low.shape,
dtype=low.dtype,
)
self.obs_keys = low_dim_keys
self.observation_space = spaces.Dict()
obs_example_full = self.env.get_observation()
obs_example = np.concatenate(
[obs_example_full[key] for key in self.obs_keys], axis=0
)
low = np.full_like(obs_example, fill_value=-1)
high = np.full_like(obs_example, fill_value=1)
self.observation_space["state"] = spaces.Box(
low=low,
high=high,
shape=low.shape,
dtype=np.float32,
)
def normalize_obs(self, obs):
obs = 2 * (
(obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5
) # -> [-1, 1]
if self.clamp_obs:
obs = np.clip(obs, -1, 1)
return obs
def unnormalize_action(self, action):
action = (action + 1) / 2 # [-1, 1] -> [0, 1]
return action * (self.action_max - self.action_min) + self.action_min
def get_observation(self, raw_obs):
obs = {"state": np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0)}
if self.normalize:
obs["state"] = self.normalize_obs(obs["state"])
return obs
def seed(self, seed=None):
if seed is not None:
np.random.seed(seed=seed)
else:
np.random.seed()
def reset(self, options={}, **kwargs):
"""Ignore passed-in arguments like seed"""
# Close video if exists
if self.video_writer is not None:
self.video_writer.close()
self.video_writer = None
# Start video if specified
if "video_path" in options:
self.video_writer = imageio.get_writer(options["video_path"], fps=30)
# Call reset
new_seed = options.get(
"seed", None
) # used to set all environments to specified seeds
if self.init_state is not None:
# always reset to the same state to be compatible with gym
raw_obs = self.env.reset_to({"states": self.init_state})
elif new_seed is not None:
self.seed(seed=new_seed)
raw_obs = self.env.reset()
else:
# random reset
raw_obs = self.env.reset()
return self.get_observation(raw_obs)
def step(self, action):
if self.normalize:
action = self.unnormalize_action(action)
raw_obs, reward, done, info = self.env.step(action)
obs = self.get_observation(raw_obs)
# render if specified
if self.video_writer is not None:
video_img = self.render(mode="rgb_array")
self.video_writer.append_data(video_img)
return obs, reward, False, info
def render(self, mode="rgb_array"):
h, w = self.render_hw
return self.env.render(
mode=mode,
height=h,
width=w,
camera_name=self.render_camera_name,
)