diff --git a/README.md b/README.md index c30d7f84..325adea9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +

+ +

+ **SLOTHY** - **S**uper (**L**azy) **O**ptimization of **T**ricky **H**andwritten assembl**Y** - is an assembly-level superoptimizer for: 1. Instruction scheduling @@ -6,7 +10,7 @@ for: SLOTHY is generic in the target architecture and microarchitecture. This repository provides instantiations for the the Cortex-M55 and Cortex-M85 CPUs implementing Armv8.1-M + Helium, and the Cortex-A55 and Cortex-A72 -CPUs implementing Armv8-A + Neon. There is an experimental model for Cortex-X/Neoverse-V cores. +CPUs implementing Armv8-A + Neon. There is an experimental model for Cortex-X/Neoverse-V cores. SLOTHY is discussed in [Fast and Clean: Auditable high-performance assembly via constraint solving](https://eprint.iacr.org/2022/1303). @@ -16,10 +20,10 @@ SLOTHY enables a development workflow where developers write 'clean' assembly by ### How it works -SLOTHY is essentially a constraint solver frontend: It converts the input source into a data flow graph and +SLOTHY is essentially a constraint solver frontend: It converts the input source into a data flow graph and builds a constraint model capturing valid instruction schedulings, register renamings, and periodic loop -interleavings. The model is passed to an external constraint solver and, upon success, -a satisfying assignment converted back into the final code. Currently, SLOTHY uses +interleavings. The model is passed to an external constraint solver and, upon success, +a satisfying assignment converted back into the final code. Currently, SLOTHY uses [Google OR-Tools](https://developers.google.com/optimization) as its constraint solver backend. ### Performance @@ -51,9 +55,11 @@ and build from scratch, e.g. as follows (also available as [submodules/setup-ort for convenience): ``` +% apt install -y git build-essential python3-pip cmake swig % git submodule init % git submodule update % cd submodules/or-tools +% git apply ../0001-Pin-pybind11_protobuf-commit-in-cmake-files.patch % mkdir build % cmake -S. -Bbuild -DBUILD_PYTHON:BOOL=ON % make -C build -j8 @@ -270,4 +276,4 @@ The [examples](examples/naive) directory contains numerous exemplary assembly sn `python3 example.py --examples={YOUR_EXAMPLE}`. See `python3 examples.py --help` for the list of all available examples. The use of SLOTHY from the command line is illustrated in [scripts/](scripts/) supporting the real-world optimizations -for the NTT, FFT and X25519 discussed in [Fast and Clean: Auditable high-performance assembly via constraint solving](https://eprint.iacr.org/2022/1303). \ No newline at end of file +for the NTT, FFT and X25519 discussed in [Fast and Clean: Auditable high-performance assembly via constraint solving](https://eprint.iacr.org/2022/1303). diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 00000000..d8bc21f9 --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,49 @@ +--- +layout: default +--- + +## Frequently asked questions + +[back](index.md) + +#### Is SLOTHY a peephole optimizer? + +No. SLOTHY is a _fixed-instruction_ super-optimizer: It keeps instructions and optimizes +register allocation, instruction scheduling, and software pipelining. It is the developer's or another tool's +responsibility to map the workload at hand to the target architecture. + + + +#### Is SLOTHY better than {name your favourite superoptimizer}? + +Most likely, they serve different purposes. SLOTHY aims to do one thing well: Optimization _after_ instruction selection. +It is thus independent of and potentially combinable with superoptimizers operating at earlier stages of the code-generation process, such as [souper](https://github.com/google/souper) and [CryptOpt](https://github.com/0xADE1A1DE/CryptOpt). + +#### Does SLOTHY support x86? + +The core of SLOTHY is architecture- and microarchitecture-agnostic and can accommodate x86. As it stands, however, +there is no model of the x86 architecture. Feel free to build one! + +#### Does SLOTHY support RISC-V? + +As for x86. + +#### Is SLOTHY formally verified? + +No. Arguably, that wouldn't be a good use of time. The more relevant question is the following: + +#### Is SLOTHY-generated code formally verified to be equivalent to the input code? + +Not yet. SLOTHY runs a self-check confirming that input and output have isomorphic data flow graphs, +but pitfalls remain, such as bad user configurations allowing SLOTHY to clobber a register that's not +meant to be reserved. More work is needed for formal verification of the equivalence of input +and output. + +#### Why is my question not here? + +Ping us! ([GitHub](https://github.com/slothy-optimizer/slothy/issues), or see [paper](https://eprint.iacr.org/2022/1303.pdf) for +contact information). \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 2d7d04a3..7c3d8df5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,12 +12,14 @@ super-optimizes: `SLOTHY` enables a development workflow where developers write 'clean' assembly by hand, emphasizing the logic of the computation, while `SLOTHY` automates microarchitecture-specific micro-optimizations. Since `SLOTHY` does not change instructions, and scheduling/allocation optimizations are tightly controlled through configurable and extensible -constraints, the developer keeps close control over the final assembly, while being freed from the most tedious and -readability- and verifiability-impeding micro-optimizations. +constraints, the developer keeps close control over the final assembly, while being freed from tedious +micro-optimizations. + +See also [FAQ](faq.md) #### Architecture/Microarchitecture support -`SLOTHY` is generic in the target architecture and microarchitecture. So far, it supports Cortex-M55 and Cortex-M85 +`SLOTHY` is generic in the target architecture and microarchitecture. It currently supports Cortex-M55 and Cortex-M85 implementing Armv8.1-M + Helium, and Cortex-A55 and Cortex-A72 implementing Armv8-A + Neon. Moreover, there is an experimental model for Cortex-X/Neoverse-V cores. diff --git a/docs/slothy_logo.png b/docs/slothy_logo.png index 27f8bb74..00e3102e 100644 Binary files a/docs/slothy_logo.png and b/docs/slothy_logo.png differ diff --git a/example.py b/example.py index 1feb4800..3f48cd6e 100644 --- a/example.py +++ b/example.py @@ -25,31 +25,35 @@ # Author: Hanno Becker # -import argparse, logging, sys -from io import StringIO +import argparse +import logging +import sys -from slothy.slothy import Slothy -from slothy.core import Config +from slothy import Slothy, Config -import targets.arm_v81m.arch_v81m as Arch_Armv81M -import targets.arm_v81m.cortex_m55r1 as Target_CortexM55r1 -import targets.arm_v81m.cortex_m85r1 as Target_CortexM85r1 +import slothy.targets.arm_v81m.arch_v81m as Arch_Armv81M +import slothy.targets.arm_v81m.cortex_m55r1 as Target_CortexM55r1 +import slothy.targets.arm_v81m.cortex_m85r1 as Target_CortexM85r1 -import targets.aarch64.aarch64_neon as AArch64_Neon -import targets.aarch64.cortex_a55 as Target_CortexA55 -import targets.aarch64.cortex_a72_frontend as Target_CortexA72 +import slothy.targets.aarch64.aarch64_neon as AArch64_Neon +import slothy.targets.aarch64.cortex_a55 as Target_CortexA55 +import slothy.targets.aarch64.cortex_a72_frontend as Target_CortexA72 target_label_dict = {Target_CortexA55: "a55", Target_CortexA72: "a72", Target_CortexM55r1: "m55", Target_CortexM85r1: "m85"} +class ExampleException(Exception): + """Exception thrown when an example goes wrong""" class Example(): + """Common boilerplate for SLOTHY examples""" + def __init__(self, infile, name=None, funcname=None, suffix="opt", rename=False, outfile="", arch=Arch_Armv81M, target=Target_CortexM55r1, **kwargs): - if name == None: + if name is None: name = infile self.arch = arch @@ -61,7 +65,7 @@ def __init__(self, infile, name=None, funcname=None, suffix="opt", self.outfile = f"{infile}_{self.suffix}_{target_label_dict[self.target]}" else: self.outfile = f"{outfile}_{self.suffix}_{target_label_dict[self.target]}" - if funcname == None: + if funcname is None: self.funcname = self.infile subfolder = "" if self.arch == AArch64_Neon: @@ -1127,8 +1131,8 @@ def run_example(name, debug=False): if e.name == name: ex = e break - if ex == None: - raise Exception(f"Could not find example {name}") + if ex is None: + raise ExampleException(f"Could not find example {name}") ex.run(debug=debug) for e in todo: diff --git a/examples/misc/gen_roots.py b/examples/misc/gen_roots.py index eb41a87b..ebf63201 100644 --- a/examples/misc/gen_roots.py +++ b/examples/misc/gen_roots.py @@ -21,28 +21,37 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import math, sys +"""Helper script for the generation of twiddle factors for various NTTs""" + +import math + +class NttRootGenInvalidParameters(Exception): + """Invalid parameters for NTT root generation""" class NttRootGen(): + """Helper class for the generation of NTT twiddle factors""" def __init__(self,*, size, modulus, root, layers, - print_label=False, - pad = [], - bitsize = 16, - inverse = False, + print_label = False, + pad = None, + bitsize = 16, + inverse = False, vector_length = 128, - word_offset_mod_4=None, + word_offset_mod_4 = None, incomplete_root = True, - widen_single_twiddles_to_words=True, - block_strided_twiddles=True, - negacyclic = True, + widen_single_twiddles_to_words = True, + block_strided_twiddles = True, + negacyclic = True, iters = None): + if pad is None: + pad = [] - assert bitsize in [16,32] + if bitsize not in [16, 32]: + raise NttRootGenInvalidParameters("Invalid bit width") self.pad = pad self.print_label=print_label @@ -76,7 +85,7 @@ def __init__(self,*, # Need an odd prime modulus if self.modulus % 2 == 0: - raise Exception("Modulus must be odd") + raise NttRootGenInvalidParameters("Modulus must be odd") self._inv_mod = pow(self.modulus,-1,2**self.bitsize) # Check that we've indeed been given a root of unity of the correct order @@ -87,12 +96,12 @@ def __init__(self,*, self.log2size = int(math.log(size,2)) if size != pow(2,self.log2size): - raise Exception(f"Size {size} not a power of 2") + raise NttRootGenInvalidParameters(f"Size {size} not a power of 2") self.layers = layers self.incompleteness_factor = 2**(self.log2size - self.layers) - if iters == None: + if iters is None: if self.layers % 2 == 0: self.iters = [(x,2) for x in range(0,self.layers,2)] else: @@ -106,18 +115,22 @@ def __init__(self,*, if ( pow(root, real_root_order, modulus) != 1 or pow(root, real_root_order // 2, modulus) == 1 ): - raise Exception(f"{root} is not a primitive {real_root_order}-th root of unity modulo {modulus}") + raise NttRootGenInvalidParameters(f"{root} is not a primitive {real_root_order}-th " + f"root of unity modulo {modulus}") self.radixes = [2] * self.log2size def get_root_pow(self, exp): + """Returns specific power of base root of unity""" + if not exp % self.incompleteness_factor == 0: - raise Exception(f"Invalid exponent {exp} for incompleteness factor {self.incompleteness_factor}") + raise NttRootGenInvalidParameters(f"Invalid exponent {exp} for incompleteness " + f"factor {self.incompleteness_factor}") if self.incomplete_root: exp = exp // self.incompleteness_factor return pow(self.root,exp,self.modulus) - def _prepare_root(self,root,layer=None): + def _prepare_root(self,root): # Force _signed_ representation of root? if root > self.modulus // 2: @@ -143,6 +156,8 @@ def _bitrev_list(self,num,radix_list): return result def root_of_unity_for_block(self,layer,block): + """Returns the twiddle factor to be used for a specific layer and block""" + actual_layer = layer if self.negacyclic: block += pow(2,layer) @@ -157,10 +172,14 @@ def root_of_unity_for_block(self,layer,block): if self.inverse: log = (self.root_order - log) % self.root_order root = self.get_root_pow(log) - root, root_twisted = self._prepare_root(root,layer) + root, root_twisted = self._prepare_root(root) return root, root_twisted - def roots_of_unity_for_layer_core(self, layer, merged): + def _roots_of_unity_for_layer_core(self, layer, merged): + + if not merged in [1,2,3,4]: + raise NttRootGenInvalidParameters("Invalid layer merge") + for cur_block in range(0,2**layer): if merged == 1: root, root_twisted = self.root_of_unity_for_block(layer, cur_block) @@ -204,11 +223,13 @@ def roots_of_unity_for_layer_core(self, layer, merged): if layer in self.pad: yield ([root0, root1, root2, root3, root4, root5, root6, 0], - [root0_tw, root1_tw, root2_tw, root3_tw, root4_tw, root5_tw, root6_tw, 0]) + [root0_tw, root1_tw, root2_tw, root3_tw, root4_tw, + root5_tw, root6_tw, 0]) else: yield ([root0, root1, root2, root3, root4, root5, root6], [root0_tw, root1_tw, root2_tw, root3_tw, root4_tw, root5_tw, root6_tw]) - elif merged == 4: + else: + assert merged == 4 # Compute the roots of unity that we need at this stage fst_layer = layer + 0 snd_layer = layer + 1 @@ -261,10 +282,10 @@ def roots_of_unity_for_layer_core(self, layer, merged): root5_tw, root6_tw, root7_tw, root8_tw, root9_tw, root10_tw, root11_tw, root12_tw, root13_tw, root14_tw]) - else: - raise Exception("Something went wrong") def roots_of_unity_for_layer(self, layer, merged): + """Generator yielding the twiddle factors for a number of merged layers""" + num_blocks = 2 ** layer block_size = self.size // num_blocks butterfly_size = block_size // 2 ** merged @@ -273,7 +294,7 @@ def roots_of_unity_for_layer(self, layer, merged): if butterfly_size < self.vector_length // self.bitsize: stride = (self.vector_length // self.bitsize) // butterfly_size - all_root_pairs = list(self.roots_of_unity_for_layer_core(layer, merged)) + all_root_pairs = list(self._roots_of_unity_for_layer_core(layer, merged)) all_roots = [ x[0] for x in all_root_pairs ] all_roots_twisted = [ x[1] for x in all_root_pairs ] num_pairs = len(all_root_pairs) @@ -288,7 +309,8 @@ def roots_of_unity_for_layer(self, layer, merged): res = [(z,stride) for x in roots for y in x for z in y] yield from res - def get_roots_of_unity_core(self): + def get_roots_of_unity(self): + """Yields roots of unity for NTT""" iters = self.iters.copy() if self.inverse: iters.reverse() @@ -297,17 +319,17 @@ def get_roots_of_unity_core(self): yield f"roots_l{''.join([str(i) for i in range(cur_iter,cur_iter+merged)])}:" yield from self.roots_of_unity_for_layer(cur_iter,merged) - def get_roots_of_unity_real(self): + def _get_roots_of_unity_asm(self): + """Yields roots of unity """ if self.bitsize == 16: twiddlesize = "short" - elif self.bitsize == 32: - twiddlesize = "word" else: - raise Exception("Should not happen") + assert self.bitsize == 32 + twiddlesize = "word" count = 0 last_stride = None - for x in self.get_roots_of_unity_core(): + for x in self.get_roots_of_unity(): if isinstance(x,str): yield x continue @@ -320,7 +342,7 @@ def get_roots_of_unity_real(self): count += 1 if stride > 1: if last_stride == 1: - if self.word_offset_mod_4 != None: + if self.word_offset_mod_4 is not None: yield f"// Word count until here: {count}" cc4 = count % 4 diff = self.word_offset_mod_4 - cc4 @@ -338,10 +360,13 @@ def get_roots_of_unity_real(self): last_stride = stride def export(self, filename): - license = """ + """Export twiddle factors as file""" + + license_text = """ /// /// Copyright (c) 2022 Arm Limited /// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer /// SPDX-License-Identifier: MIT /// /// Permission is hereby granted, free of charge, to any person obtaining a copy @@ -364,18 +389,19 @@ def export(self, filename): /// """ - f = open(filename,"w") - f.write(license) - f.write('\n'.join(self.get_roots_of_unity_real())) - f.close() + with open(filename, "w", encoding="utf-8") as f: + f.write(license_text) + f.write('\n'.join(self._get_roots_of_unity_asm())) -def main(): +def _main(): - ntt_kyber_l345 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,2),(2,3),(5,2)], word_offset_mod_4=2) + ntt_kyber_l345 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,2),(2,3),(5,2)], + word_offset_mod_4=2) ntt_kyber_l345.export("../naive/ntt_kyber_12_345_67_twiddles.s") ntt_kyber_l345.export("../opt/ntt_kyber_12_345_67_twiddles.s") - ntt_kyber_l123 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,3),(3,2),(5,2)], pad=[0,3], print_label=True, widen_single_twiddles_to_words=False) + ntt_kyber_l123 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,3),(3,2),(5,2)], + pad=[0,3], print_label=True, widen_single_twiddles_to_words=False) ntt_kyber_l123.export("../naive/ntt_kyber_123_45_67_twiddles.s") ntt_kyber_l123.export("../opt/ntt_kyber_123_45_67_twiddles.s") @@ -387,11 +413,13 @@ def main(): intt_kyber.export("../naive/intt_kyber_1_23_45_67_twiddles.s") intt_kyber.export("../opt/intt_kyber_1_23_45_67_twiddles.s") - ntt_dilithium = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8, word_offset_mod_4=2) + ntt_dilithium = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8, + word_offset_mod_4=2) ntt_dilithium.export("../naive/ntt_dilithium_12_34_56_78_twiddles.s") ntt_dilithium.export("../opt/ntt_dilithium_12_34_56_78_twiddles.s") - ntt_dilithium_l1234 = NttRootGen(size=256, bitsize=32, modulus=8380417, root=1753, layers=8, iters=[(0,4),(4,2),(6,2)], pad=[0], print_label=True) + ntt_dilithium_l1234 = NttRootGen(size=256, bitsize=32, modulus=8380417, root=1753, + layers=8, iters=[(0,4),(4,2),(6,2)], pad=[0], print_label=True) ntt_dilithium_l1234.export("../naive/aarch64/ntt_dilithium_1234_5678_twiddles.s") ntt_dilithium_l1234.export("../opt/aarch64/ntt_dilithium_1234_5678_twiddles.s") @@ -400,8 +428,8 @@ def main(): ntt_dilithium_l123.export("../naive/ntt_dilithium_123_456_78_twiddles.s") ntt_dilithium_l123.export("../opt/ntt_dilithium_123_456_78_twiddles.s") - ntt_dilithium_l123 = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8, print_label=True, pad=[0,3], - iters=[(0,3),(3,3),(6,2)]) + ntt_dilithium_l123 = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8, + print_label=True, pad=[0,3], iters=[(0,3),(3,3),(6,2)]) ntt_dilithium_l123.export("../naive/aarch64/ntt_dilithium_123_456_78_twiddles.s") ntt_dilithium_l123.export("../opt/aarch64/ntt_dilithium_123_456_78_twiddles.s") @@ -448,4 +476,4 @@ def main(): intt_n256_s32_l8_test.export("../opt/intt_n256_l8_s32_twiddles.s") if __name__ == "__main__": - main() + _main() diff --git a/paper/artifact/0001-Pin-pybind11_protobuf-commit-in-cmake-files.patch b/paper/artifact/0001-Pin-pybind11_protobuf-commit-in-cmake-files.patch new file mode 100644 index 00000000..526860d5 --- /dev/null +++ b/paper/artifact/0001-Pin-pybind11_protobuf-commit-in-cmake-files.patch @@ -0,0 +1,25 @@ +From 3b6f6999c042322268eb3ba84e829097014b7428 Mon Sep 17 00:00:00 2001 +From: Hanno Becker +Date: Tue, 19 Dec 2023 21:24:57 +0000 +Subject: [PATCH] Pin pybind11_protobuf commit in cmake files + +--- + cmake/dependencies/CMakeLists.txt | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/cmake/dependencies/CMakeLists.txt b/cmake/dependencies/CMakeLists.txt +index c39a44fb89..27923ccedb 100644 +--- a/cmake/dependencies/CMakeLists.txt ++++ b/cmake/dependencies/CMakeLists.txt +@@ -177,7 +177,7 @@ if(BUILD_PYTHON AND BUILD_pybind11_protobuf) + FetchContent_Declare( + pybind11_protobuf + GIT_REPOSITORY "https://github.com/pybind/pybind11_protobuf.git" +- GIT_TAG "main" ++ GIT_TAG "5baa2dc9d93e3b608cde86dfa4b8c63aeab4ac78" + PATCH_COMMAND git apply --ignore-whitespace "${CMAKE_CURRENT_LIST_DIR}/../../patches/pybind11_protobuf.patch" + ) + FetchContent_MakeAvailable(pybind11_protobuf) +-- +2.39.3 (Apple Git-145) + diff --git a/paper/artifact/slothy.Dockerfile b/paper/artifact/slothy.Dockerfile index bc93a1ff..ae50ccc1 100644 --- a/paper/artifact/slothy.Dockerfile +++ b/paper/artifact/slothy.Dockerfile @@ -33,6 +33,8 @@ RUN unzip or-tools.zip RUN rm or-tools.zip RUN mv or-tools-9.7 or-tools WORKDIR /home/ubuntu/slothy/submodules/or-tools +COPY 0001-Pin-pybind11_protobuf-commit-in-cmake-files.patch . +RUN git apply 0001-Pin-pybind11_protobuf-commit-in-cmake-files.patch RUN mkdir /home/ubuntu/slothy/submodules/or-tools/build RUN cmake -S. -Bbuild -DBUILD_PYTHON:BOOL=ON -DBUILD_SAMPLES:BOOL=OFF -DBUILD_EXAMPLES:BOOL=OFF WORKDIR /home/ubuntu/slothy/submodules/or-tools/build @@ -45,4 +47,4 @@ RUN ln -s /home/ubuntu/slothy /home/ubuntu/pqax/slothy RUN rm -rf /home/ubuntu/pqmx/slothy RUN ln -s /home/ubuntu/slothy /home/ubuntu/pqmx/slothy WORKDIR /home/ubuntu -RUN ln -s /home/ubuntu/slothy/paper/README.md /home/ubuntu/README.md \ No newline at end of file +RUN ln -s /home/ubuntu/slothy/paper/README.md /home/ubuntu/README.md diff --git a/paper/clean/neon/X25519-AArch64-simple.s b/paper/clean/neon/X25519-AArch64-simple.s index d7e9d6b8..a43f2841 100644 --- a/paper/clean/neon/X25519-AArch64-simple.s +++ b/paper/clean/neon/X25519-AArch64-simple.s @@ -113,7 +113,7 @@ .endm # TODO: also unwrap -.macro fcsel_dform out, in0, in1, cond // slothy:no-unfold +.macro fcsel_dform out, in0, in1, cond // @slothy:no-unfold fcsel dform_\out, dform_\in0, dform_\in1, \cond .endm @@ -416,10 +416,10 @@ sZ48 .req x22 stack_vstr_dform \offset\()_32, \vA\()8 .endm -.macro vector_load_lane vA, offset, lane // TODO: eliminate this explicit register assignment by converting stack_vld2_lane to AArch64Instruction xvector_load_lane_tmp .req x26 +.macro vector_load_lane vA, offset, lane add xvector_load_lane_tmp, sp, #\offset\()_0 stack_vld2_lane \vA\()0, \vA\()1, xvector_load_lane_tmp, \offset\()_0, \lane, 8 stack_vld2_lane \vA\()2, \vA\()3, xvector_load_lane_tmp, \offset\()_8, \lane, 8 @@ -591,8 +591,6 @@ sZ48 .req x22 scalar_decompress_inner \sA\()0, \sA\()1, \sA\()2, \sA\()3, \sA\()4, \sA\()5, \sA\()6, \sA\()7, \sA\()8, \sA\()9 .endm -.macro vector_addsub_repack_inner vA0, vA1, vA2, vA3, vA4, vA5, vA6, vA7, vA8, vA9, \ - vC0, vC1, vC2, vC3, vC4, vC5, vC6, vC7, vC8, vC9 // TODO: eliminate those. should be easy vR_l4h4l5h5 .req vADBC4 vR_l6h6l7h7 .req vADBC5 @@ -620,6 +618,8 @@ sZ48 .req x22 vrepack_inner_tmp .req v19 vrepack_inner_tmp2 .req v0 +.macro vector_addsub_repack_inner vA0, vA1, vA2, vA3, vA4, vA5, vA6, vA7, vA8, vA9, \ + vC0, vC1, vC2, vC3, vC4, vC5, vC6, vC7, vC8, vC9 vuzp1 vR_l4h4l5h5, \vC4, \vC5 vuzp1 vR_l6h6l7h7, \vC6, \vC7 stack_vld1r vrepack_inner_tmp, STACK_MASK1 @@ -949,6 +949,8 @@ scalar_mul_inner \ \sB\()0, \sB\()1, \sB\()2, \sB\()3, \sB\()4, \sB\()5, \sB\()6, \sB\()7, \sB\()8, \sB\()9 .endm +xtmp_scalar_sub_0 .req x21 + // sC0 .. sC4 output C = A + 4p - B (registers may be the same as A) // sA0 .. sA4 first operand A // sB0 .. sB4 second operand B @@ -957,8 +959,6 @@ scalar_mul_inner \ sA0, sA1, sA2, sA3, sA4, \ sB0, sB1, sB2, sB3, sB4 - xtmp_scalar_sub_0 .req x21 - ldr xtmp_scalar_sub_0, #=0x07fffffe07fffffc add \sC1, \sA1, xtmp_scalar_sub_0 add \sC2, \sA2, xtmp_scalar_sub_0 diff --git a/paper/scripts/slothy_ntt_helium.py b/paper/scripts/slothy_ntt_helium.py index 202054f6..14a34a0e 100644 --- a/paper/scripts/slothy_ntt_helium.py +++ b/paper/scripts/slothy_ntt_helium.py @@ -28,12 +28,11 @@ import argparse, logging, sys, os, time from io import StringIO -from slothy.slothy import Slothy -from slothy.core import Config +from slothy import Slothy, Config -import targets.arm_v81m.arch_v81m as Arch_Armv81M -import targets.arm_v81m.cortex_m55r1 as Target_CortexM55r1 -import targets.arm_v81m.cortex_m85r1 as Target_CortexM85r1 +import slothy.targets.arm_v81m.arch_v81m as Arch_Armv81M +import slothy.targets.arm_v81m.cortex_m55r1 as Target_CortexM55r1 +import slothy.targets.arm_v81m.cortex_m85r1 as Target_CortexM85r1 target_label_dict = {Target_CortexM55r1: "m55", Target_CortexM85r1: "m85"} @@ -42,7 +41,7 @@ class Example(): def __init__(self, infile, name=None, funcname=None, suffix="opt", rename=False, outfile="", arch=Arch_Armv81M, target=Target_CortexM55r1, **kwargs): - if name == None: + if name is None: name = infile self.arch = arch @@ -54,7 +53,7 @@ def __init__(self, infile, name=None, funcname=None, suffix="opt", self.outfile = f"{infile}_{self.suffix}_{target_label_dict[self.target]}" else: self.outfile = f"{outfile}_{self.suffix}_{target_label_dict[self.target]}" - if funcname == None: + if funcname is None: self.funcname = self.infile self.infile_full = f"../clean/helium/ntt/{self.infile}.s" self.outfile_full = f"../opt/helium/ntt/{self.outfile}.s" diff --git a/slothy-cli b/slothy-cli index b6408c8f..ff8ea9e7 100755 --- a/slothy-cli +++ b/slothy-cli @@ -28,47 +28,52 @@ import logging import time import os -from slothy.slothy import Slothy -from slothy.config import Config as SlothyConfig -from targets.query import Archery +from slothy import Slothy, Archery -def main(argv): +class CmdLineException(Exception): + """Exception thrown when a problem is encountered with the command line parameters""" + +def _main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("arch", type=str, - choices=Archery.list_archs(), help="The target architecture") - parser.add_argument("target", type=str, - choices=Archery.list_targets(), help="The target microarchitecture") + parser.add_argument("arch", type=str, choices=Archery.list_archs(), + help="The target architecture") + parser.add_argument("target", type=str, choices=Archery.list_targets(), + help="The target microarchitecture") parser.add_argument("input", type=str, - help="The name of the assembly source file.") + help="The name of the assembly source file.") parser.add_argument("-d", "--debug", default=False, action='store_true', - help="Show debug output") + help="Show debug output") parser.add_argument("-o", "--output", type=str, default=None, - help="The name of the file to write the generated assembly to. " - "If unspecified, the assembly will be printed on the standard output.") - parser.add_argument("-c", "--config", default=[], action="append", nargs='*', metavar="OPTION=VALUE", - help="""A (potentially empty) list of modifications to the default configuration of Slothy.""") + help="The name of the file to write the generated assembly to. " + "If unspecified, the assembly will be printed on the standard output.") + parser.add_argument("-c", "--config", default=[], action="append", nargs='*', + metavar="OPTION=VALUE", help="""A (potentially empty) list of modifications "\ + "to the default configuration of Slothy.""") parser.add_argument("-l", "--loop", default=[], action='append', type=str, - help="""The starting label for the loop to optimize. This is mutually - exclusive with -s/--start and -e/--end, which allowv you to specify - the code to optimize via start/end separately.""") + help="""The starting label for the loop to optimize. This is mutually + exclusive with -s/--start and -e/--end, which allowv you to specify + the code to optimize via start/end separately.""") + parser.add_argument("--fusion", default=False, action='store_true') + parser.add_argument("--fusion-only", default=False, action='store_true') parser.add_argument("-s", "--start", default=None, type=str, - help="""The label or line at which the to code to optimize begins. - This is mutually exclusive with -l/--loop.""") + help="""The label or line at which the to code to optimize begins. + This is mutually exclusive with -l/--loop.""") parser.add_argument("-e", "--end", default=None, type=str, - help="""The label or line at which the to code to optimize ends - This is mutually exclusive with -l/--loop.""") + help="""The label or line at which the to code to optimize ends + This is mutually exclusive with -l/--loop.""") parser.add_argument("-r", "--rename-function", default=None, type=str, - help="""Perform function renaming. Format: 'old_func_name,new_func_name'""") + help="""Perform function renaming. Format: 'old_func_name,new_func_name'""") parser.add_argument("--silent", default=False, action='store_true', - help="""Silent mode: Only print warnings and errors""") + help="""Silent mode: Only print warnings and errors""") parser.add_argument("--log", default=False, action='store_true', - help="""Write logging output to file""") + help="""Write logging output to file""") parser.add_argument("--logdir", default=".", type=str, - help="""Directory to store log output to""") + help="""Directory to store log output to""") parser.add_argument("--logfile", default=None, type=str, - help="""File to write logging output to. Can be omitted, in which case a generic name with timestamp is used""") + help="""File to write logging output to. Can be omitted, "\ + "in which case a generic name with timestamp is used""") args = parser.parse_args() @@ -116,51 +121,51 @@ def main(argv): logger = logging.getLogger("slothy-cli") - Arch = Archery.get_arch(args.arch) - Target = Archery.get_target(args.target) - slothy = Slothy(Arch,Target,logger=logger) + arch = Archery.get_arch(args.arch) + target = Archery.get_target(args.target) + slothy = Slothy(arch,target,logger=logger) def parse_config_value_as(val, ty): def parse_as_float(val): try: res = float(val) return res - except: + except ValueError: return None def check_ty(ty_real): - if ty == None or ty == type(None) or ty == ty_real: + if ty is None or ty == type(None) or ty == ty_real: return - raise Exception(f"Configuration value {val} isn't correctly typed -- " \ + raise CmdLineException(f"Configuration value {val} isn't correctly typed -- " \ f"expected {ty}, but got {ty_real}") if val == "": - raise Exception("Invalid configuration value") - logger.debug(f"Parsing configuration value {val} with expected type {ty}") + raise CmdLineException("Invalid configuration value") + logger.debug("Parsing configuration value %s with expected type %s", val, ty) if val.isdigit(): check_ty(int) - logger.debug(f"Value {val} parsed as integer") + logger.debug("Value %s parsed as integer", val) return int(val) if val.lower() == "true": check_ty(bool) - logger.debug(f"Value {val} parsed as Boolean") + logger.debug("Value %s parsed as Boolean", val) return True if val.lower() == "false": check_ty(bool) - logger.debug(f"Value {val} parsed as Boolean") + logger.debug("Value %s parsed as Boolean", val) return False # Try to parse as RegisterType - ty = Arch.RegisterType.from_string(val) - if ty != None: - logger.debug(f"Value {val} parsed as RegisterType") + ty = arch.RegisterType.from_string(val) + if ty is not None: + logger.debug("Value %s parsed as RegisterType", val) return ty f = parse_as_float(val) - if f != None: + if f is not None: check_ty(float) - logger.debug(f"Value {val} parsed as float") + logger.debug("Value %s parsed as float", val) return f if val[0] == '[' and val[-1] == ']': check_ty(list) val = val[1:-1].split(',') - logger.debug(f"Parsing {val} is a list -- parse recursively") + logger.debug("Parsing %s is a list -- parse recursively", val) return [ parse_config_value_as(v,None) for v in val ] if val[0] == '{' and val[-1] == '}': check_ty(dict) @@ -168,11 +173,11 @@ def main(argv): kvs = [ kv.split(':') for kv in kvs ] for kv in kvs: if not len(kv) == 2: - raise Exception("Invalid dictionary entry") - logger.debug(f"Parsing {val} is a dictionary -- parse recursively") + raise CmdLineException("Invalid dictionary entry") + logger.debug("Parsing %s is a dictionary -- parse recursively", val) return { parse_config_value_as(k, None) : parse_config_value_as(v, None) for k,v in kvs } - logger.debug(f"Parsing {val} as string") + logger.debug("Parsing %s as string", val) return val # A plain '-c' without arguments should list all available configuration options @@ -196,14 +201,14 @@ def main(argv): obj = getattr(obj,attrs.pop(0)) attr = attrs.pop(0) val = parse_config_value_as(val, type(getattr(obj,attr))) - logger.info(f"- Setting configuration option {attr} to value {val}") + logger.info("Setting configuration option %s to value %s", attr, val) setattr(obj,attr,val) - def check_list_of_fixed_len_list(lst, fixlen): + def check_list_of_fixed_len_list(lst): invalid = next(filter(lambda o: len(o) != 1, lst), None) - if invalid != None: - raise Exception(f"Invalid configuration argument {invalid} in {lst}") - check_list_of_fixed_len_list(args.config,1) + if invalid is not None: + raise CmdLineException(f"Invalid configuration argument {invalid} in {lst}") + check_list_of_fixed_len_list(args.config) config_kv_pairs = [ c[0].split('=') for c in args.config ] for kv in config_kv_pairs: # We allow shorthands for boolean configurations @@ -222,17 +227,23 @@ def main(argv): elif len(kv) == 2: setattr_recursive(slothy.config, kv[0], kv[1]) else: - raise Exception(f"Invalid configuration {kv}") + raise CmdLineException(f"Invalid configuration {kv}") # Read input slothy.load_source_from_file(args.input) # Optimize - if len(args.loop) > 0: - for l in args.loop: - slothy.optimize_loop(l) - else: - slothy.optimize(start=args.start, end=args.end) + if args.fusion is True: + if len(args.loop) > 0: + for l in args.loop: + slothy.fusion_loop(l) + + if not (args.fusion is True and args.fusion_only is True): + if len(args.loop) > 0: + for l in args.loop: + slothy.optimize_loop(l) + else: + slothy.optimize(start=args.start, end=args.end) # Rename if args.rename_function: @@ -245,7 +256,7 @@ def main(argv): if args.output is not None: slothy.write_source_to_file(args.output) else: - slothy.print_code() + print(slothy.get_source_as_string()) if __name__ == "__main__": - main(sys.argv[1:]) + _main() diff --git a/slothy/__init__.py b/slothy/__init__.py index e69de29b..51547d11 100644 --- a/slothy/__init__.py +++ b/slothy/__init__.py @@ -0,0 +1,4 @@ +from slothy.core.slothy import Slothy +from slothy.core.core import SlothyException +from slothy.core.config import Config +from slothy.targets.query import Archery diff --git a/targets/__init__.py b/slothy/core/__init__.py similarity index 100% rename from targets/__init__.py rename to slothy/core/__init__.py diff --git a/slothy/config.py b/slothy/core/config.py similarity index 78% rename from slothy/config.py rename to slothy/core/config.py index d7df3294..e0b32da3 100644 --- a/slothy/config.py +++ b/slothy/core/config.py @@ -25,6 +25,12 @@ # Author: Hanno Becker # +""" +SLOTHY configuration +""" + +# pylint:disable=too-many-lines + from copy import deepcopy import os @@ -39,31 +45,6 @@ class Config(NestedPrint, LockAttributes): This configuration object is used both for one-shot optimizations using SlothyBase, as well as stateful multi-pass optimizations using Slothy.""" - _default_split_heuristic = False - _default_split_heuristic_visualize_stalls = False - _default_split_heuristic_visualize_units = False - _default_split_heuristic_region = [0.0,1.0] - _default_split_heuristic_chunks = False - _default_split_heuristic_optimize_seam = 0 - _default_split_heuristic_bottom_to_top = False - _default_split_heuristic_factor = 2 - _default_split_heuristic_abort_cycle_at = None - _default_split_heuristic_stepsize = None - _default_split_heuristic_repeat = 1 - _default_split_heuristic_preprocess_naive_interleaving = False - _default_split_heuristic_preprocess_naive_interleaving_by_latency = False - - _default_compiler_binary = "gcc" - - _default_unsafe_skip_address_fixup = False - - _default_with_preprocessor = False - _default_max_solutions = 64 - _default_timeout = None - _default_retry_timeout = None - _default_ignore_objective = False - _default_objective_precision = 0 - @property def arch(self): """The module defining the underlying architecture used by Slothy. @@ -98,6 +79,97 @@ def reserved_regs(self): return self._reserved_regs return self._arch.RegisterType.default_reserved() + @property + def selfcheck(self): + """Indicates whether SLOTHY performs a self-check on the optimization result. + + The selfcheck confirms that the scheduling permutation found by SLOTHY yields + an isomorphism between the data flow graphs of the original and optimized code. + + WARNING: Do not unset this option unless you know what you are doing. + It is vital in catching bugs in the model generation early. + + WARNING: The selfcheck is not a formal verification of SLOTHY's output! + There are at least two classes of bugs uncaught by the selfcheck: + + - User configuration issues: The selfcheck validates SLOTHY's optimization + in the context of the provided configuration. Validation of the configuration + is the user's responsibility. Two common pitfalls include missing reserved + registers (allowing SLOTHY to clobber more registers than intended), or + missing output registers (allowing SLOTHY to overwrite an output register + in subsequent instructions). + + This is the most common source of issues for code passing the selfcheck + but remaining functionally incorrect. + + - Bugs in address offset fixup: SLOTHY's modelling of post-load/store address + increments is deliberately inaccurate to allow for reordering of such instructions + leveraging commutativity relations such as + + ``` + LDR X,[A],#imm; STR Y,[A] === STR Y,[A, #imm]; LDR X,[A],#imm + ``` + + (See also section "Address offset rewrites" in the SLOTHY paper). + + Bugs in SLOTHY's address fixup logic would not be caught by the selfcheck. + If your code doesn't work and you are sure to have configured SLOTHY correctly, + you may therefore want to double-check that address offsets have been adjusted + correctly by SLOTHY. + """ + return self._selfcheck + + @property + def allow_useless_instructions(self): + """Indicates whether SLOTHY should abort upon encountering unused instructions. + + SLOTHY requires explicit knowledge of the intended output registers of its + input assembly. If this option is set, and an instruction is encountered which + writes to a register which (a) is not an output register, (b) is not used by + any later instruction, then SLOTHY will flag this instruction and abort. + + The reason for this behaviour is that such unused instructions are usually + a sign of a buggy configuration, which would likely lead to intended output + registers being clobbered by later instructions. + + WARNING: Don't disable this option unless you know what you are doing! + Disabling this option makes it much easier to overlook configuration + issues in SLOTHY and can lead to hard-to-debug optimization failures. + """ + return self._allow_useless_instructions + + @property + def variable_size(self): + """Model number of stalls as a parameter in the constraint model. + + If this is set, one-shot SLOTHY optimization will make the number of stalls + flexible in the model and, by default, task the underlying constraint solver + to minimize it. + + If this is not set, one-shot SLOTHY optimizations will search for solutions + with a fixed number of stalls, and an external binary search be used to + find the minimum number of stalls. + + For small-to-medium sizes assembly input, this option should be set, and will + lead to faster optimization. For large assembly input, the user should experiment + and consider unsetting it to reduce model complexity. + """ + return self._variable_size + + @property + def keep_tags(self): + """Indicates whether tags in the input source should be kept or removed. + + Tags include pre/core/post or ordering annotations that usually become meaningless + post-optimization. However, for preprocessing runs that do not reorder code, it makes + sense to keep them.""" + return self._keep_tags + + @property + def ignore_tags(self): + """Indicates whether tags in the input source should be ignored.""" + return self._ignore_tags + @property def register_aliases(self): """Dictionary mapping symbolic register names to architectural register names. @@ -110,6 +182,7 @@ def register_aliases(self): return { **self._register_aliases, **self._arch.RegisterType.default_aliases() } def add_aliases(self, new_aliases): + """Add further register aliases to the configuration""" self._register_aliases = { **self._register_aliases, **new_aliases } @property @@ -222,7 +295,7 @@ def compiler_binary(self): """The compiler binary to be used. This is only relevant of `with_preprocessor` is set.""" - return self._default_compiler_binary + return self._compiler_binary @property def timeout(self): @@ -238,11 +311,31 @@ def retry_timeout(self): return self._retry_timeout @property - def unsafe_skip_address_fixup(self): - """Warn but not fail if post-optimization address fixup failed. - - (See 4.13, Address offset rewrites, in https://eprint.iacr.org/2022/1303.pdf)""" - return self._unsafe_skip_address_fixup + def do_address_fixup(self): + """Indicates whether post-optimization address fixup should be conducted. + + SLOTHY's modelling of post-load/store address increments is deliberately + inaccurate to allow for reordering of such instructions leveraging commutativity + relations such as: + + ``` + LDR X,[A],#imm; STR Y,[A] === STR Y,[A, #imm]; LDR X,[A],#imm + ``` + + When such reordering happens, a "post-optimization address fixup" of immediate + load/store offsets is necessary. See also section "Address offset rewrites" in + the SLOTHY paper. + + Disabling this option will skip post-optimization address fixup and put the + burden of post-optimization address fixup on the user. + Disabling this option does NOT tighten the constraint model to forbid reorderings + such as the above. + + WARNING: Don't disable this option unless you know what you are doing! + Disabling this will likely lead to optimized code that is functionally incorrect + and needing manual address offset fixup! + """ + return self._do_address_fixup @property def ignore_objective(self): @@ -321,6 +414,9 @@ def split_heuristic_stepsize(self): @property def split_heuristic_optimize_seam(self): + """If the split heuristic is used, the number of instructions above and beyond + the current sliding window that should be fixed but taken into account during + optimization.""" if not self.split_heuristic: raise InvalidConfig("Did you forget to set config.split_heuristic=True? "\ "Shouldn't read config.split_heuristic_optimize_seam otherwise.") @@ -337,27 +433,13 @@ def split_heuristic_chunks(self): @property def split_heuristic_bottom_to_top(self): + """If the split heuristic is used, move the sliding window from bottom to top + rather than from top to bottom.""" if not self.split_heuristic: raise InvalidConfig("Did you forget to set config.split_heuristic=True? "\ "Shouldn't read config.split_heuristic_bottom_to_top otherwise.") return self._split_heuristic_bottom_to_top - @property - def split_heuristic_visualize_stalls(self): - """Attempt to visualize the stalls after application of the split heuristic""" - if not self.split_heuristic: - raise InvalidConfig("Did you forget to set config.split_heuristic=True? "\ - "Shouldn't read config.split_heuristic_visualize_stalls otherwise.") - return self._split_heuristic_visualize_stalls - - @property - def split_heuristic_visualize_units(self): - """Attempt to visualize the functional units after application of the split heuristic""" - if not self.split_heuristic: - raise InvalidConfig("Did you forget to set config.split_heuristic=True? "\ - "Shouldn't read config.split_heuristic_visualize_units otherwise.") - return self._split_heuristic_visualize_units - @property def split_heuristic_region(self): """Restrict the split heuristic to a sub-region of the code. @@ -369,7 +451,7 @@ def split_heuristic_region(self): if the split region is set fo [0.25, 0.75] and the split factor is 5, then optimization windows of size .1 will be considered within [0.25, 0.75]. - Note that even if this option is used, the specification of inputs and outputs is still + Note that even if this option is used, the specification of inputs and outputs is still with respect to the entire code; SLOTHY will automatically derive the outputs of the subregion configured here.""" if not self.split_heuristic: @@ -388,21 +470,22 @@ def split_heuristic_preprocess_naive_interleaving(self): optimization.""" if not self.split_heuristic: raise InvalidConfig("Did you forget to set config.split_heuristic=True? "\ - "Shouldn't read config.split_heuristic_preprocess_naive_interleaving otherwise.") + "Shouldn't read config.split_heuristic_preprocess_naive_interleaving otherwise.") return self._split_heuristic_preprocess_naive_interleaving @property def split_heuristic_preprocess_naive_interleaving_by_latency(self): - """If split heuristic with naive preprocessing is used, this option causes the naive interleaving - to be by latency-depth rather than latency.""" + """If split heuristic with naive preprocessing is used, this option causes + the naive interleaving to be by latency-depth rather than latency.""" if not self.split_heuristic: - raise InvalidConfig("Did you forget to set config.split_heuristic=True? "\ - "Shouldn't read config.split_heuristic_preprocess_naive_interleaving_by_latency otherwise.") + raise InvalidConfig("Did you forget to set config.split_heuristic=True? Shouldn't" \ + "read config.split_heuristic_preprocess_naive_interleaving_by_latency otherwise.") return self._split_heuristic_preprocess_naive_interleaving_by_latency - # TODO: Consider setting this to True unconditionally @property def flexible_lifetime_start(self): + """Internal property indicating whether the lifetime interval of a register + should be allowed to extend _before_ the instructions which uses it.""" return \ self.constraints.maximize_register_lifetimes or \ (self.sw_pipelining.enabled and self.sw_pipelining.allow_post) @@ -436,22 +519,6 @@ def copy(self): class SoftwarePipelining(NestedPrint, LockAttributes): """Subconfiguration for software pipelining""" - _default_enabled = False - _default_unroll = 1 - _default_pre_before_post = False - _default_allow_pre = True - _default_allow_post = False - _default_unknown_iteration_count = False - _default_minimize_overlapping = True - _default_optimize_preamble = True - _default_optimize_postamble = True - _default_max_overlapping = None - _default_min_overlapping = None - _default_halving_heuristic = False - _default_halving_heuristic_periodic = False - _default_halving_heuristic_split_only = False - _default_max_pre = 1.0 - @property def enabled(self): """Determines whether software pipelining should be enabled.""" @@ -464,14 +531,14 @@ def unroll(self): @property def pre_before_post(self): - """If both early and late instructions are allowed, force late instructions of iteration N - to come _before_ early instructions of iteration N+2.""" + """If both early and late instructions are allowed, force late instructions + of iteration N to come _before_ early instructions of iteration N+2.""" return self._pre_before_post @property def allow_pre(self): - """Allow 'early' instructions, that is, instructions that are pulled forward from iteration N+1 - to iteration N. A typical example would be an early load.""" + """Allow 'early' instructions, that is, instructions that are pulled forward + from iteration N+1 to iteration N. A typical example would be an early load.""" return self._allow_pre @property @@ -553,36 +620,21 @@ def max_pre(self): def __init__(self): super().__init__() - self._enabled = \ - Config.SoftwarePipelining._default_enabled - self._unroll = \ - Config.SoftwarePipelining._default_unroll - self._pre_before_post = \ - Config.SoftwarePipelining._default_pre_before_post - self._allow_pre = \ - Config.SoftwarePipelining._default_allow_pre - self._allow_post = \ - Config.SoftwarePipelining._default_allow_post - self._unknown_iteration_count = \ - Config.SoftwarePipelining._default_unknown_iteration_count - self._minimize_overlapping = \ - Config.SoftwarePipelining._default_minimize_overlapping - self._optimize_preamble = \ - Config.SoftwarePipelining._default_optimize_preamble - self._optimize_postamble = \ - Config.SoftwarePipelining._default_optimize_postamble - self._max_overlapping = \ - Config.SoftwarePipelining._default_max_overlapping - self._min_overlapping = \ - Config.SoftwarePipelining._default_min_overlapping - self._halving_heuristic = \ - Config.SoftwarePipelining._default_halving_heuristic - self._halving_heuristic_periodic = \ - Config.SoftwarePipelining._default_halving_heuristic_periodic - self._halving_heuristic_split_only = \ - Config.SoftwarePipelining._default_halving_heuristic_split_only - self._max_pre = \ - Config.SoftwarePipelining._default_max_pre + self.enabled = False + self.unroll = 1 + self.pre_before_post = False + self.allow_pre = True + self.allow_post = False + self.unknown_iteration_count = False + self.minimize_overlapping = True + self.optimize_preamble = True + self.optimize_postamble = True + self.max_overlapping = None + self.min_overlapping = None + self.halving_heuristic = False + self.halving_heuristic_periodic = False + self.halving_heuristic_split_only = False + self.max_pre = 1.0 self.lock() @@ -635,19 +687,6 @@ def max_pre(self,val): class Constraints(NestedPrint, LockAttributes): """Subconfiguration for performance constraints""" - _default_stalls_allowed = 0 - _default_stalls_maximum_attempt = 512 - _default_stalls_minimum_attempt = 0 - _default_stalls_precision = 0 - _default_stalls_timeout_below_precision = None - _default_stalls_first_attempt = 0 - - _default_model_latencies = True - _default_model_functional_units = True - _default_allow_reordering = True - _default_allow_renaming = True - _default_restricted_renaming = None - @property def stalls_allowed(self): """The number of stalls allowed. Internally, this is the number of NOP @@ -697,7 +736,7 @@ def stalls_first_attempt(self): def stalls_precision(self): """The precision of the binary search for the minimum number of stalls - Slothy will stop searching if it can narrow down the minimum number + SLOTHY will stop searching if it can narrow down the minimum number of stalls to an interval of the length provided by this variable. In particular, a value of 1 means the true minimum if searched for.""" if self.functional_only: @@ -706,6 +745,9 @@ def stalls_precision(self): @property def stalls_timeout_below_precision(self): + """If this variable is set to a non-None value, SLOTHY does not abort + optimization once binary search is operating on an interval smaller than + the stall precision, but instead sets a different (typically smaller) timeout.""" return self._stalls_timeout_below_precision @property @@ -746,10 +788,6 @@ def allow_renaming(self): in order to find the number of model violations in a piece of code.""" return self._allow_renaming - @property - def restricted_renaming(self): - return self._restricted_renaming - def __init__(self): super().__init__() @@ -767,18 +805,17 @@ def __init__(self): self.minimize_use_of_extra_registers = None self.allow_extra_registers = {} - self._model_latencies = Config.Constraints._default_model_latencies - self._model_functional_units = Config.Constraints._default_model_functional_units - self._allow_reordering = Config.Constraints._default_allow_reordering - self._allow_renaming = Config.Constraints._default_allow_renaming - self._restricted_renaming = Config.Constraints._default_restricted_renaming + self._stalls_allowed = 0 + self._stalls_maximum_attempt = 512 + self._stalls_minimum_attempt = 0 + self._stalls_precision = 0 + self._stalls_timeout_below_precision = None + self._stalls_first_attempt = 0 - self._stalls_allowed = Config.Constraints._default_stalls_allowed - self._stalls_maximum_attempt = Config.Constraints._default_stalls_maximum_attempt - self._stalls_minimum_attempt = Config.Constraints._default_stalls_minimum_attempt - self._stalls_first_attempt = Config.Constraints._default_stalls_first_attempt - self._stalls_precision = Config.Constraints._default_stalls_precision - self._stalls_timeout_below_precision = Config.Constraints._default_stalls_timeout_below_precision + self._model_latencies = True + self._model_functional_units = True + self._allow_reordering = True + self._allow_renaming = True self.lock() @@ -812,9 +849,6 @@ def allow_reordering(self,val): @allow_renaming.setter def allow_renaming(self,val): self._allow_renaming = val - @restricted_renaming.setter - def restricted_renaming(self,val): - self._restricted_renaming = val @functional_only.setter def functional_only(self,val): if not val: @@ -825,11 +859,6 @@ def functional_only(self,val): class Hints(NestedPrint, LockAttributes): """Subconfiguration for solver hints""" - _default_all_core = True - _default_order_hint_orig_order = False - _default_rename_hint_orig_rename = False - _default_ext_bsearch_remember_successes = False - @property def all_core(self): """When SW pipelining is used, hint that all instructions @@ -850,15 +879,20 @@ def rename_hint_orig_rename(self): @property def ext_bsearch_remember_successes(self): + """When using an external binary search, hint previous successful + optimiation. + + See also Config.variable_size.""" return self._ext_bsearch_remember_successes def __init__(self): super().__init__() - self._all_core = Config.Hints._default_all_core - self._order_hint_orig_order = Config.Hints._default_order_hint_orig_order - self._rename_hint_orig_rename = Config.Hints._default_rename_hint_orig_rename - self._ext_bsearch_remember_successes = Config.Hints._default_ext_bsearch_remember_successes + self._all_core = True + self._order_hint_orig_order = False + self._rename_hint_orig_rename = False + self._ext_bsearch_remember_successes = False + self.lock() @all_core.setter @@ -881,14 +915,7 @@ def __init__(self, Arch, Target): self._constraints = Config.Constraints() self._hints = Config.Hints() - # NOTE: - This saves us from having to do a binary search for the minimum - # number of stalls ourselves, but it seems to slow down the tool - # significantly! - # - It also disables the minimization of instruction overlapping - # in loop mode. - # - # Rather keep it off for now... - self.variable_size = False + self._variable_size = False self._register_aliases = {} self._outputs = set() @@ -900,40 +927,38 @@ def __init__(self, Arch, Target): self._locked_registers = [] self._reserved_regs = None - self.selfcheck = True # Check that that resulting code reordering constitutes an isomorphism of computation flow graphs - - self.allow_useless_instructions = False - - self._split_heuristic = Config._default_split_heuristic - self._split_heuristic_region = Config._default_split_heuristic_region - self._split_heuristic_factor = Config._default_split_heuristic_factor - self._split_heuristic_abort_cycle_at = Config._default_split_heuristic_abort_cycle_at - self._split_heuristic_stepsize = Config._default_split_heuristic_stepsize - self._split_heuristic_optimize_seam = Config._default_split_heuristic_optimize_seam - self._split_heuristic_chunks = Config._default_split_heuristic_chunks - self._split_heuristic_bottom_to_top = Config._default_split_heuristic_bottom_to_top - self._split_heuristic_repeat = Config._default_split_heuristic_repeat - self._split_heuristic_preprocess_naive_interleaving = \ - Config._default_split_heuristic_preprocess_naive_interleaving - self._split_heuristic_preprocess_naive_interleaving_by_latency = \ - Config._default_split_heuristic_preprocess_naive_interleaving_by_latency - self._split_heuristic_optimize_seam = Config._default_split_heuristic_optimize_seam - - self._unsafe_skip_address_fixup = Config._default_unsafe_skip_address_fixup - - self._with_preprocessor = Config._default_with_preprocessor - self._compiler_binary = Config._default_compiler_binary - self._max_solutions = Config._default_max_solutions - self._timeout = Config._default_timeout - self._retry_timeout = Config._default_retry_timeout - self._ignore_objective = Config._default_ignore_objective - self._objective_precision = Config._default_objective_precision + self._selfcheck = True + self._allow_useless_instructions = False + + self._split_heuristic = False + self._split_heuristic_region = [0.0,1.0] + self._split_heuristic_chunks = False + self._split_heuristic_optimize_seam = 0 + self._split_heuristic_bottom_to_top = False + self._split_heuristic_factor = 2 + self._split_heuristic_abort_cycle_at = None + self._split_heuristic_stepsize = None + self._split_heuristic_repeat = 1 + self._split_heuristic_preprocess_naive_interleaving = False + self._split_heuristic_preprocess_naive_interleaving_by_latency = False + + self._compiler_binary = "gcc" + + self.keep_tags = True + self.ignore_tags = False + + self._do_address_fixup = True + + self._with_preprocessor = False + self._max_solutions = 64 + self._timeout = None + self._retry_timeout = None + self._ignore_objective = False + self._objective_precision = 0 # Visualization self.indentation = 8 self.visualize_reordering = True - self._split_heuristic_visualize_stalls = False - self._split_heuristic_visualize_units = False self.placeholder_char = '.' self.early_char = 'e' @@ -984,6 +1009,15 @@ def _check_rename_config(self, lst): @reserved_regs.setter def reserved_regs(self,val): self._reserved_regs = val + @variable_size.setter + def variable_size(self,val): + self._variable_size = val + @selfcheck.setter + def selfcheck(self,val): + self._selfcheck = val + @allow_useless_instructions.setter + def allow_useless_instructions(self,val): + self._allow_useless_instructions = val @locked_registers.setter def locked_registers(self,val): self._locked_registers = val @@ -1002,9 +1036,15 @@ def timeout(self, val): @retry_timeout.setter def retry_timeout(self, val): self._retry_timeout = val - @unsafe_skip_address_fixup.setter - def unsafe_skip_address_fixup(self, val): - self._unsafe_skip_address_fixup = val + @keep_tags.setter + def keep_tags(self, val): + self._keep_tags = val + @ignore_tags.setter + def ignore_tags(self, val): + self._ignore_tags = val + @do_address_fixup.setter + def do_address_fixup(self, val): + self._do_address_fixup = val @ignore_objective.setter def ignore_objective(self, val): self._ignore_objective = val @@ -1032,12 +1072,6 @@ def split_heuristic_optimize_seam(self, val): @split_heuristic_bottom_to_top.setter def split_heuristic_bottom_to_top(self, val): self._split_heuristic_bottom_to_top = val - @split_heuristic_visualize_stalls.setter - def split_heuristic_visualize_stalls(self, val): - self._split_heuristic_visualize_stalls = val - @split_heuristic_visualize_units.setter - def split_heuristic_visualize_units(self, val): - self._split_heuristic_visualize_units = val @split_heuristic_region.setter def split_heuristic_region(self, val): self._split_heuristic_region = val diff --git a/slothy/core.py b/slothy/core/core.py similarity index 81% rename from slothy/core.py rename to slothy/core/core.py index 7a9101e7..4a41844f 100644 --- a/slothy/core.py +++ b/slothy/core/core.py @@ -25,22 +25,26 @@ # Author: Hanno Becker # -import logging, ortools, math - +import logging +import math from types import SimpleNamespace from copy import deepcopy +from functools import cached_property from sympy import simplify +import ortools from ortools.sat.python import cp_model -from functools import cached_property -from slothy.config import Config -from slothy.helper import LockAttributes, AsmHelper, Permutation, DeferHandler +from slothy.core.config import Config +from slothy.helper import LockAttributes, Permutation, DeferHandler, SourceLine -from slothy.dataflow import DataFlowGraph as DFG -from slothy.dataflow import Config as DFGConfig -from slothy.dataflow import InstructionOutput, InstructionInOut, ComputationNode -from slothy.dataflow import SlothyUselessInstructionException +from slothy.core.dataflow import DataFlowGraph as DFG +from slothy.core.dataflow import Config as DFGConfig +from slothy.core.dataflow import InstructionOutput, InstructionInOut, ComputationNode +from slothy.core.dataflow import SlothyUselessInstructionException + +class SlothyException(Exception): + """Generic exception thrown by SLOTHY""" class Result(LockAttributes): """The results of a one-shot SLOTHY optimization run""" @@ -67,16 +71,13 @@ def _gen_orig_code_visualized(self): def arr_width(arr): mi = min(arr) - ma = max(0,max(arr)) + ma = max(0, max(arr)) # pylint:disable=nested-min-max return mi, ma-mi min_pos, width = arr_width(self.reordering.values()) - if not self.config.constraints.functional_only: - min_pos_cycle, width_cycle = \ - arr_width(self.cycle_position_with_bubbles.values()) - yield "" - yield "// original source code" + yield SourceLine("") + yield SourceLine("").set_comment("original source code") for i in range(self.codesize): pos = self.reordering[i] - min_pos c = core_char @@ -101,16 +102,11 @@ def arr_width(arr): c_pos += self.codesize t_comment = ''.join(t_comment) - if not self.config.constraints.functional_only and \ - self.config.target.issue_rate > 1: - cycle_pos = self.cycle_position_with_bubbles[i] - min_pos_cycle - t_comment_cycle = "|| " + (d * cycle_pos + c + d * (width_cycle - cycle_pos)) - else: - t_comment_cycle = "" - - yield f"// {self.orig_code[i]:{fixlen-3}s} // {t_comment} {t_comment_cycle}" + yield SourceLine("") \ + .set_comment(f"{str(self.orig_code[i]):{fixlen-3}s}") \ + .add_comment(t_comment) - yield "" + yield SourceLine("") @property def orig_code_visualized(self): @@ -128,10 +124,19 @@ def orig_outputs(self): @property def codesize(self): + """The number of instructions in the (original and optimized) source code.""" return len(self.orig_code) @property def codesize_with_bubbles(self): + """Performance-measure for the optimized source code. + + This is the number of issue slots used by the optimized code. + Equivalently, after division by the target's issue width, it is + SLOTHY's expectation of the performance of the code in cycles. + + It is also the codomain of the xxx_with_bubbles dictionaries. + """ return self._codesize_with_bubbles @codesize_with_bubbles.setter def codesize_with_bubbles(self, v): @@ -140,6 +145,21 @@ def codesize_with_bubbles(self, v): @property def pre_core_post_dict(self): + """Dictionary indicating interleaving of iterations. + + This dictionary consists of items (i, (pre, core, post)), where + i is the original program order position of an instruction, and + pre, core, post indicate whether that instruction is an early, + core or late instruction in the optimized source code. + + An early instruction is one which is pulled into the previous iteration. + A late instruction is one which is deferred until the next iteration. + A core instruction is one which is left in its original iteration. + + This property is only meaningful when software pipelining is enabled. + + See also is_pre, is_core, is_post. + """ self._require_sw_pipelining() return self._pre_core_post_dict @pre_core_post_dict.setter @@ -290,7 +310,7 @@ def get_periodic_reordering(self, copies): vals = list(t.values()) vals.sort() res = { i : vals.index(v) for (i,v) in t.items() } - assert (Permutation.is_permutation(res, copies * self.codesize)) + assert Permutation.is_permutation(res, copies * self.codesize) return res def get_periodic_reordering_inv(self, copies): @@ -323,15 +343,15 @@ def get_fully_unrolled_loop(self, iterations): self._require_sw_pipelining() assert iterations > self.num_exceptional_iterations kernel_copies = iterations - self.num_exceptional_iterations - new_source = '\n'.join(self._preamble + - ( self._code * kernel_copies ) + - self._postamble ) - old_source = '\n'.join(self._orig_code * iterations) + new_source = (self._preamble + + (self._code * kernel_copies) + + self._postamble ) + old_source = self._orig_code * iterations return old_source, new_source def get_unrolled_kernel(self, iterations): self._require_sw_pipelining() - return '\n'.join(self._code * iterations) + return self._code * iterations @cached_property def reordering(self): @@ -344,6 +364,7 @@ def periodic_reordering_with_bubbles(self): @cached_property def periodic_reordering_with_bubbles_inv(self): + """The inverse dictionary to periodic_reordering_with_bubbles""" return self.get_periodic_reordering_with_bubbles_inv(1) @cached_property @@ -352,6 +373,7 @@ def periodic_reordering(self): @cached_property def periodic_reordering_inv(self): + """The inverse permutation to periodic_reordering""" res = self.get_periodic_reordering_inv(1) assert Permutation.is_permutation(res, self.codesize) return res @@ -362,9 +384,14 @@ def reordering_inv(self): return { v : k for k,v in self.reordering.items() } @property + def code_raw(self): + """Optimized code, without annotations""" + return self._code + @property def code(self): """The optimized source code""" code = self._code + assert SourceLine.is_source(code) ri = self.periodic_reordering_with_bubbles_inv if not self.config.visualize_reordering: return code @@ -380,8 +407,10 @@ def _gen_visualized_code(): for i in range(self.codesize_with_bubbles): p = ri.get(i, None) if p is None: - gapstr = "// gap" - yield f"{gapstr:{fixlen}s} // {d * self.codesize}" + gap_str = "gap" + yield SourceLine("") \ + .set_comment(f"{gap_str:{fixlen-3}s}") \ + .add_comment(d * self.codesize) continue s = code[self.periodic_reordering[p]] c = core_char @@ -389,8 +418,8 @@ def _gen_visualized_code(): c = early_char elif self.is_post(p): c = late_char - comment = d * p + c + d * (self.codesize - p - 1) - yield f"{s:{fixlen}s} // {comment}" + vis = d * p + c + d * (self.codesize - p - 1) + yield s.copy().set_length(fixlen).set_comment(vis) res = list(_gen_visualized_code()) res += self.orig_code_visualized @@ -398,17 +427,18 @@ def _gen_visualized_code(): return res @code.setter def code(self, val): + assert SourceLine.is_source(val) self._code = val - def get_full_code(self, log): + def _get_full_code(self, log): if self.config.sw_pipelining.enabled: # Unroll the loop a fixed number of times iterations = 5 old_source, new_source = self.get_fully_unrolled_loop(iterations) reordering = self.get_reordering(iterations, no_gaps=True) else: - old_source = '\n'.join(self.orig_code) - new_source = '\n'.join(self.code) + old_source = self.orig_code + new_source = self.code reordering = self.reordering.copy() iterations = 1 @@ -417,8 +447,8 @@ def get_full_code(self, log): dfg_old_log = log.getChild("dfg_old") dfg_new_log = log.getChild("dfg_new") - SlothyBase.dump(f"Old code ({iterations} copies)", old_source, dfg_old_log) - SlothyBase.dump(f"New code ({iterations} copies)", new_source, dfg_new_log) + SourceLine.log(f"Old code ({iterations} copies)", old_source, dfg_old_log) + SourceLine.log(f"New code ({iterations} copies)", new_source, dfg_new_log) tree_old = DFG(old_source, dfg_old_log, DFGConfig(self.config, outputs=self.orig_outputs)) @@ -437,14 +467,59 @@ def selfcheck(self, log): try: res = self._selfcheck_core(log) except SlothyUselessInstructionException as exc: - raise SlothySelfCheckException("Useless instruction detected during selfcheck: FAIL!") from exc + raise SlothySelfCheckException("Useless instruction detected during selfcheck: FAIL!")\ + from exc if self.config.selfcheck and not res: raise SlothySelfCheckException("Isomorphism between computation flow graphs: FAIL!") return res + def selfcheck_with_fixup(self, log): + """Do selfcheck, and consider preamble/postamble fixup in case of SW pipelining + + In the presence of cross iteration dependencies, the preamble and postamble + may be functionally incorrect and need fixup.""" + + # We gather the log output of the initial selfcheck and only release + # it (a) on success, or (b) when even the selfcheck after fixup fails. + + defer_handler = DeferHandler() + log.propagate = False + log.addHandler(defer_handler) + + try: + retry = not self.selfcheck(log) + exception = None + except SlothySelfCheckException as e: + exception = e + + log.propagate = True + log.removeHandler(defer_handler) + + if exception and self.config.sw_pipelining.enabled: + retry = True + elif exception: + # We don't expect a failure if there are no cross-iteration dependencies + defer_handler.forward(log) + raise e + + if not retry: + # On success, show the log output + defer_handler.forward(log) + else: + log.info("Selfcheck failed! This sometimes happens in the presence "\ + "of cross-iteration dependencies. Try fixup...") + self.fixup_preamble_postamble(log.getChild("fixup_preamble_postamble")) + + try: + self.selfcheck(log.getChild("after_fixup")) + except SlothySelfCheckException as e: + log.error("Here is the output of the original selfcheck before fixup") + defer_handler.forward(log) + raise e + def _selfcheck_core(self, log): _, old_source, new_source, tree_old, tree_new, reordering = \ - self.get_full_code(log) + self._get_full_code(log) edges_old = tree_old.edges() edges_new = tree_new.edges() @@ -461,9 +536,9 @@ def _selfcheck_core(self, log): def apply_reordering(x): src,dst,lbl=x if not src in reordering.keys(): - raise Exception(f"Source ID {src} not in remapping {reordering.items()}") + raise SlothyException(f"Source ID {src} not in remapping {reordering.items()}") if not dst in reordering: - raise Exception(f"Destination ID {dst} not in remapping {reordering.items()}") + raise SlothyException(f"Destination ID {dst} not in remapping {reordering.items()}") return (reordering[src], reordering[dst], lbl) edges_old_remapped = set(map(apply_reordering, edges_old)) @@ -480,8 +555,8 @@ def apply_reordering(x): log.error("Input/Output renaming") log.error(reordering) - SlothyBase.dump("old code", old_source, log, err=True) - SlothyBase.dump("new code", new_source, log, err=True) + SourceLine.log("old code", old_source, log, err=True) + SourceLine.log("new code", new_source, log, err=True) new_not_old = [e for e in edges_new if e not in edges_old_remapped] old_not_new = [e for e in edges_old_remapped if e not in edges_new] @@ -500,18 +575,6 @@ def apply_reordering(x): log.error(f"New ({src_idx}:{src})"\ f"---{lbl}--->({dst_idx}:{dst}) not present in old graph") - src_idx_old = reordering_inv[src_idx] - dst_idx_old = reordering_inv[dst_idx] - src_old = tree_old.nodes_by_id[src_idx_old] - dst_old = tree_old.nodes_by_id[dst_idx_old] - log.error(f"Instructions in old graph: {src_old}, {dst_old}") - deps = [(s,d,l) for (s, d, l) in edges_old if s==src_idx_old and d==dst_idx_old] - if len(deps) > 0: - for (s,d,l) in deps: - log.error(f"Edge: {src_old} --{l}--> {dst_old}") - else: - log.error("No dependencies in old graph!") - for (src_idx,dst_idx,lbl) in old_not_new: src_idx_old = reordering_inv[src_idx] dst_idx_old = reordering_inv[dst_idx] @@ -520,21 +583,6 @@ def apply_reordering(x): log.error(f"Old ({src_old})[id:{src_idx_old}]"\ f"---{lbl}--->{dst_old}[id:{dst_idx_old}] not present in new graph") - src = tree_new.nodes_by_id.get(src_idx, None) - dst = tree_new.nodes_by_id.get(dst_idx, None) - - if src is not None and dst is not None: - log.error(f"Instructions in new graph: {src} --> {dst}") - deps = [(s,d,l) for (s,d,l) in edges_new if s==src_idx and d==dst_idx] - if len(deps) > 0: - for (s, d, l) in deps: - log.error(f"Edge: {src} --{l}--> {dst}") - else: - log.error("No dependencies in new graph!") - else: - log.error(f"Indices {src_idx} ({src}) and {dst_idx} ({dst})" - "don't both exist in new DFG?") - log.error("Isomorphism between computation flow graphs: FAIL!") return False @@ -573,6 +621,10 @@ def output_renamings(self, v): @property def stalls(self): + """The number of stalls in the optimization result. + + More precisely: The number of cycles c such that optimization succeeded with + up to c * issue_width unused issue slots.""" return self._stalls @stalls.setter def stalls(self, v): @@ -585,6 +637,8 @@ def _build_stalls_idxs(self): self.reordering_with_bubbles.values() } @property def stall_positions(self): + """The positions of instructions in the optimized assembly where SLOTHY + expects a stall or unused issue slot.""" if self._stalls_idxs is None: self._build_stalls_idxs() return self._stalls_idxs @@ -607,23 +661,19 @@ def kernel_input_output(self, val): self._kernel_input_output = val @property def preamble(self): - """When using software pipelining, the preamble to the loop kernel of the optimized loop.""" + """When using software pipelining, the preamble of the optimized loop.""" self._require_sw_pipelining() return self._preamble @preamble.setter def preamble(self, val): - # For now, double-check that we never set the preamble twice - # assert self._preamble is None self._preamble = val @property def postamble(self): - """When using software pipelining, the postamble to the loop kernel of the optimized loop.""" + """When using software pipelining, the postamble of the optimized loop.""" self._require_sw_pipelining() return self._postamble @postamble.setter def postamble(self, val): - # For now, double-check that we never set the preamble twice - # assert self._postamble is None self._postamble = val @property @@ -635,7 +685,7 @@ def config(self): def success(self): """Whether the optimization was successful""" if not self._valid: - raise Exception("Querying not-yet-populated result object") + raise SlothyException("Querying not-yet-populated result object") return self._success def __bool__(self): return self.success @@ -647,16 +697,236 @@ def success(self, val): @property def valid(self): + """Indicates whether the result object is valid.""" return self._valid - @valid.setter def valid(self, val): self._valid = val def _require_sw_pipelining(self): if not self.config.sw_pipelining.enabled: - raise Exception("Asking for SW-pipelining attribute in result of SLOTHY run" - " without SW pipelining") + raise SlothyException("Asking for SW-pipelining attribute in result " + "of SLOTHY run without SW pipelining") + + @staticmethod + def _fixup_reordered_pair(t0, t1, logger): + + def inst_changes_addr(inst): + return inst.increment is not None + + if not t0.inst.is_load_store_instruction(): + return + if not t1.inst.is_load_store_instruction(): + return + if not t0.inst.addr == t1.inst.addr: + return + if inst_changes_addr(t0.inst) and inst_changes_addr(t1.inst): + logger.error( "======================= ERROR ===============================") + logger.error(f" Cannot handle reordering of two instructions ({t0} and {t1}) ") + logger.error( " which both want to modify the same address ") + logger.error( "=================================================================") + raise SlothyException("Address fixup failure") + + if inst_changes_addr(t0.inst): + # t1 gets reordered before t0, which changes the address + # Adjust t1's address accordingly + logger.debug(f"{t0} moved after {t1}, bumping {t1.fixup} by {t0.inst.increment}, " + f"to {t1.fixup + int(simplify(t0.inst.increment))}") + t1.fixup += int(simplify(t0.inst.increment)) + elif inst_changes_addr(t1.inst): + # t0 gets reordered after t1, which changes the address + # Adjust t0's address accordingly + logger.debug(f"{t1} moved before {t0}, lowering {t0.fixup} by {t1.inst.increment}, " + f"to {t0.fixup - int(simplify(t1.inst.increment))}") + t0.fixup -= int(simplify(t1.inst.increment)) + + @staticmethod + def _fixup_reset(nodes): + for t in nodes: + t.fixup = 0 + + @staticmethod + def _fixup_finish(nodes, logger): + def inst_changes_addr(inst): + return inst.increment is not None + + for t in nodes: + if not t.inst.is_load_store_instruction(): + continue + if inst_changes_addr(t.inst): + continue + if t.fixup == 0: + continue + if t.inst.pre_index: + t.inst.pre_index = f"(({t.inst.pre_index}) + ({t.fixup}))" + else: + t.inst.pre_index = f"{t.fixup}" + logger.debug(f"Fixed up instruction {t.inst} by {t.fixup}, to {t.inst}") + + def _offset_fixup_sw(self, log): + n, _, _, _, tree_new, reordering = self._get_full_code(log) + iterations = n // self.codesize + + Result._fixup_reset(tree_new.nodes) + for _, _, ni, nj in Permutation.iter_swaps(reordering, n): + Result._fixup_reordered_pair(tree_new.nodes[ni], tree_new.nodes[nj], log) + Result._fixup_finish(tree_new.nodes, log) + + preamble_len = len(self.preamble) + postamble_len = len(self.postamble) + + assert n // iterations == self.codesize + + preamble_new = list(map(ComputationNode.to_source_line, tree_new.nodes[:preamble_len])) + postamble_new = [ ComputationNode.to_source_line(t) + for t in tree_new.nodes[-postamble_len:] ] \ + if postamble_len > 0 else [] + + code_new = [] + for i in range(iterations - self.num_exceptional_iterations): + code_new.append([ ComputationNode.to_source_line(t) for t in + tree_new.nodes[preamble_len + i*self.codesize: + preamble_len + (i+1)*self.codesize] ]) + + # Flag if address fixup makes the kernel instable. In this case, we'd have to + # widen preamble and postamble, but this is not yet implemented. + count = 0 + for i, (kcur, knext) in enumerate(zip(code_new, code_new[1:])): + if SourceLine.write_multiline(kcur) != SourceLine.write_multiline(knext): + count += 1 + if count != 0: + raise SlothyException("Instable loop kernel after post-optimization address fixup") + code_new = code_new[0] + + self.preamble = preamble_new + self.postamble = postamble_new + self.code = code_new + + def _offset_fixup_straightline(self, log): + n, _, _, _, tree_new, reordering = self._get_full_code(log) + + Result._fixup_reset(tree_new.nodes) + for _, _, ni, nj in Permutation.iter_swaps(reordering, n): + Result._fixup_reordered_pair(tree_new.nodes[ni], tree_new.nodes[nj], log) + Result._fixup_finish(tree_new.nodes, log) + + self.code = [ ComputationNode.to_source_line(t) for t in tree_new.nodes ] + + def offset_fixup(self, log): + """Fixup address offsets after optimization""" + if self.config.sw_pipelining.enabled: + self._offset_fixup_sw(log) + else: + self._offset_fixup_straightline(log) + + def fixup_preamble_postamble(self, log): + """Potentially fix up the preamble and postamble + + When software pipelining is used in the context of a loop with cross-iteration dependencies, + the core optimization step might lead to functionally incorrect preamble and postamble. + This function checks if this is the case and fixes preamble and postamble, if necessary. + """ + + #if not self._has_cross_iteration_dependencies(): + if not self.config.sw_pipelining.enabled: + return + + iterations = self.num_exceptional_iterations + assert iterations in [1,2] + + kernel = self.get_unrolled_kernel(iterations=iterations) + + perm = self.periodic_reordering_inv + assert Permutation.is_permutation(perm, self.codesize) + + dfgc_orig = DFGConfig(self.config, outputs=self.orig_outputs) + dfgc_kernel = DFGConfig(self.config, outputs=self.kernel_input_output) + + tree_orig = DFG(self.orig_code, log.getChild("orig"), dfgc_orig) + + def is_in_preamble(t): + if t.orig_pos is None: + return False + if iterations == 1: + return self.is_pre(t.orig_pos, original_program_order=False) + assert iterations == 2 + if t.orig_pos < self.codesize: + return self.is_pre(t.orig_pos, original_program_order=False) + return not self.is_post(t.orig_pos % self.codesize, + original_program_order=False) + + def is_in_postamble(t): + if t.orig_pos is None: + return False + if iterations == 1: + return not self.is_pre(t.orig_pos, original_program_order=False) + assert iterations == 2 + if t.orig_pos < self.codesize: + return not self.is_pre(t.orig_pos, original_program_order=False) + return self.is_post(t.orig_pos % self.codesize, + original_program_order=False) + + tree_kernel = DFG(kernel, log.getChild("ssa"), dfgc_kernel) + tree_kernel.ssa() + + # Go through early instructions that depend on an instruction from + # the previous iteration. Remap those dependencies as input dependencies. + for (consumer, producer, _, _) in tree_kernel.iter_dependencies(): + producer = producer.reduce() + if not (is_in_preamble(consumer) and not is_in_preamble(producer.src)): + continue + if producer.src.is_virtual: + continue + orig_pos = perm[producer.src.orig_pos % self.codesize] + assert isinstance(producer, InstructionOutput) + producer.src.inst.args_out[producer.idx] = \ + tree_orig.nodes[orig_pos].inst.args_out[producer.idx] + + # Update input and in-out register names + for t in tree_kernel.nodes_all: + for i, v in enumerate(t.src_in): + t.inst.args_in[i] = v.name() + for i, v in enumerate(t.src_in_out): + t.inst.args_in_out[i] = v.name() + + new_preamble = [ ComputationNode.to_source_line(t) + for t in tree_kernel.nodes if is_in_preamble(t) ] + self.preamble = new_preamble + SourceLine.log("New preamble", self.preamble, log) + + dfgc_preamble = DFGConfig(self.config, outputs=self.kernel_input_output) + dfgc_preamble.inputs_are_outputs = False + DFG(self.preamble, log.getChild("new_preamble"), dfgc_preamble) + + tree_kernel = DFG(kernel, log.getChild("ssa"), dfgc_kernel) + tree_kernel.ssa() + + # Go through non-early instructions that feed into an instruction from + # the next iteration. Remap those dependencies as input dependencies. + for (consumer, producer, _, _) in tree_kernel.iter_dependencies(): + producer = producer.reduce() + if not (is_in_postamble(producer.src) and not is_in_postamble(consumer)): + continue + orig_pos = perm[producer.src.orig_pos % self.codesize] + assert isinstance(producer, InstructionOutput) + producer.src.inst.args_out[producer.idx] = \ + tree_orig.nodes[orig_pos].inst.args_out[producer.idx] + + # Update input and in-out register names + for t in tree_kernel.nodes_all: + for i, v in enumerate(t.src_in): + t.inst.args_in[i] = v.reduce().name() + for i, v in enumerate(t.src_in_out): + t.inst.args_in_out[i] = v.reduce().name() + + new_postamble = [ ComputationNode.to_source_line(t) + for t in tree_kernel.nodes if is_in_postamble(t) ] + self.postamble = new_postamble + SourceLine.log("New postamble", self.postamble, log) + + dfgc_postamble = DFGConfig(self.config, outputs=self.orig_outputs) + DFG(self.postamble, log.getChild("new_postamble"), dfgc_postamble) + def __init__(self, config): super().__init__() @@ -678,15 +948,15 @@ def __init__(self, config): self._kernel_input_output = None self._pre_core_post_dict = None self._codesize_with_bubbles = None + self._register_used = None self.lock() class SlothySelfCheckException(Exception): - pass + """Exception thrown upon selfcheck failures""" class SlothyBase(LockAttributes): - """Stateless core of SLOTHY -- - [S]uper ([L]azy) [O]ptimization of [T]ricky [H]andwritten assembl[Y] + """Stateless core of SLOTHY. This class is the technical heart of the package: It implements the conversion of a software optimization problem into a constraint solving @@ -712,10 +982,12 @@ def target(self): @property def result(self): + """The result object of the last optimization.""" return self._result @property def success(self): + """Indicates whether the last optimiation succeeded.""" return self._result.success def __init__(self, Arch, Target, *, logger=None, config=None): @@ -752,7 +1024,7 @@ def _reset(self): def _set_timeout(self, timeout): if timeout is None: return - self.logger.info(f"Setting timeout of %d seconds...", timeout) + self.logger.info("Setting timeout of %d seconds...", timeout) self._model.cp_solver.parameters.max_time_in_seconds = timeout def optimize(self, source, prefix_len=0, suffix_len=0, log_model=None, retry=False): @@ -823,7 +1095,7 @@ def optimize(self, source, prefix_len=0, suffix_len=0, log_model=None, retry=Fal self._add_constraints_scheduling() self._add_constraints_lifetime_bounds() self._add_constraints_loop_optimization() - self._add_constraints_N_issue() + self._add_constraints_n_issue() self._add_constraints_dependency_order() self._add_constraints_latencies() self._add_constraints_register_renaming() @@ -841,11 +1113,13 @@ def optimize(self, source, prefix_len=0, suffix_len=0, log_model=None, retry=Fal self._result = Result(self.config) # Do the actual work - self.logger.info(f"Invoking external constraint solver ({self._describe_solver()}) ...") - self.result._success = self._solve() - if not retry and self.result._success: - self.logger.info(f"Booleans in result: {self._model.cp_solver.NumBooleans()}") - self.result._valid = True + self.logger.info("Invoking external constraint solver (%s) ...", self._describe_solver()) + self.result.success = self._solve() + self.result.valid = True + + if not retry and self.success: + self.logger.info("Booleans in result: %d", self._model.cp_solver.NumBooleans()) + if not self.success: return False @@ -853,20 +1127,21 @@ def optimize(self, source, prefix_len=0, suffix_len=0, log_model=None, retry=Fal return True def _load_source(self, source, prefix_len=0, suffix_len=0): + assert SourceLine.is_source(source) + # TODO: This does not belong here if self.config.sw_pipelining.enabled and \ ( prefix_len >0 or suffix_len > 0 ): - raise Exception("Invalid arguments") + raise SlothyException("Invalid arguments") - source = AsmHelper.reduce_source(source) - SlothyBase.dump("Source code", source, self.logger.input) + source = SourceLine.reduce_source(source) + SourceLine.log("Source code", source, self.logger.input) self._orig_code = source.copy() - source = '\n'.join(source) # Convert source code to computational flow graph if self.config.sw_pipelining.enabled: - source = source + '\n' + source + source = source + source self._model.tree = DFG(source, self.logger.getChild("dataflow"), DFGConfig(self.config)) @@ -906,7 +1181,7 @@ def _init_model_internals(self): def _usage_check(self): if self._num_optimization_passes > 0: - raise Exception("At the moment, SlothyBase should be used for one-shot optimizations") + raise SlothyException("SlothyBase should be used for one-shot optimizations") self._num_optimization_passes += 1 def _reg_is_architectural(self,reg,ty): @@ -935,7 +1210,7 @@ def static_renaming(conf_val, t): arch_str = "arch" if is_arch else "symbolic" if not isinstance(conf_val, dict): - raise Exception(f"Couldn't make sense of renaming configuration {conf_val}") + raise SlothyException(f"Couldn't make sense of renaming configuration {conf_val}") # Try to look up register in dictionary. There are three ways # it can be specified: Directly by name, via the "arch/symbolic" @@ -947,7 +1222,7 @@ def static_renaming(conf_val, t): val = val if val is not None else conf_val.get( "other" , None ) if val is None: - raise Exception( f"Register {reg} not present in renaming config {conf_val}") + raise SlothyException( f"Register {reg} not present in renaming config {conf_val}") # There are three choices for the value: # - "static" for static assignment, which will statically assign a value @@ -957,12 +1232,12 @@ def static_renaming(conf_val, t): if val == "static": canonical_static_assignment = reg if is_arch else None return True, canonical_static_assignment - elif val == "any": + if val == "any": return False, None - else: - if not self._reg_is_architectural(val,ty): - raise Exception(f"Invalid renaming configuration {val} for {reg}") - return True, val + + if not self._reg_is_architectural(val,ty): + raise SlothyException(f"Invalid renaming configuration {val} for {reg}") + return True, val def tag_input(t): static, val = static_renaming(self.config.rename_inputs, t) @@ -993,7 +1268,7 @@ def get_fresh_renaming_reg(ty): try: # Iterate statically renamed inputs/outputs which have not yet been assigned for v in inputs_tagged + outputs_tagged: - if v.static == False or v.reg is not None: + if v.static is False or v.reg is not None: continue v.reg = get_fresh_renaming_reg(v.ty) except OutOfRegisters as e: @@ -1055,20 +1330,6 @@ def _backup_original_code(self): for t in self._get_nodes(): t.inst_orig = deepcopy(t.inst) - @staticmethod - def dump(name, s, logger=None, err=False): - if err: - fun = logger.error - else: - fun = logger.debug - if isinstance(s,str): - s = s.splitlines() - if len(s) == 0: - return - fun(f"Dump: {name}") - for l in s: - fun(f"> {l}") - class CpSatSolutionCb(cp_model.CpSolverSolutionCallback): """A solution callback class represents objects that are alive during CP-SAT operation and equipped with a callback that is triggered every time CP-SAT finds a new solution. @@ -1100,124 +1361,6 @@ def solution_count(self): """The number of solutions found so far""" return self.__solution_count - @staticmethod - def _fixup_reordered_pair(t0, t1, logger, unsafe_skip_address_fixup=False): - - def inst_changes_addr(inst): - return inst.increment is not None - - if not t0.inst.is_load_store_instruction(): - return - if not t1.inst.is_load_store_instruction(): - return - if not t0.inst.addr == t1.inst.addr: - return - if inst_changes_addr(t0.inst) and inst_changes_addr(t1.inst): - if not unsafe_skip_address_fixup: - logger.error( "======================= ERROR ===============================") - logger.error(f" Cannot handle reordering of two instructions ({t0} and {t1}) ") - logger.error( " which both want to modify the same address ") - logger.error( "=================================================================") - raise Exception("Address fixup failure") - - logger.warning( "========================= WARNING ============================") - logger.warning(f" Cannot handle reordering of two instructions ({t0} and {t1}) ") - logger.warning( " which both want to modify the same address ") - logger.warning( " Skipping this -- you have to fix the address offsets manually ") - logger.warning( "==================================================================") - return - if inst_changes_addr(t0.inst): - # t1 gets reordered before t0, which changes the address - # Adjust t1's address accordingly - logger.debug(f"{t0} moved after {t1}, bumping {t1.fixup} by {t0.inst.increment}, " - f"to {t1.fixup + int(simplify(t0.inst.increment))}") - t1.fixup += int(simplify(t0.inst.increment)) - elif inst_changes_addr(t1.inst): - # t0 gets reordered after t1, which changes the address - # Adjust t0's address accordingly - logger.debug(f"{t1} moved before {t0}, lowering {t0.fixup} by {t1.inst.increment}, " - f"to {t0.fixup - int(simplify(t1.inst.increment))}") - t0.fixup -= int(simplify(t1.inst.increment)) - - @staticmethod - def _fixup_reset(nodes): - for t in nodes: - t.fixup = 0 - - @staticmethod - def _fixup_finish(nodes, logger): - def inst_changes_addr(inst): - return inst.increment is not None - - for t in nodes: - if not t.inst.is_load_store_instruction(): - continue - if inst_changes_addr(t.inst): - continue - if t.fixup == 0: - continue - if t.inst.pre_index: - t.inst.pre_index = f"(({t.inst.pre_index}) + ({t.fixup}))" - else: - t.inst.pre_index = f"{t.fixup}" - logger.debug(f"Fixed up instruction {t.inst} by {t.fixup}, to {t.inst}") - - def _offset_fixup_sw(self, log): - n, _, _, _, tree_new, reordering = self._result.get_full_code(log) - iterations = n // self._result.codesize - - SlothyBase._fixup_reset(tree_new.nodes) - for _, _, ni, nj in Permutation.iter_swaps(reordering, n): - SlothyBase._fixup_reordered_pair(tree_new.nodes[ni], tree_new.nodes[nj], log) - SlothyBase._fixup_finish(tree_new.nodes, log) - - preamble_len = len(self._result.preamble) - postamble_len = len(self._result.postamble) - - assert n // iterations == self._result.codesize - - preamble_new = [ str(t.inst) for t in tree_new.nodes[:preamble_len] ] - postamble_new = [ str(t.inst) for t in tree_new.nodes[-postamble_len:] ] \ - if postamble_len > 0 else [] - - code_new = [] - for i in range(iterations - self._result.num_exceptional_iterations): - code_new.append([ str(t.inst) for t in - tree_new.nodes[preamble_len + i*self._result.codesize: - preamble_len + (i+1)*self._result.codesize] ]) - - # Flag if address fixup makes the kernel instable. In this case, we'd have to - # widen preamble and postamble, but this is not yet implemented. - count = 0 - for i, (kcur, knext) in enumerate(zip(code_new, code_new[1:])): - if kcur != knext: - count += 1 - if count != 0: - raise Exception("Instable loop kernel after post-optimization address fixup") - code_new = code_new[0] - - self._result.preamble = preamble_new - self._result.postamble = postamble_new - self._result.code = code_new - - def _offset_fixup_straightline(self, log): - n, _, _, _, tree_new, reordering = self._result.get_full_code(log) - - SlothyBase._fixup_reset(tree_new.nodes) - for _, _, ni, nj in Permutation.iter_swaps(reordering, n): - SlothyBase._fixup_reordered_pair(tree_new.nodes[ni], tree_new.nodes[nj], log) - SlothyBase._fixup_finish(tree_new.nodes, log) - - self._result.code = [ str(t.inst) for t in tree_new.nodes ] - - def offset_fixup(self): - """Fixup address offsets after optimization""" - log = self.logger.getChild("offset_fixup") - if self.config.sw_pipelining.enabled: - self._offset_fixup_sw(log) - else: - self._offset_fixup_straightline(log) - def fixup_preamble_postamble(self): """Potentially fix up the preamble and postamble @@ -1233,9 +1376,7 @@ def fixup_preamble_postamble(self): log = self.logger.getChild("fixup_preamble_postamble") iterations = self._result.num_exceptional_iterations - assert iterations == 1 or iterations == 2 - - n = self._result.codesize * iterations + assert iterations in [1,2] kernel = self._result.get_unrolled_kernel(iterations=iterations) @@ -1252,24 +1393,26 @@ def is_in_preamble(t): return False if iterations == 1: return self._result.is_pre(t.orig_pos, original_program_order=False) - elif iterations == 2: - if t.orig_pos < self._result.codesize: - return self._result.is_pre(t.orig_pos, original_program_order=False) - else: - return not self._result.is_post(t.orig_pos % self._result.codesize, - original_program_order=False) + + assert iterations == 2 + if t.orig_pos < self._result.codesize: + return self._result.is_pre(t.orig_pos, original_program_order=False) + + return not self._result.is_post(t.orig_pos % self._result.codesize, + original_program_order=False) def is_in_postamble(t): if t.orig_pos is None: return False if iterations == 1: return not self._result.is_pre(t.orig_pos, original_program_order=False) - elif iterations == 2: - if t.orig_pos < self._result.codesize: - return not self._result.is_pre(t.orig_pos, original_program_order=False) - else: - return self._result.is_post(t.orig_pos % self._result.codesize, - original_program_order=False) + + assert iterations == 2 + if t.orig_pos < self._result.codesize: + return not self._result.is_pre(t.orig_pos, original_program_order=False) + + return self._result.is_post(t.orig_pos % self._result.codesize, + original_program_order=False) tree_kernel = DFG(kernel, log.getChild("ssa"), dfgc_kernel) tree_kernel.ssa() @@ -1294,9 +1437,10 @@ def is_in_postamble(t): for i, v in enumerate(t.src_in_out): t.inst.args_in_out[i] = v.name() - new_preamble = [ str(t.inst) for t in tree_kernel.nodes if is_in_preamble(t) ] + new_preamble = [ ComputationNode.to_source_line(t) + for t in tree_kernel.nodes if is_in_preamble(t) ] self._result.preamble = new_preamble - SlothyBase.dump("New preamble", self._result.preamble, log) + SourceLine.log("New preamble", self._result.preamble, log) dfgc_preamble = DFGConfig(self.config, outputs=self._result.kernel_input_output) dfgc_preamble.inputs_are_outputs = False @@ -1323,9 +1467,10 @@ def is_in_postamble(t): for i, v in enumerate(t.src_in_out): t.inst.args_in_out[i] = v.reduce().name() - new_postamble = [ str(t.inst) for t in tree_kernel.nodes if is_in_postamble(t) ] + new_postamble = [ ComputationNode.to_source_line(t) + for t in tree_kernel.nodes if is_in_postamble(t) ] self._result.postamble = new_postamble - SlothyBase.dump("New postamble", self._result.postamble, log) + SourceLine.log("New postamble", self._result.postamble, log) dfgc_postamble = DFGConfig(self.config, outputs=self._result.orig_outputs) DFG(self._result.postamble, log.getChild("new_postamble"), dfgc_postamble) @@ -1341,48 +1486,8 @@ def _extract_result(self): self._extract_input_output_renaming() self._extract_code() - - # In the presence of cross iteration dependencies, the preamble and postamble - # may be functionally incorrect and need fixup. - # We therefore gather the log output of the initial selfcheck and only release - # it (a) on success, or (b) when even the selfcheck after fixup fails. - - log = self.logger.getChild("selfcheck") - defer_handler = DeferHandler() - log.propagate = False - log.addHandler(defer_handler) - - try: - retry = not self._result.selfcheck(log) - exception = None - except SlothySelfCheckException as e: - exception = e - - log.propagate = True - log.removeHandler(defer_handler) - - if exception and self._has_cross_iteration_dependencies(): - retry = True - elif exception: - # We don't expect a failure if there are no cross-iteration dependencies - defer_handler.forward(log) - raise e - - if not retry: - # On success, show the log output - defer_handler.forward(log) - else: - self.logger.info("Selfcheck failed! This sometimes happens in the presence of cross-iteration dependencies. Try fixup...") - self.fixup_preamble_postamble() - - try: - self._result.selfcheck(self.logger.getChild("selfcheck_after_fixup")) - except SlothySelfCheckException as e: - self.logger.error("Here is the output of the original selfcheck before fixup") - defer_handler.forward(log) - raise e - - self.offset_fixup() + self._result.selfcheck_with_fixup(self.logger.getChild("selfcheck")) + self._result.offset_fixup(self.logger.getChild("fixup")) def _extract_positions(self, get_value): @@ -1477,10 +1582,6 @@ def _extract_kernel_input_output(self): def _extract_code(self): - def add_indentation(src): - indentation = ' ' * self.config.indentation - src = [ indentation + s for s in src ] - def get_code(filter_func=None, top=False): if len(self._model.tree.nodes) == 0: return @@ -1494,7 +1595,7 @@ def get_code_line(line_no): t = self._model.tree.nodes[periodic_reordering_with_bubbles_inv[line_no]] if filter_func and not filter_func(t): return - yield str(t.inst) + yield ComputationNode.to_source_line(t) base = 0 lines = self._result.codesize_with_bubbles @@ -1510,42 +1611,54 @@ def get_code_line(line_no): preamble += list(get_code(filter_func=lambda t: t.pre, top=True)) if self._result.num_post > 0: preamble += list(get_code(filter_func=lambda t: not t.post)) - self._result.preamble = preamble postamble = [] if self._result.num_pre > 0: postamble += list(get_code(filter_func=lambda t: not t.pre, top=True)) if self._result.num_post > 0: postamble += list(get_code(filter_func=lambda t: t.post)) - self._result.postamble = postamble - self._result.code = list(get_code()) - self._extract_kernel_input_output() + kernel = list(get_code()) log = self.logger.result.getChild("sw_pipelining") log.debug("Kernel dependencies: %s", self._result.kernel_input_output) - SlothyBase.dump("Preamble", self._result.preamble, log) - SlothyBase.dump("Kernel", self._result.kernel, log) - SlothyBase.dump("Postamble", self._result.postamble, log) + SourceLine.log("Preamble", preamble, log) + SourceLine.log("Kernel", kernel, log) + SourceLine.log("Postamble", postamble, log) + + preamble = SourceLine.apply_indentation(preamble, self.config.indentation) + postamble = SourceLine.apply_indentation(postamble, self.config.indentation) + kernel = SourceLine.apply_indentation(kernel, self.config.indentation) - add_indentation(self._result.preamble) - add_indentation(self._result.kernel) - add_indentation(self._result.postamble) + if self.config.keep_tags is False: + SourceLine.drop_tags(preamble) + SourceLine.drop_tags(postamble) + SourceLine.drop_tags(kernel) + + self._result.preamble = preamble + self._result.postamble = postamble + self._result.code = kernel + + self._extract_kernel_input_output() else: - self._result.code = list(get_code()) + code = list(get_code()) + code = SourceLine.apply_indentation(code, self.config.indentation) + + if self.config.keep_tags is False: + SourceLine.drop_tags(code) + + self._result.code = code self.logger.result.debug("Optimized code") for s in self._result.code: - self.logger.result.debug("> " + s.strip()) - - add_indentation(self._result.code) + self.logger.result.debug("> " + str(s).strip()) if self.config.visualize_reordering: self._result._code += self._result.orig_code_visualized - def _add_path_constraint( self, consumer, producer, cb, force=False): + def _add_path_constraint( self, consumer, producer, cb): """Add model constraint cb() relating to the pair of producer-consumer instructions Outside of loop mode, this ignores producer and consumer, and just adds cb(). In loop mode, however, the condition has to be omitted in two cases: @@ -1607,16 +1720,15 @@ def _get_nodes_by_program_order(self, low=False, high=False, allnodes=False, inputs=False, outputs=False): if low: return self._model.tree.nodes_low - elif high: + if high: return self._model.tree.nodes_high - elif allnodes: + if allnodes: return self._model.tree.nodes_all - elif inputs: + if inputs: return self._model.tree.nodes_input - elif outputs: + if outputs: return self._model.tree.nodes_output - else: - return self._model.tree.nodes + return self._model.tree.nodes def _get_nodes_by_depth(self, **kwargs): return sorted(self._get_nodes_by_program_order(**kwargs), @@ -1625,8 +1737,7 @@ def _get_nodes_by_depth(self, **kwargs): def _get_nodes(self, by_depth=False, **kwargs): if by_depth: return self._get_nodes_by_depth(**kwargs) - else: - return self._get_nodes_by_program_order(**kwargs) + return self._get_nodes_by_program_order(**kwargs) # ================================================================ # VARIABLES (Instruction scheduling) # @@ -1694,7 +1805,7 @@ def _add_variables_functional_units(self): t.unique_unit = False t.exec_unit_choices = {} for unit_choices in units: - if type(unit_choices) != list: + if not isinstance(unit_choices, list): unit_choices = [unit_choices] for unit in unit_choices: unit_var = self._NewBoolVar(f"[{t.inst}].unit_choice.{unit}") @@ -1721,13 +1832,15 @@ def make_start_var(name=""): # When we optimize for longest register lifetimes, we allow the starting time of the # usage interval to be smaller than the program order position of the instruction. if self.config.flexible_lifetime_start: - t.out_lifetime_start = [ make_start_var(f"{t.varname()}_out_{i}_lifetime_start") - for i in range(t.inst.num_out) ] - t.inout_lifetime_start = [ make_start_var(f"{t.varname()}_inout_{i}_lifetime_start") - for i in range(t.inst.num_in_out) ] + t.out_lifetime_start = [ + make_start_var(f"{t.varname()}_out_{i}_lifetime_start") + for i in range(t.inst.num_out) ] + t.inout_lifetime_start = [ + make_start_var(f"{t.varname()}_inout_{i}_lifetime_start") + for i in range(t.inst.num_in_out) ] else: - t.out_lifetime_start = [ t.program_start_var for i in range(t.inst.num_out) ] - t.inout_lifetime_start = [ t.program_start_var for i in range(t.inst.num_in_out) ] + t.out_lifetime_start = [ t.program_start_var for _ in range(t.inst.num_out) ] + t.inout_lifetime_start = [ t.program_start_var for _ in range(t.inst.num_in_out) ] t.out_lifetime_end = [ make_var(f"{t.varname()}_out_{i}_lifetime_end") for i in range(t.inst.num_out) ] @@ -1745,28 +1858,10 @@ def make_start_var(name=""): def _add_variables_register_renaming(self): """Add boolean variables indicating if an instruction uses a certain output register""" - def get_metric(t): - return int(t.id) // (max(t.depth,1)) - - if self.config.constraints.restricted_renaming is not None: - nodes_sorted_by_metric = [ t for t in self._get_nodes() ] # Refs only - nodes_sorted_by_metric.sort(key=get_metric) - start_idx = int(len(nodes_sorted_by_metric) * - self.config.constraints.restricted_renaming) - renaming_allowed_list = nodes_sorted_by_metric[start_idx:] - - def _allow_renaming(t): + def _allow_renaming(_): if not self.config.constraints.allow_renaming: return False - if self.config.constraints.restricted_renaming is None: - return True - if t.is_virtual: - return True - if t in renaming_allowed_list: - self.logger.info("Exceptionally allow renaming for %s, position %s, depth %d", - t, t.id, t.depth) - return True - return False + return True self.logger.debug("Adding variables for register allocation...") @@ -1787,7 +1882,8 @@ def _allow_renaming(t): self.logger.debug("- Output %s (%s)", arg_out, arg_ty) - # Locked output register aren't renamed, and neither are outputs of locked instructions. + # Locked output register aren't renamed, and neither are + # outputs of locked instructions. self.logger.debug("Locked registers: %s", self.config.locked_registers) is_locked = arg_out in self.config.locked_registers # Symbolic registers are always renamed @@ -1815,7 +1911,7 @@ def _allow_renaming(t): self.logger.error("Original candidates: %s", candidates) self.logger.error("Restricted candidates: %s", candidates_restricted) self.logger.error("Restrictions: %s", restrictions) - raise Exception() + raise SlothyException() self.logger.input.debug("Registers available for renaming of " f"[{t.inst}].{arg_out} ({t.orig_pos})") @@ -1877,7 +1973,7 @@ def add_arg_combination_vars(combinations, vs, name, t=t): ## Create intervals tracking the usage of registers for t in self._get_nodes(allnodes=True): - self.logger.debug(f"Create register usage intervals for {t}") + self.logger.debug("Create register usage intervals for %s", t) ivals = [] ivals += list(zip(t.inst.arg_types_out, t.alloc_out_var, @@ -1944,16 +2040,16 @@ def _iter_dependencies_with_lifetime(self): def _get_lifetime_start(src): if isinstance(src, InstructionOutput): return src.src.out_lifetime_start[src.idx] - elif isinstance(src, InstructionInOut): + if isinstance(src, InstructionInOut): return src.src.inout_lifetime_start[src.idx] - raise Exception("Unknown register source") + raise SlothyException("Unknown register source") def _get_lifetime_end(src): if isinstance(src, InstructionOutput): return src.src.out_lifetime_end[src.idx] - elif isinstance(src, InstructionInOut): + if isinstance(src, InstructionInOut): return src.src.inout_lifetime_end[src.idx] - raise Exception("Unknown register source") + raise SlothyException("Unknown register source") for (consumer, producer, ty, idx) in self._iter_dependencies(): start_var = _get_lifetime_start(producer) @@ -1997,8 +2093,9 @@ def _add_constraints_lifetime_bounds(self): # For every instruction depending on the output, add a lifetime bound for (consumer, producer, _, _, _, end_var, _) in \ self._iter_dependencies_with_lifetime(): - self._add_path_constraint(consumer, producer.src, lambda end_var=end_var, consumer=consumer: - self._Add(end_var >= consumer.program_start_var), force=True) + self._add_path_constraint(consumer, producer.src, + lambda end_var=end_var, consumer=consumer: + self._Add(end_var >= consumer.program_start_var)) # ================================================================ # CONSTRAINTS (Register allocation) # @@ -2040,7 +2137,9 @@ def _force_renaming_collision(self, var_dic_a, var_dic_b): def _force_allocation_restriction_single(self, valid_allocs, var_dict): for k,v in var_dict.items(): if k not in valid_allocs: - self._Add(v == False) + # Disabling pylint warning here since we're building a + # CP-SAT constraint here, rather than making a boolean comparison. + self._Add(v == False) # pylint:disable=singleton-comparison def _force_allocation_restriction_many(self, restriction_lst, var_dict_lst): for r, v in zip(restriction_lst, var_dict_lst, strict=True): @@ -2057,7 +2156,7 @@ def _add_constraints_register_renaming(self): if len(arr) > 0: self._model.AddMaxEquality(self._register_used[reg], arr) else: - self._Add(self._register_used[reg] == False) + self._Add(self._register_used[reg] is False) # Ensure that outputs are unambiguous for t in self._get_nodes(allnodes=True): @@ -2075,8 +2174,10 @@ def _add_constraints_register_renaming(self): t.alloc_in_out_combinations_vars) # Enforce individual input argument restrictions (for outputs this has already # been done at the time when we created the allocation variables). - self._force_allocation_restriction_many(t.inst.args_in_restrictions, t.alloc_in_var) - self._force_allocation_restriction_many(t.inst.args_in_out_restrictions, t.alloc_in_out_var) + self._force_allocation_restriction_many(t.inst.args_in_restrictions, + t.alloc_in_var) + self._force_allocation_restriction_many(t.inst.args_in_out_restrictions, + t.alloc_in_out_var) # Enforce exclusivity of arguments self._forbid_renaming_collision_many( t.inst.args_in_out_different, t.alloc_out_var, @@ -2090,10 +2191,10 @@ def find_out_node(t_in): c = list(filter(lambda t: t.inst.orig_reg == t_in.inst.orig_reg, self._model.tree.nodes_output)) if len(c) == 0: - raise Exception("Could not find matching output for input:" + + raise SlothyException("Could not find matching output for input:" + t_in.inst.orig_reg) if len(c) > 1: - raise Exception("Found multiple matching output nodes for input: " + + raise SlothyException("Found multiple matching output nodes for input: " + f"{t_in.inst.orig_reg}: {c}") return c[0] for t_in in self._model.tree.nodes_input: @@ -2123,9 +2224,28 @@ def _add_constraints_loop_optimization(self): self._AddExactlyOne([t.pre_var, t.post_var, t.core_var]) + # Check if source line was tagged pre/core/post + force_pre = t.inst.source_line.tags.get("pre", None) + force_core = t.inst.source_line.tags.get("core", None) + force_post = t.inst.source_line.tags.get("post", None) + if force_pre is not None: + assert force_pre is True or force_pre is False + self._Add(t.pre_var == force_pre) + self.logger.debug("Force pre=%s instruction for %s", force_pre, t.inst) + if force_core is not None: + assert force_core is True or force_core is False + self._Add(t.core_var == force_core) + self.logger.debug("Force core=%s instruction for %s", force_core, t.inst) + if force_post is not None: + assert force_post is True or force_post is False + self._Add(t.post_var == force_post) + self.logger.debug("Force post=%s instruction for %s", force_post, t.inst) + if not self.config.sw_pipelining.allow_pre: + # pylint:disable=singleton-comparison self._Add(t.pre_var == False) if not self.config.sw_pipelining.allow_post: + # pylint:disable=singleton-comparison self._Add(t.post_var == False) if self.config.hints.all_core: @@ -2136,8 +2256,9 @@ def _add_constraints_loop_optimization(self): # Allow early instructions only in a certain range if self.config.sw_pipelining.max_pre < 1.0 and self._is_low(t): relpos = t.orig_pos / len(self._get_nodes(low=True)) - if relpos < 1 and relpos > self.config.sw_pipelining.max_pre: - self._Add( t.pre_var == False ) + if self.config.sw_pipelining.max_pre < relpos < 1: + # pylint:disable=singleton-comparison + self._Add(t.pre_var == False) if self.config.sw_pipelining.pre_before_post: for t, s in [(t,s) for t in self._get_nodes(low=True) \ @@ -2154,14 +2275,17 @@ def _add_constraints_loop_optimization(self): # An instruction with forward dependency to the next iteration # cannot be an early instruction, and an instruction depending # on an instruction from a previous iteration cannot be late. + + # pylint:disable=singleton-comparison self._Add(producer.src.pre_var == False) + # pylint:disable=singleton-comparison self._Add(consumer.post_var == False) # ================================================================ # CONSTRAINTS (Single issuing) # # ================================================================ - def _add_constraints_N_issue(self): + def _add_constraints_n_issue(self): self._AddAllDifferent([ t.program_start_var for t in self._get_nodes() ] ) if self.config.variable_size: @@ -2174,7 +2298,6 @@ def _add_constraints_N_issue(self): self._Add( t.program_start_var == t.cycle_start_var * self.target.issue_rate + t.slot_var ) - def _add_constraints_locked_ordering(self): def inst_changes_addr(inst): @@ -2203,11 +2326,52 @@ def _change_same_address(t0,t1): self._AddImplication( t0.pre_var, t1.post_var.Not() ) if _change_same_address(t0,t1): - self.logger.debug("Forbid reordering of (%s,%s) to avoid address fixup issues", t0, t1) + self.logger.debug("Forbid reordering of (%s,%s) to avoid address fixup issues", + t0, t1) self._add_path_constraint( t1, t0, lambda t0=t0, t1=t1: self._Add(t0.program_start_var < t1.program_start_var) ) + # Look for source annotations forcing orderings + + if self.config.sw_pipelining.enabled is True: + nodes = self._get_nodes(low=True) + else: + nodes = self._get_nodes() + + def find_node_by_source_id(src_id): + for t in nodes: + cur_id = t.inst.source_line.tags.get("id", None) + if cur_id == src_id: + return t + raise SlothyException(f"Could not find node with source ID {src_id}") + + for i, t1 in enumerate(nodes): + force_after = t1.inst.source_line.tags.get("after", []) + if not isinstance(force_after, list): + force_after = [force_after] + t0s = list(map(find_node_by_source_id, force_after)) + force_after_last = t1.inst.source_line.tags.get("after_last", False) + if force_after_last is True: + if i == 0: + # Ignore after_last tag for first instruction + continue + t0s.append(nodes[i-1]) + for t0 in t0s: + self.logger.info("Force %s < %s by source annotation", t0, t1) + self._add_path_constraint(t1, t0, + lambda t0=t0, t1=t1: self._Add(t0.program_start_var < t1.program_start_var)) + + for t0 in nodes: + force_before = t0.inst.source_line.tags.get("before", []) + if not isinstance(force_before, list): + force_before = [force_before] + for t1_id in force_before: + t1 = find_node_by_source_id(t1_id) + self.logger.info("Force %s < %s by source annotation", t0, t1) + self._add_path_constraint(t1, t0, + lambda t0=t0, t1=t1: self._Add(t0.program_start_var < t1.program_start_var)) + # ================================================================ # CONSTRAINTS (Single issuing) # # ================================================================ @@ -2303,11 +2467,19 @@ def _add_constraints_misc(self): self.target.add_further_constraints(self) def get_inst_pairs(self, cond=None): - if cond is None: - cond = lambda a,b: True + """Yields all instruction pairs satisfying the provided predicate. + + This can be useful for the specification of additional + microarchitecture-specific constraints. + + Args: + cond: Predicate on pairs of ComputationNode's. True by default. + + Returns: + Generator of all instruction pairs satisfying the predicate.""" for t0 in self._model.tree.nodes: for t1 in self._model.tree.nodes: - if cond(t0,t1): + if cond is None or cond(t0,t1): yield (t0,t1) # ================================================================# @@ -2355,13 +2527,16 @@ def _add_constraints_loop_periodic(self): self._Add( t0.post_var == t1.post_var ) self._Add( t0.core_var == t1.core_var ) # Early - self._Add( t0.program_start_var == t1.program_start_var + self._model.program_padded_size_half )\ + self._Add(t0.program_start_var == \ + t1.program_start_var + self._model.program_padded_size_half) \ .OnlyEnforceIf(t0.pre_var) # Core - self._Add( t1.program_start_var == t0.program_start_var + self._model.program_padded_size_half )\ + self._Add(t1.program_start_var == \ + t0.program_start_var + self._model.program_padded_size_half) \ .OnlyEnforceIf(t0.core_var) # Late - self._Add( t0.program_start_var == t1.program_start_var + self._model.program_padded_size_half )\ + self._Add(t0.program_start_var == \ + t1.program_start_var + self._model.program_padded_size_half) \ .OnlyEnforceIf(t0.post_var) ## Register allocations must be the same assert t0.inst.arg_types_out == t1.inst.arg_types_out @@ -2379,7 +2554,7 @@ def _add_constraints_loop_periodic(self): for reg in t1_vars: v0 = t0.alloc_out_var[o][reg] v1 = t1.alloc_out_var[o][reg] - self._Add( v0 == v1 ) + self._Add(v0 == v1) def restrict_early_late_instructions(self, filter_func): """Forces all instructions not passing the filter_func to be `core`, that is, @@ -2387,7 +2562,7 @@ def restrict_early_late_instructions(self, filter_func): This is only meaningful if software pipelining is enabled.""" if not self.config.sw_pipelining.enabled: - raise Exception("restrict_early_late_instructions() only useful in SW pipelining mode") + raise SlothyException("restrict_early_late_instructions() only in SW pipelining mode") for t in self._get_nodes(): if filter_func(t.inst): @@ -2400,12 +2575,12 @@ def force_early(self, filter_func, early=True): This is only meaningful if software pipelining is enabled.""" if not self.config.sw_pipelining.enabled: - raise Exception("force_early() only useful in SW pipelining mode") + raise SlothyException("force_early() only useful in SW pipelining mode") invalid_pre = early and not self.config.sw_pipelining.allow_pre invalid_post = not early and not self.config.sw_pipelining.allow_post if invalid_pre or invalid_post: - raise Exception("Invalid SW pipelining configuration in force_early()") + raise SlothyException("Invalid SW pipelining configuration in force_early()") for t in self._get_nodes(): if filter_func(t.inst): @@ -2458,8 +2633,8 @@ def restrict_slots_for_instructions_by_class(self, cls_lst, slots): provided list of instruction classes. Args: - - cls_lst: A list of instruction classes - - slots: A list of issue slots represented as integers.""" + cls_lst: A list of instruction classes + slots: A list of issue slots represented as integers.""" self.restrict_slots_for_instructions( self.filter_instructions_by_class(cls_lst), slots ) @@ -2468,8 +2643,8 @@ def restrict_slots_for_instructions_by_property(self, filter_func, slots): filter function. Args: - - cls_lst: A predicate on instructions - - slots: A list of issue slots represented as integers.""" + cls_lst: A predicate on instructions + slots: A list of issue slots represented as integers.""" self.restrict_slots_for_instructions( self.filter_instructions_by_property(filter_func), slots ) @@ -2497,7 +2672,8 @@ def _add_objective(self, force_objective=False): name = "minimize iteration overlapping" elif self.config.constraints.maximize_register_lifetimes: name = "maximize register lifetimes" - maxlist = [ v for t in self._get_nodes(allnodes=True) for v in t.out_lifetime_duration ] + maxlist = [ v for t in self._get_nodes(allnodes=True) + for v in t.out_lifetime_duration ] elif self.config.constraints.move_stalls_to_bottom is True: minlist = [ t.program_start_var for t in self._get_nodes() ] name = "move stalls to bottom" @@ -2542,8 +2718,7 @@ def _describe_solver(self): workers = self._model.cp_solver.parameters.num_workers if workers > 0: return f"OR-Tools CP-SAT v{ortools.__version__}, {workers} threads" - else: - return f"OR-Tools CP-SAT v{ortools.__version__}" + return f"OR-Tools CP-SAT v{ortools.__version__}" def _init_external_model_and_solver(self): self._model.cp_model = cp_model.CpModel() @@ -2557,44 +2732,38 @@ def _init_external_model_and_solver(self): self.logger.warning("Please consider upgrading OR-Tools to version >= 9.5.2040") self._model.cp_solver.parameters.symmetry_level = 1 - def _NewIntVar(self, minval, maxval, name=""): + def _NewIntVar(self, minval, maxval, name=""): # pylint:disable=invalid-name r = self._model.cp_model.NewIntVar(minval,maxval, name) self._model.variables.append(r) return r - def _NewIntervalVar(self, base, dur, end, name=""): - r = self._model.cp_model.NewIntervalVar(base,dur,end,name) - return r - def _NewOptionalIntervalVar(self, base, dur, end, cond,name=""): - r = self._model.cp_model.NewOptionalIntervalVar(base,dur,end,cond,name) - return r - def _NewBoolVar(self, name=""): + def _NewIntervalVar(self, base, dur, end, name=""): # pylint:disable=invalid-name + return self._model.cp_model.NewIntervalVar(base,dur,end,name) + def _NewOptionalIntervalVar(self, base, dur, end, cond,name=""): # pylint:disable=invalid-name + return self._model.cp_model.NewOptionalIntervalVar(base,dur,end,cond,name) + def _NewBoolVar(self, name=""): # pylint:disable=invalid-name r = self._model.cp_model.NewBoolVar(name) self._model.variables.append(r) return r - def _NewConstant(self, val): + def _NewConstant(self, val): # pylint:disable=invalid-name r = self._model.cp_model.NewConstant(val) return r - def _Add(self,c): + def _Add(self,c): # pylint:disable=invalid-name return self._model.cp_model.Add(c) - def _AddNoOverlap(self,lst): + def _AddNoOverlap(self,lst): # pylint:disable=invalid-name return self._model.cp_model.AddNoOverlap(lst) - def _AddExactlyOne(self,lst): + def _AddExactlyOne(self,lst): # pylint:disable=invalid-name return self._model.cp_model.AddExactlyOne(lst) - def _AddImplication(self,a,b): + def _AddImplication(self,a,b): # pylint:disable=invalid-name return self._model.cp_model.AddImplication(a,b) - def _AddAtLeastOne(self,lst): + def _AddAtLeastOne(self,lst): # pylint:disable=invalid-name return self._model.cp_model.AddAtLeastOne(lst) - def _AddAbsEq(self,dst,expr): + def _AddAbsEq(self,dst,expr): # pylint:disable=invalid-name return self._model.cp_model.AddAbsEquality(dst,expr) - def _AddAllDifferent(self,lst): - if len(lst) < 2: - return + def _AddAllDifferent(self,lst): # pylint:disable=invalid-name return self._model.cp_model.AddAllDifferent(lst) - def _AddHint(self,var,val): + def _AddHint(self,var,val): # pylint:disable=invalid-name return self._model.cp_model.AddHint(var,val) - def _AddNoOverlap(self,interval_list): - if len(interval_list) < 2: - return + def _AddNoOverlap(self,interval_list): # pylint:disable=invalid-name return self._model.cp_model.AddNoOverlap(interval_list) def _export_model(self, log_model): @@ -2649,7 +2818,7 @@ def retry(self, fix_stalls=None): self._set_timeout(self.config.retry_timeout) # - Objective - self._add_objective(force_objective = (fix_stalls is not None)) + self._add_objective(force_objective = fix_stalls is not None) # Do the actual work self.logger.info("Invoking external constraint solver...") @@ -2663,4 +2832,4 @@ def retry(self, fix_stalls=None): def _dump_model_statistics(self): # Extract and report results - SlothyBase.dump("Statistics", self._model.cp_model.cp_solver.ResponseStats(), self.logger) + SourceLine.log("Statistics", self._model.cp_model.cp_solver.ResponseStats(), self.logger) diff --git a/slothy/dataflow.py b/slothy/core/dataflow.py similarity index 88% rename from slothy/dataflow.py rename to slothy/core/dataflow.py index e25d61c0..965a6db6 100644 --- a/slothy/dataflow.py +++ b/slothy/core/dataflow.py @@ -26,7 +26,7 @@ # from functools import cached_property -from .helper import AsmHelper +from slothy.helper import SourceLine class SlothyUselessInstructionException(Exception): """An instruction was found whose outputs are neither used by a subsequent instruction @@ -127,6 +127,7 @@ def __init__(self, reg, reg_ty): self.num_in = 1 self.args_in = [reg] self.arg_types_in = [reg_ty] + self.args_in_restrictions = [None] def write(self): return f"// output renaming: {self.orig_reg} -> {self.args_in_out[0]}" @@ -141,6 +142,7 @@ def __init__(self, reg, reg_ty): self.num_out = 1 self.args_out = [reg] self.arg_types_out = [reg_ty] + self.args_out_restrictions = [None] def write(self): return f"// input renaming: {self.orig_reg} -> {self.args_out[0]}" @@ -209,6 +211,17 @@ def isinstancelist(l, c): self.dst_out = [ [] for _ in range(inst.num_out) ] self.dst_in_out = [ [] for _ in range(inst.num_in_out) ] + def to_source_line(self): + """Convert node in data flor graph to source line. + + This keeps original tags and comments from the source line that + gave rise to the node, but updates the text with the stringification + of the instruction underlying the node. + """ + line = self.inst.source_line.copy() + inst_txt = str(self.inst) + return line.set_text(inst_txt) + @cached_property def is_virtual_input(self): """Indicates whether the node is an input node.""" @@ -278,9 +291,9 @@ def arch(self): def typing_hints(self): """A dictionary of 'typing hints' explicitly assigning to symbolic register names a register type. - - This can be necessary to disambiguate the type of symbolic registers. - For example, the Helium vector extension has various instructions which + + This can be necessary to disambiguate the type of symbolic registers. + For example, the Helium vector extension has various instructions which accept either vector or GPR arguments.""" typing_hints = { name : ty for ty in self.arch.RegisterType \ for name in self.arch.RegisterType.list_registers(ty, with_variants=True) } @@ -291,12 +304,12 @@ def outputs(self): return self._outputs @property def inputs_are_outputs(self): - """Every input is automatically treated as an output. + """Every input is automatically treated as an output. This is typically set for loop kernels.""" return self._inputs_are_outputs @property def allow_useless_instructions(self): - """Indicates whether data flow creation should raise SlothyUselessInstructionException + """Indicates whether data flow creation should raise SlothyUselessInstructionException when a useless instruction is detected.""" return self._allow_useless_instructions @@ -325,6 +338,7 @@ def __init__(self, slothy_config=None, **kwargs): self._outputs = None self._inputs_are_outputs = None self._allow_useless_instructions = None + self._locked_registers = None self._load_slothy_config(slothy_config) for k,v in kwargs.items(): setattr(self,k,v) @@ -334,6 +348,7 @@ def _load_slothy_config(self, slothy_config): return self._slothy_config = slothy_config self._arch = slothy_config.arch + self._locked_registers = slothy_config.locked_registers self._typing_hints = self._slothy_config.typing_hints self._outputs = self._slothy_config.outputs self._inputs_are_outputs = self._slothy_config.inputs_are_outputs @@ -461,7 +476,7 @@ def _iter_edges_with_label(): def depth(self): """The depth of the data flow graph. - + Equivalently, the maximum length of a dependency chain in the assembly source represented by the graph.""" if self.nodes is None or len(self.nodes) == 0: @@ -479,6 +494,73 @@ def arch(self): """The underlying architecturel model""" return self.config.arch + def apply_cbs(self, cb, logger, one_a_time=False): + """Apply callback to all nodes in the graph""" + + count = 0 + while True: + count += 1 + assert count < 100 # There shouldn't be many repeated modifications to the CFG + + some_change = False + + for t in self.nodes: + t.delete = False + t.changed = False + + for t in self.nodes: + if cb(t): + some_change = True + if one_a_time is True: + break + + if some_change is False: + break + + z = zip(self.nodes, self.src) + z = filter(lambda x: x[0].delete is False, z) + z = map(lambda x: ([x[0].inst], x[0].inst.write()), z) + + self.src = list(z) + + # Otherwise, parse again + changed = [t for t in self.nodes if t.changed is True] + deleted = [t for t in self.nodes if t.delete is True] + + logger.debug("Some instruction changed in callback -- need to build dataflow graph again...") + + for t in deleted: + logger.debug("* %s was deleted", t) + for t in changed: + logger.debug("* %s was changed", t) + + self._build_graph() + + def apply_parsing_cbs(self): + """Apply parsing callbacks to all nodes in the graph. + + Typically, we only build the computation flow graph once. However, sometimes we make + retrospective modifications to instructions afterwards, and then need to reparse. + + An example for this are jointly destructive instruction patterns: A sequence of + instructions where each instruction individually overwrites only part of a register, + but jointly they overwrite the register as a whole. In this case, we can remove the + output register as an input dependency for the first instruction in the sequence, + thereby creating more reordering and renaming flexibility. In this case, we change + the instruction and then rebuild the computation flow graph. + """ + logger = self.logger.getChild("parsing_cbs") + def parsing_cb(t): + return t.inst.global_parsing_cb(t, log=logger.info) + return self.apply_cbs(parsing_cb, logger) + + def apply_fusion_cbs(self): + """Apply fusion callbacks to nodes in the graph""" + logger = self.logger.getChild("fusion_cbs") + def fusion_cb(t): + return t.inst.global_fusion_cb(t, log=logger.info) + return self.apply_cbs(fusion_cb, logger, one_a_time=True) + def __init__(self, src, logger, config, parsing_cb=True): """Compute a data flow graph from a source code snippet. @@ -497,45 +579,10 @@ def __init__(self, src, logger, config, parsing_cb=True): self.config = config self.src = self._parse_source(src) - # Typically, we only build the computation flow graph once. However, sometimes we make - # retrospective modifications to instructions afterwards, and then need to reparse. - # - # An example for this are jointly destructive instruction patterns: A sequence of - # instructions where each instruction individually overwrites only part of a register, - # but jointly they overwrite the register as a whole. In this case, we can remove the - # output register as an input dependency for the first instruction in the sequence, - # thereby creating more reordering and renaming flexibility. In this case, we change - # the instruction and then rebuild the computation flow graph. - count = 0 - while True: - count += 1 - assert count < 10 # There shouldn't be many repeated modifications to the CFG - - self._build_graph() - - if not parsing_cb: - break + self._build_graph() - changed = [] - for t in self.nodes: - was_changed = t.inst.global_parsing_cb(t) - if was_changed: # remember to build the dataflow graph again - changed.append(t) - - changes = len(changed) - # If no instruction was modified, we're done - if changes == 0: - break - - self.src = list(zip([[t.inst] for t in self.nodes], [s[1] for s in self.src])) - - # Otherwise, parse again - logger.debug("%d instructions changed -- need to build dataflow graph again...", - changes) - logger.debug("The following instructions have changed:") - if changes > 0: - for t in changed: - logger.debug(t) + if parsing_cb is True: + self.apply_parsing_cbs() self._selfcheck_outputs() @@ -565,15 +612,24 @@ def outputs_unused(t): self.logger.warning(f"Instruction details: {t}, {t.inst.inputs}") self._dump_instructions("Source code", error=False) + def _parse_line(self, l): + assert SourceLine.is_source_line(l) + insts = self.arch.Instruction.parser(l) + # Remember options from source line + # TODO: Might not be the right place to remember options + for inst in insts: + inst.source_line = l + return (insts, l) + def _parse_source(self, src): - return [ (self.arch.Instruction.parser(l),l) for l in AsmHelper.reduce_source(src) ] + return list(map(self._parse_line, SourceLine.reduce_source(src))) def iter_dependencies(self): """Returns an iterator over all dependencies in the data flow graph. - + Each returned element has the form (consumer, producer, ty, idx), representing a dependency - from output producer to the idx-th input (if ty=="in") or input/output (if ty=="inout") of - consumer. The producer field is an instance of RegisterSource and contains the output index + from output producer to the idx-th input (if ty=="in") or input/output (if ty=="inout") of + consumer. The producer field is an instance of RegisterSource and contains the output index and source instruction as producer.idx and producer.src, respectively.""" for consumer in self.nodes_all: for idx, producer in enumerate(consumer.src_in): @@ -647,8 +703,8 @@ def get_fresh_reg(): no_ssa.append((producer.src, producer.idx)) for t in self.nodes: - for (i,_) in enumerate(t.inst.args_out): - if (t,i) in no_ssa: + for (i,c) in enumerate(t.inst.args_out): + if c in self.config._locked_registers or (t,i) in no_ssa: continue t.inst.args_out[i] = get_fresh_reg() @@ -763,6 +819,7 @@ def find_sources(types,names): step = ComputationNode(node_id=s_id, orig_pos=orig_pos, inst=s, src_in=src_in, src_in_out=src_in_out) + step.reg_state = self.reg_state.copy() def change_reg_ref(reg, ref): self._remember_type(reg, ref.get_type()) diff --git a/slothy/heuristics.py b/slothy/core/heuristics.py similarity index 58% rename from slothy/heuristics.py rename to slothy/core/heuristics.py index edb9a330..b278d50a 100644 --- a/slothy/heuristics.py +++ b/slothy/core/heuristics.py @@ -25,22 +25,33 @@ # Author: Hanno Becker # +"""SLOTHY heuristics + +The one-shot SLOTHY approach tends to become computationally infeasible above +200 assembly instructions. To optimize kernels beyond that threshold, this +module provides heuristics to split the optimization problem into several +smaller-sizes problems amenable to one-shot SLOTHY. +""" + import math import random -import numpy as np -from slothy.dataflow import DataFlowGraph as DFG -from slothy.dataflow import Config as DFGConfig -from slothy.core import SlothyBase, Result -from slothy.helper import Permutation, AsmHelper +from slothy.core.dataflow import DataFlowGraph as DFG +from slothy.core.dataflow import Config as DFGConfig, ComputationNode +from slothy.core.core import SlothyBase, Result, SlothyException +from slothy.helper import Permutation, SourceLine from slothy.helper import binary_search, BinarySearchLimitException class Heuristics(): + """Break down large optimization problems into smaller ones. + + The one-shot SLOTHY approach tends to become computationally infeasible above + 200 assembly instructions. To optimize kernels beyond that threshold, this + class provides heuristics to split the optimization problem into several + smaller-sizes problems amenable to one-shot SLOTHY.""" @staticmethod - def optimize_binsearch_core(source, logger, conf, **kwargs): - """Shim wrapper around Slothy performing a binary search for the - minimization of stalls""" + def _optimize_binsearch_core(source, logger, conf, **kwargs): logger_name = logger.name.replace(".","_") last_successful = None @@ -73,70 +84,139 @@ def try_with_stalls(stalls, timeout=None): try: return binary_search(try_with_stalls, - minimum= conf.constraints.stalls_minimum_attempt - 1, - start=conf.constraints.stalls_first_attempt, - threshold=conf.constraints.stalls_maximum_attempt, - precision=conf.constraints.stalls_precision, - timeout_below_precision=conf.constraints.stalls_timeout_below_precision) + minimum=conf.constraints.stalls_minimum_attempt - 1, + start=conf.constraints.stalls_first_attempt, + threshold=conf.constraints.stalls_maximum_attempt, + precision=conf.constraints.stalls_precision, + timeout_below_precision=conf.constraints.stalls_timeout_below_precision) + except BinarySearchLimitException: logger.error("Exceeded stall limit without finding a working solution") logger.error("Here's what you asked me to optimize:") - Heuristics._dump("Original source code", source, logger=logger, err=True, no_comments=True) - logger.error("Configuration") + + Heuristics._dump("Original source code", source, + logger=logger, err=True, no_comments=True) + logger.error("Configuration:") conf.log(logger.error) err_file = conf.log_dir + f"/{logger_name}_ERROR.s" - f = open(err_file, "w") - conf.log(lambda l: f.write("// " + l + "\n")) - f.write('\n'.join(source)) - f.close() + with open(err_file, "w", encoding="utf-8") as f: + conf.log(lambda l: f.write("// " + l + "\n")) + f.write('\n'.join(source)) + logger.error(f"Stored this information in {err_file}") @staticmethod def optimize_binsearch(source, logger, conf, **kwargs): + """Optimize for minimum number of stalls, and potentially a secondary objective. + + Args: + source: The source code to be optimized. Must be a list of + SourceLine instances. + logger: The logger to be used + conf: The configuration to apply. This fixed for all one-shot SLOTHY + runs invoked by this call, except for the variation of the stall count. + + The `variable_size` configuration option determines whether the minimiation of + stalls happens internally or externally. Internal minimization means that the + number of stalls is part of the model, and its minimization registered as the + objective to the underlying solver. External minimization means that the number + of stalls is statically fixed per one-shot SLOTHY optimization, and that an + external binary search is used to minimize it. + + Returns: + The Result object for the succceeding optimization with the smallest + number of stalls. + """ if conf.variable_size: return Heuristics.optimize_binsearch_internal(source, logger, conf, **kwargs) - else: - return Heuristics.optimize_binsearch_external(source, logger, conf, **kwargs) + + return Heuristics.optimize_binsearch_external(source, logger, conf, **kwargs) + + @staticmethod + def _log_reoptimization_failure(log): + log.warning("Re-optimization with objective at minimum number of stalls failed. "\ + "By the non-deterministic nature of the optimization, this can happen. " \ + "Will just pick previous result...") + + @staticmethod + def _log_input_output_warning(log): + log.warning("You are using SW pipelining without setting inputs_are_outputs=True."\ + "This means that the last iteration of the loop may overwrite inputs "\ + "to the loop (such as address registers), unless they are marked as " \ + "reserved registers. If this is intended, ignore this warning. " \ + "Otherwise, consider setting inputs_are_outputs=True to ensure that " \ + "nothing that is used as an input to the loop is overwritten, " \ + "not even in the last iteration.") @staticmethod def optimize_binsearch_external(source, logger, conf, flexible=True, **kwargs): - """Find minimum number of stalls without objective, then optimize - the objective for a fixed number of stalls.""" + """Externally optimize for minimum number of stalls, and potentially a secondary objective. + + This function uses an external binary search to find the minimum number of stalls + for which a one-shot SLOTHY optimization succeeds. If the provided configuration + has a secondary objective, it then re-optimizes the result for that secondary + objective, fixing the minimal number of stalls. + + Args: + source: The source code to be optimized. Must be a list of SourceLine instances. + logger: The logger to be used. + conf: The configuration to apply. This is fixed for all one-shot SLOTHY + runs invoked by this call, except for variation of stall count. + flexible: Indicates whether the number of stalls should be minimized + through a binary search, or whether a single one-shot SLOTHY optimization + for a fixed number of stalls (encoded in the configuration) should be + conducted. + + Returns: + A Result object representing the final optimization result. + """ if not flexible: core = SlothyBase(conf.arch, conf.target, logger=logger,config=conf) if not core.optimize(source): - raise Exception("Optimization failed") + raise SlothyException("Optimization failed") return core.result logger.info("Perform external binary search for minimal number of stalls...") c = conf.copy() c.ignore_objective = True - min_stalls, core = Heuristics.optimize_binsearch_core(source, logger, c, **kwargs) + min_stalls, core = Heuristics._optimize_binsearch_core(source, logger, c, **kwargs) if not conf.has_objective: return core.result - logger.info(f"Optimize again with minimal number of {min_stalls} stalls, with objective...") + logger.info("Optimize again with minimal number of %d stalls, with objective...", + min_stalls) first_result = core.result core.config.ignore_objective = False success = core.retry() if not success: - logger.warning("Re-optimization with objective at minimum number of stalls failed -- should not happen? Will just pick previous result...") + Heuristics._log_reoptimization_failure(logger) return first_result - # core = SlothyBase(conf.arch, conf.target, logger=logger, config=c) - # success = core.optimize(source, **kwargs) return core.result @staticmethod def optimize_binsearch_internal(source, logger, conf, **kwargs): - """Find minimum number of stalls without objective, then optimize - the objective for a fixed number of stalls.""" + """Internally optimize for minimum number of stalls, and potentially a secondary objective. + + This finds the minimum number of stalls for which a one-shot SLOTHY optimization succeeds. + If the provided configuration has a secondary objective, it then re-optimizes the result + for that secondary objective, fixing the minimal number of stalls. + + Args: + source: The source code to be optimized. Must be a list of SourceLine instances. + logger: The logger to be used. + conf: The configuration to apply. This is fixed for all one-shot SLOTHY + runs invoked by this call, except for variation of stall count. + + Returns: + A Result object representing the final optimization result. + """ logger.info("Perform internal binary search for minimal number of stalls...") @@ -148,7 +228,7 @@ def optimize_binsearch_internal(source, logger, conf, **kwargs): c.variable_size = True c.constraints.stalls_allowed = cur_attempt - logger.info(f"Attempt optimization with max {cur_attempt} stalls...") + logger.info("Attempt optimization with max %d stalls...", cur_attempt) core = SlothyBase(c.arch, c.target, logger=logger, config=c) success = core.optimize(source, **kwargs) @@ -160,40 +240,63 @@ def optimize_binsearch_internal(source, logger, conf, **kwargs): cur_attempt = max(1,cur_attempt * 2) if cur_attempt > conf.constraints.stalls_maximum_attempt: logger.error("Exceeded stall limit without finding a working solution") - raise Exception("No solution found") + raise SlothyException("No solution found") logger.info(f"Minimum number of stalls: {min_stalls}") if not conf.has_objective: return core.result - logger.info(f"Optimize again with minimal number of {min_stalls} stalls, with objective...") + logger.info("Optimize again with minimal number of %d stalls, with objective...", + min_stalls) first_result = core.result success = core.retry(fix_stalls=min_stalls) if not success: - logger.warning("Re-optimization with objective at minimum number of stalls failed -- should not happen? Will just pick previous result...") + Heuristics._log_reoptimization_failure(logger) return first_result return core.result @staticmethod def periodic(body, logger, conf): - """Heuristics for the optimization of large loops - - Can be called if software pipelining is disabled. In this case, it just - forwards to the linear heuristic.""" + """Entrypoint for optimization of loops. + + If software pipelining is disabled, this function forwards to + the straightline optimization via Heuristics.linear(). + + If software pipelining is enabled but the halving heuristic + is disabled, this function performs a one-shot SLOTHY optimization + without heuristics. + + If software pipelining is enabled and the halving heuristic is + enabled, this function optimizes the loop body via straightline + optimization first, splits result as `[A;B]`, and optimizes + `[B;A]` again via straightline optimizations. The optimized loop + is then given by the preamble `A`, kernel `opt(B;A)`, and postamble + `B`. The straightline optimizations applied in this heuristics are + done via Heuristics.linear() and thus themselves subject to the + splitting heuristic, if enabled. + + Args: + body: The loop body to be optimized. This must be a list of + SourceLine instances. + logger: The logger to be used. + conf: The configuration to be applied. + + Returns: + Tuple (preamble, kernel, postamble, num_exceptional_iterations) + of preamble, kernel and postamble (each as a list of SourceLine + objects), plus the number of iterations jointly accounted for by + the preamble and postamble (the caller will need this to adjust the + loop counter). + """ if conf.sw_pipelining.enabled and not conf.inputs_are_outputs: - logger.warning("You are using SW pipelining without setting inputs_are_outputs=True. This means that the last iteration of the loop may overwrite inputs to the loop (such as address registers), unless they are marked as reserved registers. If this is intended, ignore this warning. Otherwise, consider setting inputs_are_outputs=True to ensure that nothing that is used as an input to the loop is overwritten, not even in the last iteration.") + Heuristics._log_input_output_warning(logger) - def unroll(source): - if conf.sw_pipelining.enabled: - source = source * conf.sw_pipelining.unroll - source = '\n'.join(source) - return source - - body = unroll(body) + if conf.sw_pipelining.enabled: + body = body * conf.sw_pipelining.unroll if conf.inputs_are_outputs: dfg = DFG(body, logger.getChild("dfg_generate_outputs"), @@ -202,10 +305,10 @@ def unroll(source): conf.inputs_are_outputs = False # If we're not asked to do software pipelining, just forward to - # the heurstics for linear optimization. + # the heuristics for linear optimization. if not conf.sw_pipelining.enabled: - core = Heuristics.linear( body, logger=logger, conf=conf) - return [], core, [], 0 + res = Heuristics.linear( body, logger=logger, conf=conf) + return [], res.code, [], 0 if conf.sw_pipelining.halving_heuristic: return Heuristics._periodic_halving( body, logger, conf) @@ -224,6 +327,7 @@ def unroll(source): num_exceptional_iterations = result.num_exceptional_iterations kernel = result.code + assert SourceLine.is_source(kernel) # Second step: Separately optimize preamble and postamble @@ -231,12 +335,12 @@ def unroll(source): if conf.sw_pipelining.optimize_preamble: logger.debug("Optimize preamble...") Heuristics._dump("Preamble", preamble, logger) - logger.debug(f"Dependencies within kernel: "\ - f"{result.kernel_input_output}") + logger.debug("Dependencies within kernel: %s", result.kernel_input_output) c = conf.copy() c.outputs = result.kernel_input_output c.sw_pipelining.enabled=False - preamble = Heuristics.linear(preamble,conf=c, logger=logger.getChild("preamble")) + res_preamble = Heuristics.linear(preamble,conf=c, logger=logger.getChild("preamble")) + preamble = res_preamble.code postamble = result.postamble if conf.sw_pipelining.optimize_postamble: @@ -244,27 +348,43 @@ def unroll(source): Heuristics._dump("Preamble", postamble, logger) c = conf.copy() c.sw_pipelining.enabled=False - postamble = Heuristics.linear(postamble, conf=c, logger=logger.getChild("postamble")) + res_postamble = Heuristics.linear(postamble, conf=c, + logger=logger.getChild("postamble")) + postamble = res_postamble.code return preamble, kernel, postamble, num_exceptional_iterations @staticmethod - def linear(body, logger, conf, visualize_stalls=True): - """Heuristic for the optimization of large linear chunks of code. + def linear(body, logger, conf): + """Entrypoint for straightline optimization. + + If the split heuristic is disabled, this forwards to a one-shot optimization. + + If the split heuristic is enabled (conf.split_heuristic == True), the assembly + input is optimized by successively applying one-shot optimizations to a + 'sliding window' of code. - Must only be called if software pipelining is disabled.""" + Args: + body: The assembly input to be optimized. This must be a list of + SourceLine objects. + conf: The configuration to be applied. Software pipelining must be disabled. + + Raises: + Raises a SlothyException if software pipelining is enabled. + """ + assert SourceLine.is_source(body) if conf.sw_pipelining.enabled: - raise Exception("Linear heuristic should only be called with SW pipelining disabled") + raise SlothyException("Linear heuristic should only be called " + "with SW pipelining disabled") Heuristics._dump("Starting linear optimization...", body, logger) # So far, we only implement one heuristic: The splitting heuristic -- # If that's disabled, just forward to the core optimization if not conf.split_heuristic: - result = Heuristics.optimize_binsearch(body,logger.getChild("slothy"), conf) - return result.code + return Heuristics.optimize_binsearch(body,logger.getChild("slothy"), conf) - return Heuristics._split( body, logger, conf, visualize_stalls) + return Heuristics._split(body, logger, conf) @staticmethod def _naive_reordering(body, logger, conf, use_latency_depth=False): @@ -285,11 +405,11 @@ def _naive_reordering(body, logger, conf, use_latency_depth=False): else: # Calculate latency-depth of instruction nodes nodes_by_depth = dfg.nodes.copy() - nodes_by_depth.sort(key=(lambda t: t.depth)) + nodes_by_depth.sort(key=lambda t: t.depth) for t in dfg.nodes_all: t.latency_depth = 0 def get_latency(tp,t): - if tp.src.is_virtual(): + if tp.src.is_virtual: return 0 return conf.target.get_latency(tp.src.inst, tp.idx, t.inst) for t in nodes_by_depth: @@ -304,14 +424,15 @@ def get_latency(tp,t): perm = Permutation.permutation_id(l) - for i in range(l): - def get_inputs(inst): - return set(inst.args_in + inst.args_in_out) - def get_outputs(inst): - return set(inst.args_out + inst.args_in_out) + def get_inputs(inst): + return set(inst.args_in + inst.args_in_out) + def get_outputs(inst): + return set(inst.args_out + inst.args_in_out) - joint_prev_inputs = {} - joint_prev_outputs = {} + joint_prev_inputs = {} + joint_prev_outputs = {} + + for i in range(l): cur_joint_prev_inputs = set() cur_joint_prev_outputs = set() for j in range(i,l): @@ -339,24 +460,15 @@ def could_come_next(j): def pick_candidate(candidate_idxs): - # print("CANDIDATES: " + '\n* '.join(list(map(lambda idx: str((body[idx], conf.target.get_units(insts[idx]))), candidate_idxs)))) - # There a different strategies one can pursue here, some being: - # - Always pick the candidate instruction of the smallest depth - # - Peek into the uarch model and try to alternate between functional units - # It's a bit disappointing if this is necessary, since SLOTHY should do this. - # However, running it on really large snippets (1000 instructions) remains - # infeasible, even if latencies and renaming are disabled. - strategy = "minimal_depth" - # strategy = "alternate_functional_units" if strategy == "minimal_depth": candidate_depths = list(map(lambda j: depths[j], candidate_idxs)) logger.debug("Candidate %s: %s", depth_str, candidate_depths) choice_idx = candidate_idxs[candidate_depths.index(min(candidate_depths))] - elif strategy == "alternate_functional_units": - + else: + assert strategy == "alternate_functional_units" def flatten_units(units): res = [] for u in units: @@ -389,44 +501,28 @@ def units_different(a,b): candidate_depths = list(map(lambda j: depths[j], candidate_idxs)) logger.debug(f"Candidate {depth_str}s: {candidate_depths}") min_depth = min(candidate_depths) - refined_candidates = [ candidate_idxs[i] for i,d in enumerate(candidate_depths) if d == min_depth ] + refined_candidates = [ candidate_idxs[i] + for i,d in enumerate(candidate_depths) if d == min_depth ] choice_idx = random.choice(refined_candidates) - else: - raise Exception("Unknown preprocessing strategy") - return choice_idx - def move_entry_forward(lst, idx_from, idx_to, callback=None): + def move_entry_forward(lst, idx_from, idx_to): entry = lst[idx_from] del lst[idx_from] - - if callback is not None: - for before in lst[idx_to:idx_from]: - callback(before, entry) - return lst[:idx_to] + [entry] + lst[idx_to:] - def inst_reorder_cb(t0,t1): - SlothyBase._fixup_reordered_pair(t0,t1,logger) - - SlothyBase._fixup_reset(insts) choice_idx = None while choice_idx is None: - try: - choice_idx = pick_candidate(candidate_idxs) - insts = move_entry_forward(insts, choice_idx, i, inst_reorder_cb) - except: - candidate_idxs.remove(choice_idx) - choice_idx = None - SlothyBase._fixup_finish(insts, logger) + choice_idx = pick_candidate(candidate_idxs) + insts = move_entry_forward(insts, choice_idx, i) local_perm = Permutation.permutation_move_entry_forward(l, choice_idx, i) perm = Permutation.permutation_comp (local_perm, perm) - body = [ str(j.inst) for j in insts] + body = list(map(ComputationNode.to_source_line, insts)) depths = move_entry_forward(depths, choice_idx, i) - body[i] = f" {body[i].strip():100s} // {depth_str} {depths[i]}" + body[i].set_length(100).set_comment(f"{depth_str} {depths[i]}") Heuristics._dump("New code", body, logger) # Selfcheck @@ -441,6 +537,9 @@ def inst_reorder_cb(t0,t1): res.valid = True res.selfcheck(logger.getChild("naive_interleaving_selfcheck")) + res.offset_fixup(logger.getChild("naive_interleaving_fixup")) + body = res.code_raw + Heuristics._dump("Before naive interleaving", old, logger) Heuristics._dump("After naive interleaving", body, logger) return body, perm @@ -454,11 +553,11 @@ def _get_ssa_form(body, logger, conf): logger.info("Transform DFG into SSA...") dfg = DFG(body, logger.getChild("dfg_ssa"), DFGConfig(conf.copy()), parsing_cb=True) dfg.ssa() - ssa = [ str(t.inst) for t in dfg.nodes ] + ssa = [ ComputationNode.to_source_line(t) for t in dfg.nodes ] return ssa @staticmethod - def _split_inner(body, logger, conf, visualize_stalls=True, ssa=False): + def _split_inner(body, logger, conf, ssa=False): l = len(body) if l == 0: @@ -484,34 +583,22 @@ def _split_inner(body, logger, conf, visualize_stalls=True, ssa=False): c = conf.copy() c.constraints.allow_reordering = False c.constraints.functional_only = True - body = AsmHelper.reduce_source(body) - result = Heuristics.optimize_binsearch(body, log.getChild("remove_symbolics"),conf=c) + body = SourceLine.reduce_source(body) + result = Heuristics.optimize_binsearch(body, + log.getChild("remove_symbolics"),conf=c) body = result.code - body = AsmHelper.reduce_source(body) + body = SourceLine.reduce_source(body) else: perm = Permutation.permutation_id(l) - # log.debug("Remove symbolics...") - # c = conf.copy() - # c.constraints.allow_reordering = False - # c.constraints.functional_only = True - # body = AsmHelper.reduce_source(body) - # result = Heuristics.optimize_binsearch(body, log.getChild("remove_symbolics"),conf=c) - # body = result.code - # body = AsmHelper.reduce_source(body) - - # conf.outputs = result.outputs - - # Heuristics._dump("Source code without symbolic registers", body, log) - - # conf.outputs = result.outputs - def print_intarr(arr, l,vals=50): - m = max(10,max(arr)) + m = max(10, max(arr)) # pylint:disable=nested-min-max start_idxs = [ (l * i) // vals for i in range(vals) ] end_idxs = [ (l * (i+1)) // vals for i in range(vals) ] avgs = [] for (s,e) in zip(start_idxs, end_idxs): + if s == e: + continue avg = sum(arr[s:e]) // (e-s) avgs.append(avg) log.info(f"[{s:3d}-{e:3d}]: {'*'*avg}{'.'*(m-avg)} ({avg})") @@ -522,7 +609,8 @@ def print_stalls(stalls,l): stalls_arr = [ i in stalls for i in range(l) ] for v in stalls_arr: assert v in {0,1} - stalls_cumulative = [ sum(stalls_arr[max(0,i-math.floor(chunk_len/2)):i+math.ceil(chunk_len/2)]) for i in range(l) ] + stalls_cumulative = [ sum(stalls_arr[max(0,i-math.floor(chunk_len/2)) + :i+math.ceil(chunk_len/2)]) for i in range(l) ] print_intarr(stalls_cumulative,l) def optimize_chunk(start_idx, end_idx, body, stalls,show_stalls=True): @@ -549,7 +637,8 @@ def optimize_chunk(start_idx, end_idx, body, stalls,show_stalls=True): pre_pad = len(cur_pre) post_pad = len(cur_post) - Heuristics._dump(f"Optimizing chunk [{start_idx}-{prefix_len}:{end_idx}+{suffix_len}]", cur_body, log) + Heuristics._dump(f"Optimizing chunk [{start_idx}-{prefix_len}:{end_idx}+{suffix_len}]", + cur_body, log) if prefix_len > 0: Heuristics._dump("Using prefix", cur_prefix, log) if suffix_len > 0: @@ -571,7 +660,7 @@ def optimize_chunk(start_idx, end_idx, body, stalls,show_stalls=True): log.getChild(f"{start_idx}_{end_idx}"), c, prefix_len=prefix_len, suffix_len=suffix_len) Heuristics._dump(f"New chunk [{start_idx}:{end_idx}]", result.code, log) - new_body = cur_pre + AsmHelper.reduce_source(result.code) + cur_post + new_body = cur_pre + SourceLine.reduce_source(result.code) + cur_post perm = Permutation.permutation_pad(result.reordering, pre_pad, post_pad) @@ -606,8 +695,7 @@ def make_idx_list_consecutive(factor, increment): end_pos = [] while cur_end < 1.0: cur_end = cur_start + chunk_len - if cur_end > 1.0: - cur_end = 1.0 + cur_end = min(cur_end, 1.0) start_pos.append(cur_start) end_pos.append(cur_end) @@ -639,19 +727,14 @@ def not_empty(x): else: increment = conf.split_heuristic_stepsize - # orig_body = AsmHelper.reduce_source(cur_body).copy() - # perm = Permutation.permutation_id(len(orig_body)) - # Remember inputs and outputs dfgc = DFGConfig(conf.copy()) outputs = conf.outputs.copy() inputs = DFG(orig_body, log.getChild("dfg_infer_inputs"),dfgc).inputs.copy() - last_base = None - - for i in range(conf.split_heuristic_repeat): + for _ in range(conf.split_heuristic_repeat): - cur_body = AsmHelper.reduce_source(cur_body) + cur_body = SourceLine.reduce_source(cur_body) if conf.split_heuristic_chunks: start_pos = [ x[0] for x in conf.split_heuristic_chunks ] @@ -667,86 +750,78 @@ def not_empty(x): idx_lst.reverse() cur_body, stalls, local_perm = optimize_chunks_many(idx_lst, cur_body, stalls, - abort_stall_threshold=conf.split_heuristic_abort_cycle_at) + abort_stall_threshold=conf.split_heuristic_abort_cycle_at) perm = Permutation.permutation_comp(local_perm, perm) # Check complete result res = Result(conf) res.orig_code = orig_body - res.code = AsmHelper.reduce_source(cur_body).copy() + res.code = SourceLine.reduce_source(cur_body).copy() res.codesize_with_bubbles = res.codesize res.success = True res.reordering_with_bubbles = perm res.input_renamings = { s:s for s in inputs } res.output_renamings = { s:s for s in outputs } res.valid = True - res.selfcheck(log.getChild("full_selfcheck")) - cur_body = res.code - - maxlen = max(len(s.rstrip()) for s in cur_body) - for i in stalls: - if i > len(cur_body): - log.error("Something is wrong: Index %d, body length %d", i, len(cur_body)) - Heuristics._dump("Body:", cur_body, log, err=True) - cur_body[i] = f"{cur_body[i].rstrip():{maxlen+8}s} // gap(s) to follow" - - # Visualize model violations - if conf.split_heuristic_visualize_stalls: - cur_body = AsmHelper.reduce_source(cur_body) - c = conf.copy() - c.constraints.allow_reordering = False - c.constraints.allow_renaming = False - c.visualize_reordering = False - cur_body = Heuristics.optimize_binsearch( cur_body, log.getChild("visualize_stalls"), c).code - cur_body = ["// Start split region"] + cur_body + ["// End split region"] - - # Visualize functional units - if conf.split_heuristic_visualize_units: - dfg = DFG(cur_body, logger.getChild("visualize_functional_units"), DFGConfig(c)) - new_body = [] - for (l,t) in enumerate(dfg.nodes): - unit = conf.target.get_units(t.inst)[0] - indentation = conf.target.ExecutionUnit.get_indentation(unit) - new_body[i] = f"{'':{indentation}s}" + l - cur_body = new_body - - return cur_body + res.selfcheck(log.getChild("split_heuristic_full")) + return res @staticmethod - def _split(body, logger, conf, visualize_stalls=True): + def _split(body, logger, conf): c = conf.copy() # Focus on the chosen subregion - body = AsmHelper.reduce_source(body) + body = SourceLine.reduce_source(body) if c.split_heuristic_region == [0.0, 1.0]: - return Heuristics._split_inner(body, logger, c, visualize_stalls) + return Heuristics._split_inner(body, logger, c) + + inputs = DFG(body, logger.getChild("dfg_generate_inputs"), DFGConfig(c)).inputs start_end_idxs = Heuristics._idxs_from_fractions(c.split_heuristic_region, body) start_idx = start_end_idxs[0] end_idx = start_end_idxs[1] pre = body[:start_idx] - cur = body[start_idx:end_idx] + partial_body = body[start_idx:end_idx] post = body[end_idx:] # Adjust the outputs c.outputs = DFG(post, logger.getChild("dfg_generate_outputs"), DFGConfig(c)).inputs c.inputs_are_outputs = False - cur = Heuristics._split_inner(cur, logger, c, visualize_stalls) - body = pre + cur + post - return body + res = Heuristics._split_inner(partial_body, logger, c) + new_partial_body = res.code + + pre_pad = len(pre) + post_pad = len(post) + perm = Permutation.permutation_pad(res.reordering, pre_pad, post_pad) + + new_body = SourceLine.reduce_source(pre + new_partial_body + post) + + res2 = Result(conf) + res2.orig_code = body.copy() + res2.code = new_body + res2.codesize_with_bubbles = pre_pad + post_pad + res.codesize_with_bubbles + res2.success = True + res2.reordering_with_bubbles = perm + res2.input_renamings = { s:s for s in inputs } + res2.output_renamings = { s:s for s in conf.outputs } + res2.valid = True + res2.selfcheck(logger.getChild("split")) + + return res2 @staticmethod def _dump(name, s, logger, err=False, no_comments=False): + assert SourceLine.is_source(s) + s = [ str(l) for l in s] + def strip_comments(sl): return [ s.split("//")[0].strip() for s in sl ] fun = logger.debug if not err else logger.error - fun(f"Dump: {name}") - if isinstance(s, str): - s = s.splitlines() + fun(f"Dump: {name} (size {len(s)})") if no_comments: s = strip_comments(s) for l in s: @@ -759,6 +834,8 @@ def _periodic_halving(body, logger, conf): assert conf.sw_pipelining.enabled assert conf.sw_pipelining.halving_heuristic + body = SourceLine.reduce_source(body) + # Find kernel dependencies kernel_deps = DFG(body, logger.getChild("dfg_kernel_deps"), DFGConfig(conf.copy())).inputs @@ -770,11 +847,80 @@ def _periodic_halving(body, logger, conf): c.outputs = c.outputs.union(kernel_deps) if not conf.sw_pipelining.halving_heuristic_split_only: - kernel = Heuristics.linear(body,logger.getChild("slothy"),conf=c, - visualize_stalls=False) + res_halving_0 = Heuristics.linear(body,logger.getChild("slothy"),conf=c) + + # Split resulting kernel as [A;B] and synthesize result structure + # as if SW pipelining has been used and the result would have been + # [B;A], with preamble A and postamble B. + # + # Run the normal SW-pipelining selfcheck on this result. + # + # The overall goal here is to produce a result structure that's structurally + # the same as for normal SW pipelining, including checks and visualization. + # + # TODO: The 2nd optimization step below does not yet produce a Result structure. + reordering = res_halving_0.reordering + codesize = res_halving_0.codesize + def rotate_pos(p): + return p - (codesize // 2) + def is_pre(i): + return rotate_pos(reordering[i]) < 0 + + kernel = SourceLine.reduce_source(res_halving_0.code) + preamble = kernel[:codesize//2] + postamble = kernel[codesize//2:] + + # Swap halves around and consider new kernel [B;A] + kernel = postamble + preamble + + dfgc = DFGConfig(c.copy()) + dfgc.inputs_are_outputs = False + core_out = DFG(postamble, logger.getChild("dfg_kernel_deps"),dfgc).inputs + + dfgc = DFGConfig(conf.copy()) + dfgc.inputs_are_outputs = True + dfgc.outputs = core_out + new_kernel_deps = DFG(kernel, logger.getChild("dfg_kernel_deps"),dfgc).inputs + + c2 = c.copy() + c2.sw_pipelining.enabled = True + + reordering1 = { i : rotate_pos(reordering[i]) + for i in range(codesize) } + pre_core_post_dict1 = { i : (is_pre(i), not is_pre(i), False) + for i in range(codesize) } + + res = Result(c2) + res.orig_code = body + res.code = kernel + res.preamble = preamble + res.postamble = postamble + res.kernel_input_output = new_kernel_deps + res.codesize_with_bubbles = res_halving_0.codesize_with_bubbles + res.reordering_with_bubbles = reordering1 + res.pre_core_post_dict = pre_core_post_dict1 + res.input_renamings = { s:s for s in kernel_deps } + res.output_renamings = { s:s for s in c.outputs } + res.success = True + res.valid = True + + # Check result as if it has been produced by SW pipelining run + res.selfcheck(logger.getChild("halving_heuristic_1")) + else: logger.info("Halving heuristic: Split-only -- no optimization") - kernel = body + codesize = len(body) + preamble = body[:codesize//2] + postamble = body[codesize//2:] + kernel = postamble + preamble + + dfgc = DFGConfig(c.copy()) + dfgc.inputs_are_outputs = False + kernel_deps = DFG(postamble, logger.getChild("dfg_kernel_deps"),dfgc).inputs + + dfgc = DFGConfig(conf.copy()) + dfgc.inputs_are_outputs = True + kernel_deps = DFG(kernel, logger.getChild("dfg_kernel_deps"),dfgc).inputs # # Second step: @@ -792,27 +938,9 @@ def _periodic_halving(body, logger, conf): # iteration followed by the early half of the successive iteration. The hope is that this # enables good interleaving even without calling SLOTHY in SW pipelining mode. - kernel = AsmHelper.reduce_source(kernel) - kernel_len = len(kernel) - kernel_lenh = kernel_len // 2 - kernel_low = kernel[:kernel_lenh] - kernel_high = kernel[kernel_lenh:] - kernel = kernel_high.copy() + kernel_low.copy() - - preamble, postamble = kernel_low, kernel_high - - dfgc = DFGConfig(conf.copy()) - dfgc.outputs = kernel_deps - dfgc.inputs_are_outputs = False - kernel_deps = DFG(kernel_high, logger.getChild("dfg_kernel_deps"),dfgc).inputs - - dfgc = DFGConfig(conf.copy()) - dfgc.inputs_are_outputs = True - kernel_deps = DFG(kernel, logger.getChild("dfg_kernel_deps"),dfgc).inputs - logger.info("Apply halving heuristic to optimize two halves of consecutive loop kernels...") - # The 'periodic' version considers the 'seam' between loop iterations; otherwise, we consider + # The 'periodic' version considers the 'seam' between iterations; otherwise, we consider # [B;A] as a non-periodic snippet, which may still lead to stalls at the loop boundary. if conf.sw_pipelining.halving_heuristic_periodic: @@ -827,10 +955,50 @@ def _periodic_halving(body, logger, conf): getChild("periodic heuristic"), conf=c).code elif not conf.sw_pipelining.halving_heuristic_split_only: c = conf.copy() - c.outputs = kernel_deps - c.sw_pipelining.enabled=False - - kernel = Heuristics.linear( kernel, logger.getChild("heuristic"), conf=c) + c.outputs = new_kernel_deps + c.inputs_are_outputs = True + c.sw_pipelining.enabled = False + + res_halving_1 = Heuristics.linear(kernel, logger.getChild("heuristic"), conf=c) + final_kernel = res_halving_1.code + + reordering2 = res_halving_1.reordering_with_bubbles + + c2 = conf.copy() + + def get_reordering2(i): + is_pre = res.pre_core_post_dict[i][0] + p = reordering2[res.periodic_reordering[i]] + if is_pre: + p -= res_halving_1.codesize_with_bubbles + return p + reordering2 = { i : get_reordering2(i) for i in range(codesize) } + + res2 = Result(c2) + res2.orig_code = body + res2.code = final_kernel + res2.kernel_input_output = new_kernel_deps + res2.codesize_with_bubbles = res_halving_1.codesize_with_bubbles + res2.reordering_with_bubbles = reordering2 + res2.pre_core_post_dict = pre_core_post_dict1 + res2.input_renamings = res.input_renamings + res2.output_renamings = res.output_renamings + + new_preamble = [ final_kernel[i] for i in range(res2.codesize) + if res2.is_pre(i, original_program_order=False) is True ] + new_postamble = [ final_kernel[i] for i in range(res2.codesize) + if res2.is_pre(i, original_program_order=False) is False ] + + res2.preamble = new_preamble + res2.postamble = new_postamble + res2.success = True + res2.valid = True + + # TODO: This does not yet work since there can be renaming at the boundary between + # preamble and postamble that we don't account for in the selfcheck. + # res2.selfcheck(logger.getChild("halving_heuristic_2")) + + kernel = res2.code num_exceptional_iterations = 1 return preamble, kernel, postamble, num_exceptional_iterations diff --git a/slothy/slothy.py b/slothy/core/slothy.py similarity index 53% rename from slothy/slothy.py rename to slothy/core/slothy.py index 44d1102c..e23e74be 100644 --- a/slothy/slothy.py +++ b/slothy/core/slothy.py @@ -1,4 +1,3 @@ - # # Copyright (c) 2022 Arm Limited # Copyright (c) 2022 Hanno Becker @@ -26,16 +25,51 @@ # Author: Hanno Becker # +"""SLOTHY optimizer + +SLOTHY - Super Lazy Optimization of Tricky Handwritten assemblY - is a +fixed-instruction assembly superoptimizer based on constraint solving. +It takes handwritten assembly as input and simultaneously super-optimizes: + +- Instruction scheduling +- Register allocation +- Software pipelining + +SLOTHY enables a development workflow where developers write 'clean' assembly by hand, +emphasizing the logic of the computation, while SLOTHY automates microarchitecture-specific +micro-optimizations. Since SLOTHY does not change instructions, and scheduling/allocation +optimizations are tightly controlled through configurable and extensible constraints, the +developer keeps close control over the final assembly, while being freed from the most tedious +and readability- and verifiability-impeding micro-optimizations. + +This module provides the Slothy class, which is a stateful interface to both +one-shot and heuristic optimiations using SLOTHY.""" + import logging from types import SimpleNamespace -from slothy.dataflow import DataFlowGraph as DFG -from slothy.dataflow import Config as DFGConfig -from slothy.core import Config -from slothy.heuristics import Heuristics -from slothy.helper import AsmAllocation, AsmMacro, AsmHelper, CPreprocessor - -class Slothy(): +from slothy.core.dataflow import DataFlowGraph as DFG +from slothy.core.dataflow import Config as DFGConfig, ComputationNode +from slothy.core.core import Config +from slothy.core.heuristics import Heuristics +from slothy.helper import AsmAllocation, AsmMacro, AsmHelper, CPreprocessor, SourceLine + +class Slothy: + """SLOTHY optimizer + + This class provides a stateful interface to both one-shot and heuristic + optimizations using SLOTHY. + + The basic flow of operation is the following: + - Initialize an instance, providing models to the target architecture + and microarchitecture as arguments. + - Load source code from file or raw string. + - Repeat: Adjust configuration and conduct an optimization of a loop body or + straightline block of code, using optimize() or optimize_loop(). + - Write source code to file or raw string. + + The use of heuristics is controlled through the configuration. + """ # Quick convenience access to architecture and target from the config def _get_arch(self): @@ -47,41 +81,64 @@ def _get_target(self): def __init__(self, arch, target, logger=None): self.config = Config(arch, target) - self.logger = logger if logger != None else logging.getLogger("slothy") - self.source = None + self.logger = logger if logger is not None else logging.getLogger("slothy") + + # The source, once loaded, is represented as a list of strings + self._source = None self.results = None self.last_result = None self.success = None + @property + def source(self): + """Returns the current source code as an array of SourceLine objects + + If you want the current source code as a multiline string, use get_source_as_string().""" + return self._source + + @source.setter + def source(self, val): + assert SourceLine.is_source(val) + self._source = val + + def get_source_as_string(self, comments=True, indentation=True, tags=True): + """Retrieve current source code as multi-line string""" + return SourceLine.write_multiline(self.source, comments=comments, + indentation=indentation, tags=tags) + + def set_source_as_string(self, s): + """Provide input source code as multi-line string""" + assert isinstance(s, str) + reduce = not self.config.ignore_tags + self.source = SourceLine.read_multiline(s, reduce=reduce) + def load_source_raw(self, source): - self.source = source.replace("\\\n", "") + """Load source code from multi-line string""" + self.set_source_as_string(source) self.results = [] def load_source_from_file(self, filename): + """Load source code from file""" if self.source is not None: self.logger.warning("Overwriting previous source code") - f = open(filename,"r") - self.load_source_raw(f.read()) - f.close() + with open(filename,"r", encoding="utf8") as f: + self.load_source_raw(f.read()) def write_source_to_file(self, filename): - f = open(filename,"w") - f.write(self.source) - f.close() - - def print_code(self): - print(self.source) + """Write current source code to file""" + with open(filename,"w", encoding="utf8") as f: + f.write(self.get_source_as_string()) def rename_function(self, old_funcname, new_funcname): + """Rename a function in the current source code""" self.source = AsmHelper.rename_function(self.source, old_funcname, new_funcname) @staticmethod def _dump(name, s, logger, err=False): + assert isinstance(s, list) fun = logger.debug if not err else logger.error fun(f"Dump: {name}") - if isinstance(s, str): - s = s.splitlines() for l in s: fun(f"> {l}") @@ -106,6 +163,7 @@ def optimize(self, start=None, end=None, loop_synthesis_cb=None, logname=None): loop_synthesis_cb: Optional (None by default) callback synthesis final source code from tuple of (preamble, kernel, postamble, # exceptional iterations). """ + # pylint:disable=too-many-locals if logname is None and start is not None: logname = start @@ -128,42 +186,43 @@ def optimize(self, start=None, end=None, loop_synthesis_cb=None, logname=None): self.logger.debug("Code after preprocessor:") Slothy._dump("preprocessed", body, self.logger, err=False) - body = AsmHelper.split_semicolons(body) + body = SourceLine.split_semicolons(body) body = AsmMacro.unfold_all_macros(pre, body) body = AsmAllocation.unfold_all_aliases(c.register_aliases, body) - body = AsmHelper.apply_indentation(body, indentation) + body = SourceLine.apply_indentation(body, indentation) self.logger.info("Instructions in body: %d", len(list(filter(None, body)))) early, core, late, num_exceptional = Heuristics.periodic(body, logger, c) def indented(code): - indent = ' ' * self.config.indentation - return [ indent + s for s in code ] + return [ SourceLine(l).set_indentation(indentation) for l in code] if start is not None: - core = [f"{start}:"] + core + core = [SourceLine(f"{start}:")] + core if end is not None: - core += [f"{end}:"] + core += [SourceLine(f"{end}:")] if not self.config.sw_pipelining.enabled: assert early == [] assert late == [] assert num_exceptional == 0 - optimized_source = indented(core) - elif loop_synthesis_cb != None: - optimized_source = indented(loop_synthesis_cb( pre, core, post, num_exceptional)) + optimized_source = core + elif loop_synthesis_cb is not None: + optimized_source = loop_synthesis_cb( pre, core, post, num_exceptional) else: optimized_source = [] optimized_source += indented([f"// Exceptional iterations: {num_exceptional}", "// Preamble"]) - optimized_source += indented(early) + optimized_source += early optimized_source += indented(["// Kernel"]) - optimized_source += indented(core) + optimized_source += core optimized_source += indented(["// Postamble"]) - optimized_source += indented(late) + optimized_source += late - self.source = '\n'.join(pre + optimized_source + post) + self.source = pre + optimized_source + post + assert SourceLine.is_source(self.source) def get_loop_input_output(self, loop_lbl): + """Find all registers that a loop body depends on""" logger = self.logger.getChild(loop_lbl) _, body, _, _, _ = self.arch.Loop.extract(self.source, loop_lbl) @@ -173,6 +232,7 @@ def get_loop_input_output(self, loop_lbl): return list(DFG(body, logger.getChild("dfg_kernel_deps"), dfgc).inputs) def get_input_from_output(self, start, end, outputs=None): + """For the piece of straightline code, infer which input registers affect its output""" if outputs is None: outputs = {} logger = self.logger.getChild(f"{start}_{end}_infer_input") @@ -188,11 +248,7 @@ def get_input_from_output(self, start, end, outputs=None): dfgc = DFGConfig(c) return list(DFG(body, logger.getChild("dfg_find_deps"), dfgc).inputs) - def ssa_region(self, start, end, outputs=None): - if outputs is None: - outputs = {} - logger = self.logger.getChild(f"{start}_{end}_infer_input") - pre, body, post = AsmHelper.extract(self.source, start, end) + def _fusion_core(self, pre, body, logger): c = self.config.copy() if c.with_preprocessor: @@ -200,20 +256,65 @@ def ssa_region(self, start, end, outputs=None): body = CPreprocessor.unfold(pre, body, c.compiler_binary) self.logger.debug("Code after preprocessor:") Slothy._dump("preprocessed", body, self.logger, err=False) - body = AsmHelper.split_semicolons(body) + body = SourceLine.split_semicolons(body) aliases = AsmAllocation.parse_allocs(pre) c.add_aliases(aliases) - c.outputs = outputs body = AsmMacro.unfold_all_macros(pre, body) body = AsmAllocation.unfold_all_aliases(c.register_aliases, body) dfgc = DFGConfig(c) - dfg = DFG(body, logger.getChild("dfg_find_deps"), dfgc) + + dfg = DFG(body, logger.getChild("ssa"), dfgc, parsing_cb=False) dfg.ssa() + body = [ ComputationNode.to_source_line(t) for t in dfg.nodes ] - body_ssa = [ f"{start}:" ] + [ str(t.inst) for t in dfg.nodes ] + [ f"{end}:" ] - self.source = '\n'.join(pre + body_ssa + post) + dfg = DFG(body, logger.getChild("fusion"), dfgc, parsing_cb=False) + dfg.apply_fusion_cbs() + body = [ ComputationNode.to_source_line(t) for t in dfg.nodes ] + + return body + + def fusion_region(self, start, end): + """Run fusion callbacks on straightline code""" + logger = self.logger.getChild(f"ssa_{start}_{end}") + pre, body, post = AsmHelper.extract(self.source, start, end) + + body_ssa = [ SourceLine(f"{start}:") ] +\ + self._fusion_core(pre, body, logger) + \ + [ SourceLine(f"{end}:") ] + self.source = pre + body_ssa + post + assert SourceLine.is_source(self.source) + + def fusion_loop(self, loop_lbl): + """Run fusion callbacks on loop body""" + logger = self.logger.getChild(f"ssa_loop_{loop_lbl}") + + pre , body, post, _, other_data = \ + self.arch.Loop.extract(self.source, loop_lbl) + + indentation = AsmHelper.find_indentation(body) + + loop = self.arch.Loop(lbl_start=loop_lbl) + body_ssa = SourceLine.read_multiline(loop.start()) + \ + SourceLine.apply_indentation(self._fusion_core(pre, body, logger), indentation) + \ + SourceLine.read_multiline(loop.end(other_data)) + + self.source = pre + body_ssa + post + assert SourceLine.is_source(self.source) + + c = self.config.copy() + self.config.keep_tags = True + self.config.constraints.functional_only = True + self.config.constraints.allow_reordering = False + self.config.sw_pipelining.enabled = False + self.config.split_heuristic = False + self.config.inputs_are_outputs = True + self.config.sw_pipelining.unknown_iteration_count = False + self.optimize_loop(loop_lbl) + self.config = c + + assert SourceLine.is_source(self.source) def optimize_loop(self, loop_lbl, postamble_label=None): """Optimize the loop starting at a given label""" @@ -236,22 +337,33 @@ def optimize_loop(self, loop_lbl, postamble_label=None): self.logger.debug("Code after preprocessor:") Slothy._dump("preprocessed", body, self.logger, err=False) - body = AsmHelper.split_semicolons(body) + body = SourceLine.split_semicolons(body) body = AsmMacro.unfold_all_macros(early, body) body = AsmAllocation.unfold_all_aliases(c.register_aliases, body) - body = AsmHelper.apply_indentation(body, indentation) + body = SourceLine.apply_indentation(body, indentation) - insts = len(list(filter(None, body))) - self.logger.info("Optimizing loop %s (%d instructions) ...", loop_lbl, insts) + self.logger.info("Optimizing loop %s (%d instructions) ...", + loop_lbl, len(body)) preamble_code, kernel_code, postamble_code, num_exceptional = \ Heuristics.periodic(body, logger, c) + def indented(code): - indent = ' ' * self.config.indentation - return [ indent + s for s in code ] + if not SourceLine.is_source(code): + code = SourceLine.read_multiline(code) + return SourceLine.apply_indentation(code, self.config.indentation) + + loop_lbl_end = f"{loop_lbl}_end" + def loop_lbl_iter(i): + return SourceLine(f"{loop_lbl}_iter_{i}") - loop = self.arch.Loop(lbl_start=loop_lbl) optimized_code = [] + + if self.config.sw_pipelining.unknown_iteration_count: + for i in range(1, num_exceptional): + optimized_code += indented(self.arch.Branch.if_equal(i, loop_lbl_iter(i))) + + loop = self.arch.Loop(lbl_start=loop_lbl) optimized_code += indented(preamble_code) if self.config.sw_pipelining.unknown_iteration_count: @@ -261,22 +373,34 @@ def indented(code): else: jump_if_empty = None - optimized_code += list(loop.start( + optimized_code += SourceLine.read_multiline(loop.start( indentation=self.config.indentation, fixup=num_exceptional, unroll=self.config.sw_pipelining.unroll, jump_if_empty=jump_if_empty)) optimized_code += indented(kernel_code) - optimized_code += list(loop.end(other_data, indentation=self.config.indentation)) + optimized_code += SourceLine.read_multiline(loop.end(other_data, + indentation=self.config.indentation)) if postamble_label is not None: - optimized_code += [ f"{postamble_label}: // end of loop kernel" ] + optimized_code += [ SourceLine(f"{postamble_label}:") + .add_comment("end of loop kernel") ] optimized_code += indented(postamble_code) + if self.config.sw_pipelining.unknown_iteration_count: + optimized_code += indented(self.arch.Branch.unconditional(loop_lbl_end)) + for i in range(1, num_exceptional): + optimized_code += [SourceLine(f"{loop_lbl_iter(i)}:")] + optimized_code += i * indented(body) + optimized_code += [SourceLine(f"{loop_lbl_iter(i)}_end:")] + if i != num_exceptional - 1: + optimized_code += indented(self.arch.Branch.unconditional(loop_lbl_end)) + optimized_code += [SourceLine(f"{loop_lbl_end}:")] + self.last_result = SimpleNamespace() dfgc = DFGConfig(c) dfgc.inputs_are_outputs = True self.last_result.kernel_input_output = \ list(DFG(kernel_code, logger.getChild("dfg_kernel_deps"), dfgc).inputs) - self.source = '\n'.join(early + optimized_code + late) + self.source = early + optimized_code + late self.success = True diff --git a/slothy/helper.py b/slothy/helper.py index cb43560b..88f148fc 100644 --- a/slothy/helper.py +++ b/slothy/helper.py @@ -29,6 +29,348 @@ import subprocess import logging +class SourceLine: + """Representation of a single line of source code""" + + def _extract_comments_from_text(self): + if not "//" in self._raw: + return + s = list(map(str.strip, self._raw.split("//"))) + self._raw = s[0] + self._comments += s[1:] + self._trim_comments() + + def _extract_indentation_from_text(self): + old = self._raw + new = old.lstrip() + self._indentation = len(old) - len(new) + self._raw = new + + @staticmethod + def _parse_tags_in_string(s, tags): + def parse_value(v): + if v.lower() == "true": + return True + if v.lower() == "false": + return False + if v.isnumeric(): + return int(v) + return v + + def tag_value_callback(g): + tag = g.group("tag") + value = parse_value(g.group("value")) + tags[tag] = value + return "" + + def tag_callback(g): + tag = g.group("tag") + tags[tag] = True + return "" + + tag_value_regexp_txt = r"@slothy:(?P(\w|-)+)=(?P\w+)" + tag_regexp_txt = r"@slothy:(?P(\w|-)+)" + s = re.sub(tag_value_regexp_txt, tag_value_callback, s) + s = re.sub(tag_regexp_txt, tag_callback, s) + return s + + def _strip_comments(self): + self._comments = list(map(str.strip, self._comments)) + + def _trim_comments(self): + self._strip_comments() + self._comments = list(filter(lambda s: s != "", self._comments)) + + def _extract_tags_from_comments(self): + tags = {} + self._comments = list(map(lambda c: SourceLine._parse_tags_in_string(c, tags), + self._comments)) + self._trim_comments() + self.add_tags(tags) + + def reduce(self): + """Extract metadata (tags, comments, indentation) from raw text + + The extracted components get retracted from the text.""" + self._extract_indentation_from_text() + self._extract_comments_from_text() + self._extract_tags_from_comments() + return self + + def add_comment(self, comment): + """Add a comment to the metadata of a source line""" + self._comments.append(comment) + return self + + def add_comments(self, comments): + """Add one or more comments to the metadata of a source line""" + for c in comments: + self.add_comment(c) + return self + + def set_comments(self, comments): + """Set comments for source line. + + Overwrites existing comments.""" + self._comments = comments + return self + + def set_comment(self, comment): + """Set single comment for source line. + + Overwrites existing comments.""" + self.set_comments([comment]) + return self + + def __init__(self, s, reduce=True): + """Create source line from string""" + assert isinstance(s, str) + + self._raw = s + self._tags = {} + self._indentation = 0 + self._fixlength = None + self._comments = [] + + if reduce is True: + self.reduce() + + def set_tag(self, tag, value=True): + """Set source line tag""" + self._tags[tag] = value + return self + + def set_length(self, length): + """Set the padded length of the text component of the source line + + When printing the source line with to_string(), the source text will be + whitespace padded to the specified length before adding comments and tags. + This allows to print multiple commented source lines with a uniform + indentation for the comments, improving readability.""" + self._fixlength = length + return self + + @property + def tags(self): + """Return the list of tags for the source line + + Tags are source annotations of the form @slothy:(tag[=value]?). + """ + return self._tags + @tags.setter + def tags(self, v): + self._tags = v + + @property + def comments(self): + """Return the list of comments for the source line""" + return self._comments + @comments.setter + def comments(self, v): + self._comments = v + + def has_text(self): + """Indicates if the source line constaints some text""" + return self._raw.strip() != "" + + @property + def text(self): + """Returns the (non-metadata) text in the source line""" + return self._raw + + def to_string(self, indentation=False, comments=False, tags=False): + """Convert source line to a string + + This includes formatting the metadata in a way reversing the + parsing done in the _extract_xxx() routines.""" + if self._fixlength is None: + core = self._raw + else: + core = f"{self._raw:{self._fixlength}s}" + + indentation = ' ' * self._indentation \ + if indentation is True else "" + comments = ''.join(map(lambda s: f"// {s}", self._comments)) \ + if comments is True else "" + tags = ' '.join(map(lambda tv: f" // @slothy:{tv[0]}={tv[1]}", self._tags.items())) \ + if tags is True else "" + + return f"{indentation}{core}{comments}{tags}" + + def __str__(self): + return self.to_string() + + @staticmethod + def reduce_source(src): + """Extract metadata (e.g. indentation, tags, comments) from source lines""" + assert SourceLine.is_source(src) + for l in src: + l.reduce() + return [ l for l in src if l.has_text() ] + + @staticmethod + def log(name, s, logger=None, err=False): + """Send source to logger""" + assert isinstance(s, list) + if err: + fun = logger.error + else: + fun = logger.debug + if len(s) == 0: + return + fun(f"Dump: {name}") + for l in s: + fun(f"> {l}") + + def set_text(self, s): + """Set the text of the source line + + This only affects the instruction text of the source line, but leaves + metadata (such as comments, indentation or tags) unmodified.""" + self._raw = s + return self + + def add_text(self, s): + """Add text to a source line + + This only affects the instruction text of the source line, but leaves + metadata (such as comments, indentation or tags) unmodified.""" + self._raw += " " + s + return self + + @property + def is_escaped(self): + """Indicates if line text ends with a backslash""" + return self.text.endswith("\\") + + def remove_escaping(self): + """Remove escape character at end of line, if present""" + if not self.is_escaped: + return self + self._raw = self._raw[:-1] + return self + + def __len__(self): + return len(self._raw) + + def copy(self): + """Create a copy of a source line""" + return SourceLine(self._raw) \ + .add_tags(self._tags.copy()) \ + .set_indentation(self._indentation) \ + .add_comments(self._comments.copy())\ + .set_length(self._fixlength) + + @staticmethod + def read_multiline(s, reduce=True): + """Parse multi-line string or array of strings into list of SourceLine instances""" + if isinstance(s, str): + s = s.splitlines() + return SourceLine.merge_escaped_lines([ SourceLine(l, reduce=reduce) for l in s ]) + + @staticmethod + def merge_escaped_lines(s): + """Merge lines ending in a backslash with subsequent line(s)""" + assert SourceLine.is_source(s) + res = [] + cur = None + for l in s: + if cur is not None: + cur.add_text(l.text) + else: + cur = l.copy() + if cur.is_escaped: + cur.remove_escaping() + else: + res.append(cur) + cur = None + assert cur is None + return res + + @staticmethod + def copy_source(s): + """Create a copy of a list of source lines""" + assert SourceLine.is_source(s) + return [ l.copy() for l in s ] + + @staticmethod + def write_multiline(s, comments=True, indentation=True, tags=True): + """Write source as multiline string""" + return '\n'.join(map(lambda t: t.to_string( + comments=comments, tags=tags, indentation=indentation), s)) + + def set_indentation(self, indentation): + """Set the indentation (in number of spaces) to be used in to_string()""" + self._indentation = indentation + return self + + def add_tags(self, tags): + """Add one or more tags to the metadata of the source line + + tags must be a tag:value dictionary.""" + self._tags = {**self._tags, **tags} + return self + + def add_tag(self, tag, value): + """Add a single tag-value pair to the metadata of the source line + + If a tag is already specified, it is overwritten.""" + return self.add_tags({ tag: value }) + + def inherit_tags(self, l): + """Inhertis the tags from another source line + + In case of overlapping tags, source line l takes precedence.""" + assert SourceLine.is_source_line(l) + self.add_tags(l.tags) + return self + + @staticmethod + def apply_indentation(source, indentation): + """Apply consistent indentation to assembly source""" + assert SourceLine.is_source(source) + if indentation is None: + return source + assert isinstance(indentation, int) + return [ l.copy().set_indentation(indentation) for l in source ] + + @staticmethod + def drop_tags(source): + """Drop all tags from a source""" + assert SourceLine.is_source(source) + for l in source: + l.tags = {} + return source + + @staticmethod + def split_semicolons(s): + """"Split the text of a source line at semicolons + + The resulting source lines inherit their metadata from the caller.""" + assert SourceLine.is_source(s) + res = [] + for line in s: + for l in str(line).split(';'): + t = line.copy() + t.set_text(l) + res.append(t) + return res + + @staticmethod + def is_source(s): + """Check if parameter is a list of SourceLine instances""" + if isinstance(s, list) is False: + return False + for t in s: + if isinstance(t, SourceLine) is False: + return False + return True + + @staticmethod + def is_source_line(s): + """Checks if the parameter is a SourceLine instance""" + return isinstance(s, SourceLine) + class NestedPrint(): """Helper for recursive printing of structures""" def __str__(self): @@ -44,7 +386,7 @@ def log(self, fun): for l in str(self).splitlines(): fun(l) -class LockAttributes(object): +class LockAttributes: """Base class adding support for 'locking' the set of attributes, that is, preventing the creation of any further attributes. Note that the modification of already existing attributes remains possible. @@ -62,7 +404,7 @@ def __setattr__(self, attr, val): varlist = [v for v in dir(self) if not v.startswith("_") ] varlist = '\n'.join(map(lambda x: '* ' + x, varlist)) raise TypeError(f"Unknown attribute {attr}. \nValid attributes are:\n{varlist}") - elif self._locked and attr == "_locked": + if self._locked and attr == "_locked": raise TypeError("Can't unlock an object") object.__setattr__(self,attr,val) @@ -79,6 +421,7 @@ def find_indentation(source): def get_indentation(l): return len(l) - len(l.lstrip()) + source = map(str, source) # Remove empty lines source = list(filter(lambda t: t.strip() != "", source)) l = len(source) @@ -99,88 +442,24 @@ def get_indentation(l): return None - @staticmethod - def apply_indentation(source, indentation): - """Apply consistent indentation to assembly source""" - if indentation is None: - return source - assert isinstance(indentation, int) - indent = ' ' * indentation - return [ indent + l.lstrip() for l in source ] - @staticmethod def rename_function(source, old_funcname, new_funcname): """Rename function in assembly snippet""" # For now, just replace function names line by line - def change_funcname(s): + def change_funcname(line): + s = str(line) s = re.sub( f"{old_funcname}:", f"{new_funcname}:", s) s = re.sub( f"\\.global(\\s+){old_funcname}", f".global\\1{new_funcname}", s) s = re.sub( f"\\.type(\\s+){old_funcname}", f".type\\1{new_funcname}", s) - return s - return '\n'.join([ change_funcname(s) for s in source.splitlines() ]) - - @staticmethod - def split_semicolons(body): - """Split assembly snippet across semicolons`""" - return [ l for s in body for l in s.split(';') ] - - @staticmethod - def reduce_source_line(line): - """Simplify or ignore assembly source line""" - regexp_align_txt = r"^\s*\.(?:p2)?align" - regexp_req_txt = r"\s*(?P\w+)\s+\.req\s+(?P\w+)" - regexp_unreq_txt = r"\s*\.unreq\s+(?P\w+)" - regexp_label_txt = r"\s*(?P