From d5b9adf39602d1cecfcfe9a3825d8b345f29b15f Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Wed, 28 Feb 2024 12:34:38 +0100 Subject: [PATCH 1/2] fix automatic einsum segmentation --- lib/gpt/core/einsum.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lib/gpt/core/einsum.py b/lib/gpt/core/einsum.py index 7680e01b..a59c3b5d 100644 --- a/lib/gpt/core/einsum.py +++ b/lib/gpt/core/einsum.py @@ -165,9 +165,19 @@ def process_indices(names, values, epsilon_tensors, sign0): else: index_value[j] = 0 + # now sort by target index + code = sorted(code, key = lambda c: c[1]) + + # now verify that segmentation works assert len(code) % nsegment == 0 + use_segmentation = all([len(set([c[1] for c in code[i:i+nsegment]])) == 1 for i in range(0, len(code), nsegment)]) + if not use_segmentation: + nsegment = 1 + + # create segmentation segments = [(len(code) // nsegment, nsegment)] + # and tensor ein = g.stencil.tensor(tensors_destination[0], [(0, 0, 0, 0)], code, segments) def exec(*src): From 17f78a3a15f75db509c19e7f7465e4f0cbd6e242 Mon Sep 17 00:00:00 2001 From: Christoph Lehner Date: Wed, 28 Feb 2024 12:39:03 +0100 Subject: [PATCH 2/2] style --- lib/gpt/core/einsum.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lib/gpt/core/einsum.py b/lib/gpt/core/einsum.py index a59c3b5d..943353f4 100644 --- a/lib/gpt/core/einsum.py +++ b/lib/gpt/core/einsum.py @@ -166,14 +166,19 @@ def process_indices(names, values, epsilon_tensors, sign0): index_value[j] = 0 # now sort by target index - code = sorted(code, key = lambda c: c[1]) + code = sorted(code, key=lambda c: c[1]) # now verify that segmentation works assert len(code) % nsegment == 0 - use_segmentation = all([len(set([c[1] for c in code[i:i+nsegment]])) == 1 for i in range(0, len(code), nsegment)]) + use_segmentation = all( + [ + len(set([c[1] for c in code[i : i + nsegment]])) == 1 + for i in range(0, len(code), nsegment) + ] + ) if not use_segmentation: nsegment = 1 - + # create segmentation segments = [(len(code) // nsegment, nsegment)]