Skip to content

Commit

Permalink
fix memoization when selecting rules by batches
Browse files Browse the repository at this point in the history
  • Loading branch information
VichoReyes committed Mar 2, 2020
1 parent d88f09b commit 5207334
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions dsgd/DSModelMultiQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import dill
# import pickle
from torch import nn
from torch.nn.functional import pad
import numpy as np
from scipy.stats import norm
from itertools import count

from dsgd.DSRule import DSRule
from dsgd.core import create_random_maf_k
Expand Down Expand Up @@ -59,7 +61,8 @@ def forward(self, X):
# transform to commonalities before selecting the rules that apply
qs = ms[:, :-1] + ms[:, -1].view(-1, 1) * torch.ones_like(ms[:, :-1])
qt = qs.repeat(len(X), 1, 1)
sel = self._select_all_rules(X)
vectors, indices = X[:, 1:], X[:, 0].long()
sel = self._select_all_rules(vectors, indices)
# replace rules that don't apply with ones
qt[~sel] = 1
res = qt.prod(1)
Expand All @@ -71,16 +74,28 @@ def clear_rmap(self):
self.rmap = {}
self._all_rules = None

def _select_all_rules(self, X):
if self._all_rules is not None:
return self._all_rules
def _select_all_rules(self, X, indices):
"""
This works based on the assumption that indices will be
provided in order. Otherwise, the function may return uninitialized
values.
"""
if self._all_rules is None:
self._all_rules = torch.zeros(0, self.n, dtype=torch.bool)
max_index = torch.max(indices)
len_all_rules = len(self._all_rules)
if max_index < len_all_rules:
return self._all_rules[indices]
else:
desired_len = max_index + 1
padding = (0, 0, 0, desired_len-len_all_rules)
self._all_rules = pad(self._all_rules, padding)
sel = torch.zeros(len(X), self.n, dtype=torch.bool)
X = X.data.numpy()
for i, sample in enumerate(X):
for i, sample, index in zip(count(), X, indices):
for j in range(self.n):
sel[i, j] = bool(self.preds[j](sample[1:]))
if self.precompute_rules:
self._all_rules = sel
sel[i, j] = bool(self.preds[j](sample))
self._all_rules[index] = sel[i]
return sel

def normalize(self):
Expand Down

0 comments on commit 5207334

Please sign in to comment.