-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
66 lines (50 loc) · 1.97 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gymnasium as gym
import numpy as np
import torch
import torchvision.transforms as T
class SkipFrame(gym.Wrapper):
def __init__(self, env, skip):
super().__init__(env)
self._skip = skip
def step(self, action):
total_reward = 0.0
for _ in range(self._skip):
obs, reward, done, trunk, info = self.env.step(action)
total_reward += reward
if done or trunk:
break
return obs, total_reward, done, trunk, info
class GrayScaleObservation(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
obs_shape = self.observation_space.shape[:2]
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=obs_shape, dtype=np.uint8
)
def permute_orientation(self, observation):
observation = np.transpose(observation, (2, 0, 1))
return torch.tensor(observation.copy(), dtype=torch.float)
def observation(self, observation):
observation = self.permute_orientation(observation)
return T.Grayscale()(observation)
class ResizeObservation(gym.ObservationWrapper):
def __init__(self, env, shape):
super().__init__(env)
self.shape = tuple(shape)
obs_shape = self.shape + self.observation_space.shape[2:]
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=obs_shape, dtype=np.uint8
)
def observation(self, observation):
transforms = T.Compose(
[T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]
)
return transforms(observation).squeeze(0)
def preprocess(env, skip=4, grayscale=True, shape=(72, 128), num_stack=4):
env = SkipFrame(env, skip=skip)
if grayscale:
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=shape)
env = gym.wrappers.FrameStack(env, num_stack=num_stack)
# assert env.observation_space.shape == (num_stack, *shape)
return env