-
Notifications
You must be signed in to change notification settings - Fork 0
/
split_buffer.py
89 lines (66 loc) · 2.43 KB
/
split_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import gzip
import numpy as np
from atari_action_utils import LIMITED_ACTION_TO_FULL_ACTION
def get_file_name(root, type, filenumber):
file = '$store$_{}_ckpt.{}.gz'.format(type, filenumber)
return os.path.join(root, file)
root_dir = 'data/download'
out_base_dir = 'data/single_split'
def split_buffer(file_number, base_dir, out_dir, splits, game):
g = gzip.GzipFile(filename=get_file_name(base_dir, 'terminal', file_number))
terminal = np.load(g)
end_idx = np.where(terminal == 1)[0]
g = gzip.GzipFile(filename=get_file_name(base_dir, 'observation', file_number))
obs = np.load(g)
g = gzip.GzipFile(filename=get_file_name(base_dir, 'action', file_number))
action = np.load(g)
g = gzip.GzipFile(filename=get_file_name(base_dir, 'reward', file_number))
reward = np.load(g)
ret = np.zeros_like(reward)
curr_ret = 0
reward = np.clip(reward, -1, 1)
for i in range(end_idx[-1], -1, -1):
if terminal[i] == 1:
curr_ret = 0
curr_ret += reward[i]
ret[i] = curr_ret
action[i] = LIMITED_ACTION_TO_FULL_ACTION[game][action[i]]
split_length = end_idx[-1] // splits
start = 0
i = 0
while start < end_idx[-1]:
end = end_idx[-1]
for idx in end_idx:
length = idx+1-start
if length > split_length:
end = idx+1
break
end = min(end_idx[-1]+1, end)
data = obs[start : end]
print(start,' ',end)
with gzip.GzipFile(os.path.join(out_dir, 'obs-{}.npy.gz'.format(i)), "w") as fh:
np.save(file=fh, arr=data)
with open(os.path.join(out_dir, 'action-{}.npy'.format(i)), 'wb') as fh:
np.save(file=fh, arr=action[start:end])
with open(os.path.join(out_dir, 'reward-{}.npy'.format(i)), 'wb') as fh:
np.save(file=fh, arr=reward[start:end])
with open(os.path.join(out_dir, 'return-{}.npy'.format(i)), 'wb') as fh:
np.save(file=fh, arr=ret[start:end])
with open(os.path.join(out_dir, 'terminal-{}.npy'.format(i)), 'wb') as fh:
np.save(file=fh, arr=terminal[start:end])
i += 1
start = end
games = [
'Pong',
'Qbert',
'DemonAttack',
'SpaceInvaders',
'Breakout',
]
for game in games:
base = os.path.join(root_dir, game)
out = os.path.join(out_base_dir, game)
if not os.path.exists(out):
os.makedirs(out)
split_buffer(50, base, out, 1, game)