-
Notifications
You must be signed in to change notification settings - Fork 0
/
memory.py
81 lines (68 loc) · 2.96 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from collections import namedtuple
import random
State = namedtuple('State', ('obs', 'description', 'inventory'))
Transition = namedtuple('Transition', ('state', 'act',
'reward', 'next_state', 'next_acts', 'done'))
Episode = namedtuple('Episode', ('states', 'acts',
'rewards', 'next_acts', 'dones'))
class ReplayMemory(object):
def __init__(self, capacity, obj_type=Transition):
self.capacity = capacity
self.memory = []
self.position = 0
self.obj_type = obj_type
def push(self, *args):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args) if self.obj_type == Transition else Episode(
*args)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def pull_all(self):
current_transitions = self.memory
self.memory = []
self.position = 0
return current_transitions
def __len__(self):
return len(self.memory)
class PrioritizedReplayMemory(object):
def __init__(self, capacity, priority_fraction):
self.priority_fraction = priority_fraction
self.alpha_capacity = int(capacity * priority_fraction)
self.beta_capacity = capacity - self.alpha_capacity
self.alpha_memory, self.beta_memory = [], []
self.alpha_position, self.beta_position = 0, 0
def clear_alpha(self):
self.alpha_memory = []
self.alpha_position = 0
def push(self, transition, is_prior=False):
"""Saves a transition."""
if self.priority_fraction == 0.0:
is_prior = False
if is_prior:
if len(self.alpha_memory) < self.alpha_capacity:
self.alpha_memory.append(None)
self.alpha_memory[self.alpha_position] = transition
self.alpha_position = (
self.alpha_position + 1) % self.alpha_capacity
else:
if len(self.beta_memory) < self.beta_capacity:
self.beta_memory.append(None)
self.beta_memory[self.beta_position] = transition
self.beta_position = (self.beta_position + 1) % self.beta_capacity
def sample(self, batch_size):
if self.priority_fraction == 0.0:
from_beta = min(batch_size, len(self.beta_memory))
res = random.sample(self.beta_memory, from_beta)
else:
from_alpha = min(int(self.priority_fraction *
batch_size), len(self.alpha_memory))
from_beta = min(
batch_size - int(self.priority_fraction * batch_size), len(self.beta_memory))
res = random.sample(self.alpha_memory, from_alpha) + \
random.sample(self.beta_memory, from_beta)
random.shuffle(res)
return res
def __len__(self):
return len(self.alpha_memory) + len(self.beta_memory)