-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenvironment.py
62 lines (48 loc) · 1.96 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gym
import torch
import torchvision
import random
games = {
'breakout' : 'BreakoutDeterministic-v4',
'asterix': 'AsterixDeterministic-v4',
'seaquest': 'SeaquestDeterministic-v4',
}
class Environment:
def __init__(self, game='breakout'):
if game.lower() not in games:
raise ValueError(f"Game {game} not supported by this environment.")
self.env = gym.make(games[game.lower()])
self.ale = self.env.ale
self.spec = self.env.spec
self.action_space = self.env.action_space
self.metadata = self.env.metadata
self.observation_space = self.env.observation_space
self.reward_range = self.env.reward_range
self.history = torch.zeros((4, 84, 84))
self.noop_steps = 10
self.transforms = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.Grayscale(),
lambda x : torchvision.transforms.functional.crop(x, 25, 8, 180, 144),
torchvision.transforms.Resize((84, 84), 0),
torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize(mean = mean, std = std)
])
def reset(self, eval=False):
self.env.reset()
self.state, _, _, _ = self.step(1) # experiment to start the epoch faster. Could also tile zeros.
if eval:
for _ in range(random.randint(1, self.noop_steps)):
_, _, _, _ = self.env.step(1)
def render(self, mode=None):
return self.env.render(mode)
def clip(self, reward):
return min(reward, 1)
def step(self, action):
observation, reward, done, info = self.env.step(action)
clipped_reward = self.clip(reward)
info['unclipped_reward'] = reward
self.history = torch.cat([self.history[1:], self.transforms(observation)])
return self.history.unsqueeze(0), clipped_reward, done, info
def close(self):
return self.env.close()