diff --git a/dsgd/DSModelMultiQ.py b/dsgd/DSModelMultiQ.py index d43ab3b..e2a7c37 100644 --- a/dsgd/DSModelMultiQ.py +++ b/dsgd/DSModelMultiQ.py @@ -2,8 +2,10 @@ import dill # import pickle from torch import nn +from torch.nn.functional import pad import numpy as np from scipy.stats import norm +from itertools import count from dsgd.DSRule import DSRule from dsgd.core import create_random_maf_k @@ -27,6 +29,7 @@ def __init__(self, k, precompute_rules=False): self.precompute_rules = precompute_rules self.rmap = {} self.active_rules = [] + self._all_rules = None def add_rule(self, pred, m_sing=None, m_uncert=None): """ @@ -54,33 +57,48 @@ def forward(self, X): :param X: Set of inputs :return: Set of prediction for each input in one hot encoding format """ - out = torch.zeros(len(X), self.k) ms = torch.stack(self._params) - for i in range(len(X)): - sel = self._select_rules(X[i, 1:], int(X[i, 0].item())) - if len(sel) == 0: - # raise RuntimeError("No rule especified for input No %d" % i) - # print("Warning: No rule especified for input No %d" % i) - out[i] = torch.ones((self.k,)) / self.k - else: - mt = torch.index_select(ms, 0, torch.LongTensor(sel)) - qt = mt[:, :-1] + mt[:, -1].view(-1, 1) * torch.ones_like(mt[:, :-1]) - res = qt.prod(0) - # if torch.isnan(res).any(): - # print(self._params) - # print(mt) - # print(qt) - # print(res) - # raise RuntimeError("NaN found in computation") - if res.sum().item() <= 1e-16: - res = res + 1e-16 - out[i] = res / res.sum() - else: - out[i] = res / res.sum() + # transform to commonalities before selecting the rules that apply + qs = ms[:, :-1] + ms[:, -1].view(-1, 1) * torch.ones_like(ms[:, :-1]) + qt = qs.repeat(len(X), 1, 1) + vectors, indices = X[:, 1:], X[:, 0].long() + sel = self._select_all_rules(vectors, indices) + # replace rules that don't apply with ones + qt[sel] = 1 + res = qt.prod(1) + res2 = res.clone() + res[res2.sum(1) <= 1e-16] += 1e-16 + out = res / res.sum(1, keepdim=True) return out def clear_rmap(self): self.rmap = {} + self._all_rules = None + + def _select_all_rules(self, X, indices): + """ + This works based on the assumption that indices will be + provided in order. Otherwise, the function may return uninitialized + values. + :return a bool tensor with shape (len(X), num_rules) with Trues for the rules that don't apply + """ + if self._all_rules is None: + self._all_rules = torch.zeros(0, self.n, dtype=torch.bool) + max_index = torch.max(indices) + len_all_rules = len(self._all_rules) + if max_index < len_all_rules: + return self._all_rules[indices] + else: + desired_len = max_index + 1 + padding = (0, 0, 0, desired_len-len_all_rules) + self._all_rules = pad(self._all_rules, padding) + sel = torch.zeros(len(X), self.n, dtype=torch.bool) + X = X.data.numpy() + for i, sample, index in zip(count(), X, indices): + for j in range(self.n): + sel[i, j] = not bool(self.preds[j](sample)) + self._all_rules[index] = sel[i] + return sel def normalize(self): """