From 52073343288b88d8f7846abc8fd0ac71b51a2218 Mon Sep 17 00:00:00 2001 From: Vicente Reyes Date: Mon, 2 Mar 2020 15:11:41 +0900 Subject: [PATCH] fix memoization when selecting rules by batches --- dsgd/DSModelMultiQ.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/dsgd/DSModelMultiQ.py b/dsgd/DSModelMultiQ.py index 3c32208..b1c535d 100644 --- a/dsgd/DSModelMultiQ.py +++ b/dsgd/DSModelMultiQ.py @@ -2,8 +2,10 @@ import dill # import pickle from torch import nn +from torch.nn.functional import pad import numpy as np from scipy.stats import norm +from itertools import count from dsgd.DSRule import DSRule from dsgd.core import create_random_maf_k @@ -59,7 +61,8 @@ def forward(self, X): # transform to commonalities before selecting the rules that apply qs = ms[:, :-1] + ms[:, -1].view(-1, 1) * torch.ones_like(ms[:, :-1]) qt = qs.repeat(len(X), 1, 1) - sel = self._select_all_rules(X) + vectors, indices = X[:, 1:], X[:, 0].long() + sel = self._select_all_rules(vectors, indices) # replace rules that don't apply with ones qt[~sel] = 1 res = qt.prod(1) @@ -71,16 +74,28 @@ def clear_rmap(self): self.rmap = {} self._all_rules = None - def _select_all_rules(self, X): - if self._all_rules is not None: - return self._all_rules + def _select_all_rules(self, X, indices): + """ + This works based on the assumption that indices will be + provided in order. Otherwise, the function may return uninitialized + values. + """ + if self._all_rules is None: + self._all_rules = torch.zeros(0, self.n, dtype=torch.bool) + max_index = torch.max(indices) + len_all_rules = len(self._all_rules) + if max_index < len_all_rules: + return self._all_rules[indices] + else: + desired_len = max_index + 1 + padding = (0, 0, 0, desired_len-len_all_rules) + self._all_rules = pad(self._all_rules, padding) sel = torch.zeros(len(X), self.n, dtype=torch.bool) X = X.data.numpy() - for i, sample in enumerate(X): + for i, sample, index in zip(count(), X, indices): for j in range(self.n): - sel[i, j] = bool(self.preds[j](sample[1:])) - if self.precompute_rules: - self._all_rules = sel + sel[i, j] = bool(self.preds[j](sample)) + self._all_rules[index] = sel[i] return sel def normalize(self):