-
Notifications
You must be signed in to change notification settings - Fork 11
/
vtrace.py
150 lines (123 loc) · 6.1 KB
/
vtrace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# coding=utf-8
# Copyright 2019 The SEED Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to compute V-trace off-policy actor critic targets.
For details and theory see:
"IMPALA: Scalable Distributed Deep-RL with
Importance Weighted Actor-Learner Architectures"
by Espeholt, Soyer, Munos et al.
See https://arxiv.org/abs/1802.01561 for the full paper.
"""
import collections
import tensorflow as tf
VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages')
def from_importance_weights(
target_action_log_probs, behaviour_action_log_probs,
discounts, rewards, values, bootstrap_value,
clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, lambda_=1.0,
name='vtrace_from_importance_weights'):
r"""V-trace from log importance weights.
Calculates V-trace actor critic targets as described in
"IMPALA: Scalable Distributed Deep-RL with
Importance Weighted Actor-Learner Architectures"
by Espeholt, Soyer, Munos et al.
In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
NUM_ACTIONS refers to the number of actions. This code also supports the
case where all tensors have the same number of additional dimensions, e.g.,
`rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C].
Args:
target_action_log_probs: A float32 tensor of shape [T, B] with
log-probabilities of taking the action by the current policy
behaviour_action_log_probs: A float32 tensor of shape [T, B] with
log-probabilities of taking the action by the behavioural policy
discounts: A float32 tensor of shape [T, B] with discounts encountered when
following the behaviour policy.
rewards: A float32 tensor of shape [T, B] containing rewards generated by
following the behaviour policy.
values: A float32 tensor of shape [T, B] with the value function estimates
wrt. the target policy.
bootstrap_value: A float32 of shape [B] with the value function estimate at
time T.
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
importance weights (rho) when calculating the baseline targets (vs).
rho^bar in the paper. If None, no clipping is applied.
clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If
None, no clipping is applied.
lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). See Remark 2
in paper. Defaults to lambda_=1.
name: The name scope that all V-trace operations will be created in.
Returns:
A VTraceReturns namedtuple (vs, pg_advantages) where:
vs: A float32 tensor of shape [T, B]. Can be used as target to
train a baseline (V(x_t) - vs_t)^2.
pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
advantage in the calculation of policy gradients.
"""
log_rhos = target_action_log_probs - behaviour_action_log_probs
log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32)
discounts = tf.convert_to_tensor(discounts, dtype=tf.float32)
rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
values = tf.convert_to_tensor(values, dtype=tf.float32)
bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32)
if clip_rho_threshold is not None:
clip_rho_threshold = tf.convert_to_tensor(clip_rho_threshold,
dtype=tf.float32)
if clip_pg_rho_threshold is not None:
clip_pg_rho_threshold = tf.convert_to_tensor(clip_pg_rho_threshold,
dtype=tf.float32)
# Make sure tensor ranks are consistent.
rho_rank = log_rhos.shape.ndims # Usually 2.
values.shape.assert_has_rank(rho_rank)
bootstrap_value.shape.assert_has_rank(rho_rank - 1)
discounts.shape.assert_has_rank(rho_rank)
rewards.shape.assert_has_rank(rho_rank)
if clip_rho_threshold is not None:
clip_rho_threshold.shape.assert_has_rank(0)
if clip_pg_rho_threshold is not None:
clip_pg_rho_threshold.shape.assert_has_rank(0)
with tf.name_scope(name):
rhos = tf.exp(log_rhos)
if clip_rho_threshold is not None:
clipped_rhos = tf.minimum(clip_rho_threshold, rhos, name='clipped_rhos')
else:
clipped_rhos = rhos
cs = tf.minimum(1.0, rhos, name='cs')
cs *= tf.convert_to_tensor(lambda_, dtype=tf.float32)
# Append bootstrapped value to get [v1, ..., v_t+1]
values_t_plus_1 = tf.concat(
[values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
acc = tf.zeros_like(bootstrap_value)
vs_minus_v_xs = []
for i in range(int(discounts.shape[0]) - 1, -1, -1):
discount, c, delta = discounts[i], cs[i], deltas[i]
acc = delta + discount * c * acc
vs_minus_v_xs.append(acc)
vs_minus_v_xs = vs_minus_v_xs[::-1]
# Add V(x_s) to get v_s.
vs = tf.add(vs_minus_v_xs, values, name='vs')
# Advantage for policy gradient.
vs_t_plus_1 = tf.concat([
vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
if clip_pg_rho_threshold is not None:
clipped_pg_rhos = tf.minimum(clip_pg_rho_threshold, rhos,
name='clipped_pg_rhos')
else:
clipped_pg_rhos = rhos
pg_advantages = (
clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))
# Make sure no gradients backpropagated through the returned values.
return VTraceReturns(vs=tf.stop_gradient(vs),
pg_advantages=tf.stop_gradient(pg_advantages))