diff --git a/lib/gpt/core/einsum.py b/lib/gpt/core/einsum.py index a59c3b5d..943353f4 100644 --- a/lib/gpt/core/einsum.py +++ b/lib/gpt/core/einsum.py @@ -166,14 +166,19 @@ def process_indices(names, values, epsilon_tensors, sign0): index_value[j] = 0 # now sort by target index - code = sorted(code, key = lambda c: c[1]) + code = sorted(code, key=lambda c: c[1]) # now verify that segmentation works assert len(code) % nsegment == 0 - use_segmentation = all([len(set([c[1] for c in code[i:i+nsegment]])) == 1 for i in range(0, len(code), nsegment)]) + use_segmentation = all( + [ + len(set([c[1] for c in code[i : i + nsegment]])) == 1 + for i in range(0, len(code), nsegment) + ] + ) if not use_segmentation: nsegment = 1 - + # create segmentation segments = [(len(code) // nsegment, nsegment)]