From d5b9adf39602d1cecfcfe9a3825d8b345f29b15f Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Wed, 28 Feb 2024 12:34:38 +0100 Subject: [PATCH] fix automatic einsum segmentation --- lib/gpt/core/einsum.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lib/gpt/core/einsum.py b/lib/gpt/core/einsum.py index 7680e01b..a59c3b5d 100644 --- a/lib/gpt/core/einsum.py +++ b/lib/gpt/core/einsum.py @@ -165,9 +165,19 @@ def process_indices(names, values, epsilon_tensors, sign0): else: index_value[j] = 0 + # now sort by target index + code = sorted(code, key = lambda c: c[1]) + + # now verify that segmentation works assert len(code) % nsegment == 0 + use_segmentation = all([len(set([c[1] for c in code[i:i+nsegment]])) == 1 for i in range(0, len(code), nsegment)]) + if not use_segmentation: + nsegment = 1 + + # create segmentation segments = [(len(code) // nsegment, nsegment)] + # and tensor ein = g.stencil.tensor(tensors_destination[0], [(0, 0, 0, 0)], code, segments) def exec(*src):