-
Notifications
You must be signed in to change notification settings - Fork 7
/
ddpg_agent.py
119 lines (108 loc) · 5.24 KB
/
ddpg_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import os
from datetime import datetime
import numpy as np
from mpi4py import MPI
from mpi_utils.mpi_utils import sync_networks, sync_grads
from rl_modules.base_agent import BaseAgent
from rl_modules.replay_buffer import replay_buffer
from rl_modules.models import actor, critic
from mpi_utils.normalizer import normalizer
from her_modules.her import her_sampler
"""
ddpg with HER (MPI-version)
"""
class DDPG(BaseAgent):
def __init__(self, args, env, env_params):
super().__init__(args, env, env_params)
# self.args = args
# self.env = env
# self.env_params = env_params
# create the network
self.actor_network = actor(env_params)
self.critic_network = critic(env_params)
# sync the networks across the cpus
sync_networks(self.actor_network)
sync_networks(self.critic_network)
# build up the target network
self.actor_target_network = actor(env_params)
self.critic_target_network = critic(env_params)
# load the weights into the target networks
self.actor_target_network.load_state_dict(self.actor_network.state_dict())
self.critic_target_network.load_state_dict(self.critic_network.state_dict())
# if use gpu
if self.args.cuda:
self.actor_network.cuda()
self.critic_network.cuda()
self.actor_target_network.cuda()
self.critic_target_network.cuda()
# create the optimizer
self.actor_optim = torch.optim.Adam(self.actor_network.parameters(), lr=self.args.lr_actor)
self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic)
# this function will choose action for the agent and do the exploration
def _stochastic_actions(self, input_tensor):
pi = self.actor_network(input_tensor)
action = pi.cpu().numpy().squeeze()
# add the gaussian
action += self.args.noise_eps * self.env_params['action_max'] * np.random.randn(*action.shape)
action = np.clip(action, -self.env_params['action_max'], self.env_params['action_max'])
# random actions...
random_actions = np.random.uniform(low=-self.env_params['action_max'], high=self.env_params['action_max'], \
size=self.env_params['action'])
# choose if use the random actions
action += np.random.binomial(1, self.args.random_eps, 1)[0] * (random_actions - action)
return action
def _deterministic_action(self, input_tensor):
action = self.actor_network(input_tensor)
return action
# update the network
def _update_network(self, future_p=None):
# sample the episodes
sample_batch = self.sample_batch(future_p=future_p)
transitions = sample_batch['transitions']
# start to do the update
obs_norm = self.o_norm.normalize(transitions['obs'])
g_norm = self.g_norm.normalize(transitions['g'])
inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
g_next_norm = self.g_norm.normalize(transitions['g_next'])
inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
# transfer them into the tensor
inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32)
actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32)
r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
if self.args.cuda:
inputs_norm_tensor = inputs_norm_tensor.cuda()
inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
actions_tensor = actions_tensor.cuda()
r_tensor = r_tensor.cuda()
# calculate the target Q value function
with torch.no_grad():
# do the normalization
# concatenate the stuffs
actions_next = self.actor_target_network(inputs_next_norm_tensor)
q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next)
q_next_value = q_next_value.detach()
target_q_value = r_tensor + self.args.gamma * q_next_value
target_q_value = target_q_value.detach()
# clip the q value
clip_return = 1 / (1 - self.args.gamma)
target_q_value = torch.clamp(target_q_value, -clip_return, 0)
# the q loss
real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)
critic_loss = (target_q_value - real_q_value).pow(2).mean()
# the actor loss
actions_real = self.actor_network(inputs_norm_tensor)
actor_loss = -self.critic_network(inputs_norm_tensor, actions_real).mean()
actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
# start to update the network
self.actor_optim.zero_grad()
actor_loss.backward()
sync_grads(self.actor_network)
self.actor_optim.step()
# update the critic_network
self.critic_optim.zero_grad()
critic_loss.backward()
sync_grads(self.critic_network)
self.critic_optim.step()