Skip to content

Commit

Permalink
better compatibility for tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Dec 20, 2024
1 parent c7789ca commit 320ac96
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions lib/gpt/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@
from gpt.core.foundation import tensor as foundation, base as foundation_base


def get_mt_entry(self_otype, other_otype):
self_tag = self_otype.__name__
other_tag = other_otype.__name__
if other_tag in self_otype.mtab:
return self_otype.mtab[other_tag]
elif self_tag in other_otype.rmtab:
return other_otype.rmtab[self_tag]
return None


class tensor(foundation_base):
foundation = foundation

Expand Down Expand Up @@ -106,12 +116,10 @@ def norm2(self):

def __mul__(self, other):
if isinstance(other, gpt.tensor):
self_tag = self.otype.__name__
other_tag = other.otype.__name__
if other_tag in self.otype.mtab:
mt = self.otype.mtab[other_tag]
elif self_tag in other.otype.rmtab:
mt = other.otype.rmtab[self_tag]
mt = get_mt_entry(self.otype, other.otype)
if mt is None:
mt = get_mt_entry(self.otype.data_otype(), other.otype.data_otype())
assert mt is not None
a = np.tensordot(self.array, other.array, axes=mt[1])
if len(mt) > 2:
a = np.transpose(a, mt[2])
Expand Down

0 comments on commit 320ac96

Please sign in to comment.