-
Notifications
You must be signed in to change notification settings - Fork 16
/
rollout.py
56 lines (49 loc) · 1.75 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
class Rollout(object):
""" Rollout Policy """
def __init__(self, model, update_rate):
self.ori_model = model
self.own_model = copy.deepcopy(model)
self.update_rate = update_rate
def get_reward(self, x, num, discriminator):
"""
Inputs: x, num, discriminator
- x: (batch_size, seq_len) input data
- num: rollout number
- discriminator: discrimanator model
"""
rewards = []
batch_size = x.size(0)
seq_len = x.size(1)
for i in range(num):
for l in range(1, seq_len):
data = x[:, 0:l]
samples = self.own_model.sample(batch_size, seq_len, data)
pred = discriminator(samples)
pred = pred.cpu().data[:,1].numpy()
if i == 0:
rewards.append(pred)
else:
rewards[l-1] += pred
# for the last token
pred = discriminator(x)
pred = pred.cpu().data[:, 1].numpy()
if i == 0:
rewards.append(pred)
else:
rewards[seq_len-1] += pred
rewards = np.transpose(np.array(rewards)) / (1.0 * num) # batch_size * seq_len
return rewards
def update_params(self):
dic = {}
for name, param in self.ori_model.named_parameters():
dic[name] = param.data
for name, param in self.own_model.named_parameters():
if name.startswith('emb'):
param.data = dic[name]
else:
param.data = self.update_rate * param.data + (1 - self.update_rate) * dic[name]