-
Notifications
You must be signed in to change notification settings - Fork 0
/
play.py
66 lines (51 loc) · 1.65 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from itertools import count
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from snake_env import Env
from dqn import DQN
# class DQN(nn.Module):
# def __init__(self, n_observations, n_actions):
# super(DQN, self).__init__()
# self.layer1 = nn.Linear(n_observations, 128)
# self.layer2 = nn.Linear(128, 128)
# self.layer3 = nn.Linear(128, n_actions)
#
# def forward(self, x):
# x = F.relu(self.layer1(x))
# x = F.relu(self.layer2(x))
# return self.layer3(x)
device = torch.device("cpu")
env = Env(render_mode="ansi", size=5)
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state.flatten())
model = DQN(n_observations, n_actions).to(device)
model.load_state_dict(torch.load(rf"models/{env.size}.pth"))
state, info = env.reset()
state = torch.tensor(state.flatten(), dtype=torch.float32, device=device).unsqueeze(0)
# total_reward = 0
for t in count():
with torch.no_grad():
action = model(state).max(1).indices.view(1, 1)
observation, reward, terminated, truncated, info = env.step(action.item())
if observation is not None:
observation = observation.flatten()
# reward = torch.tensor([reward], device=device)
# total_reward += reward
done = terminated or truncated
if terminated or truncated:
next_state = None
else:
next_state = torch.tensor(
observation, dtype=torch.float32, device=device
).unsqueeze(0)
state = next_state
if done:
break
else:
os.system("clear")
print(env.render_state(info))
time.sleep(0.1)