From 320ac96ba714b35c09f0a276a5c9d1cc2f4b2f17 Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Fri, 20 Dec 2024 12:56:54 +0100 Subject: [PATCH] better compatibility for tensors --- lib/gpt/core/tensor.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/lib/gpt/core/tensor.py b/lib/gpt/core/tensor.py index f7e2975b..11e5ed2f 100644 --- a/lib/gpt/core/tensor.py +++ b/lib/gpt/core/tensor.py @@ -22,6 +22,16 @@ from gpt.core.foundation import tensor as foundation, base as foundation_base +def get_mt_entry(self_otype, other_otype): + self_tag = self_otype.__name__ + other_tag = other_otype.__name__ + if other_tag in self_otype.mtab: + return self_otype.mtab[other_tag] + elif self_tag in other_otype.rmtab: + return other_otype.rmtab[self_tag] + return None + + class tensor(foundation_base): foundation = foundation @@ -106,12 +116,10 @@ def norm2(self): def __mul__(self, other): if isinstance(other, gpt.tensor): - self_tag = self.otype.__name__ - other_tag = other.otype.__name__ - if other_tag in self.otype.mtab: - mt = self.otype.mtab[other_tag] - elif self_tag in other.otype.rmtab: - mt = other.otype.rmtab[self_tag] + mt = get_mt_entry(self.otype, other.otype) + if mt is None: + mt = get_mt_entry(self.otype.data_otype(), other.otype.data_otype()) + assert mt is not None a = np.tensordot(self.array, other.array, axes=mt[1]) if len(mt) > 2: a = np.transpose(a, mt[2])