diff --git a/example.py b/example.py index fc133550..9decb2e0 100644 --- a/example.py +++ b/example.py @@ -706,6 +706,21 @@ def core(self,slothy): slothy.config.outputs = ["r6"] slothy.optimize_loop("start") +class AArch64IfElse(Example): + def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55): + name = "aarch64_ifelse" + infile = name + + if var != "": + name += f"_{var}" + infile += f"_{var}" + name += f"_{target_label_dict[target]}" + + super().__init__(infile, name, rename=True, arch=arch, target=target) + + def core(self,slothy): + slothy.optimize() + class ntt_kyber_123_4567(Example): def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55, timeout=None): name = "ntt_kyber_123_4567" @@ -1494,6 +1509,7 @@ def main(): AArch64Example1(target=Target_CortexA72), AArch64Example2(), AArch64Example2(target=Target_CortexA72), + AArch64IfElse(), # Armv7m examples Armv7mExample0(), diff --git a/examples/naive/aarch64/aarch64_ifelse.s b/examples/naive/aarch64/aarch64_ifelse.s new file mode 100644 index 00000000..f621859f --- /dev/null +++ b/examples/naive/aarch64/aarch64_ifelse.s @@ -0,0 +1,30 @@ +ldr q0, [x1, #0] +ldr q1, [x2, #0] + +ldr q8, [x0] +ldr q9, [x0, #1*16] +ldr q10, [x0, #2*16] +ldr q11, [x0, #3*16] +.if 5 != 0 + mul v24.8h, v9.8h, v0.h[0] + sqrdmulh v9.8h, v9.8h, v0.h[1] + mls v24.8h, v9.8h, v1.h[0] + sub v9.8h, v8.8h, v24.8h + add v8.8h, v8.8h, v24.8h + + .if 5 > 2 + mul v24.8h, v11.8h, v0.h[0] + sqrdmulh v11.8h, v11.8h, v0.h[1] + mls v24.8h, v11.8h, v1.h[0] + sub v11.8h, v10.8h, v24.8h + add v10.8h, v10.8h, v24.8h + .else + add v10.8h, v10.8h, v11.8h + .endif +.else + add x0, x0, #4 +.endif +str q8, [x0], #4*16 +str q9, [x0, #-3*16] +str q10, [x0, #-2*16] +str q11, [x0, #-1*16] \ No newline at end of file diff --git a/examples/opt/aarch64/aarch64_ifelse_opt_a55.s b/examples/opt/aarch64/aarch64_ifelse_opt_a55.s new file mode 100644 index 00000000..ffb154cf --- /dev/null +++ b/examples/opt/aarch64/aarch64_ifelse_opt_a55.s @@ -0,0 +1,54 @@ + // Instructions: 20 + // Expected cycles: 28 + // Expected IPC: 0.71 + // + // Wall time: 0.25s + // User time: 0.25s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q2, [x0, #48] // *............................. + ldr q4, [x1, #0] // ..*........................... + ldr q13, [x0, #16] // ....*......................... + mul v27.8H, v2.8H, v4.H[0] // ......*....................... + sqrdmulh v6.8H, v2.8H, v4.H[1] // .......*...................... + ldr q3, [x2, #0] // ........*..................... + mul v24.8H, v13.8H, v4.H[0] // ..........*................... + ldr q28, [x0, #32] // ...........*.................. + mls v27.8H, v6.8H, v3.H[0] // .............*................ + sqrdmulh v14.8H, v13.8H, v4.H[1] // ..............*............... + ldr q2, [x0] // ...............*.............. + sub v18.8H, v28.8H, v27.8H // .................*............ + mls v24.8H, v14.8H, v3.H[0] // ..................*........... + add v9.8H, v28.8H, v27.8H // ....................*......... + str q18, [x0, #48] // .....................*........ + add v12.8H, v2.8H, v24.8H // ......................*....... + str q9, [x0, #32] // .......................*...... + sub v3.8H, v2.8H, v24.8H // ........................*..... + str q12, [x0], #4*16 // .........................*.... + str q3, [x0, #-48] // ...........................*.. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q0, [x1, #0] // ..*............................ + // ldr q1, [x2, #0] // ........*...................... + // ldr q8, [x0] // ...............*............... + // ldr q9, [x0, #1*16] // ....*.......................... + // ldr q10, [x0, #2*16] // ...........*................... + // ldr q11, [x0, #3*16] // *.............................. + // mul v24.8h, v9.8h, v0.h[0] // ..........*.................... + // sqrdmulh v9.8h, v9.8h, v0.h[1] // ..............*................ + // mls v24.8h, v9.8h, v1.h[0] // ..................*............ + // sub v9.8h, v8.8h, v24.8h // ........................*...... + // add v8.8h, v8.8h, v24.8h // ......................*........ + // mul v24.8h, v11.8h, v0.h[0] // ......*........................ + // sqrdmulh v11.8h, v11.8h, v0.h[1] // .......*....................... + // mls v24.8h, v11.8h, v1.h[0] // .............*................. + // sub v11.8h, v10.8h, v24.8h // .................*............. + // add v10.8h, v10.8h, v24.8h // ....................*.......... + // str q8, [x0], #4*16 // .........................*..... + // str q9, [x0, #-3*16] // ...........................*... + // str q10, [x0, #-2*16] // .......................*....... + // str q11, [x0, #-1*16] // .....................*......... diff --git a/slothy/core/slothy.py b/slothy/core/slothy.py index b0b776ee..76cbc1ae 100644 --- a/slothy/core/slothy.py +++ b/slothy/core/slothy.py @@ -53,7 +53,7 @@ from slothy.core.core import Config from slothy.core.heuristics import Heuristics from slothy.helper import CPreprocessor, SourceLine -from slothy.helper import AsmAllocation, AsmMacro, AsmHelper +from slothy.helper import AsmAllocation, AsmMacro, AsmHelper, AsmIfElse from slothy.helper import CPreprocessor, LLVM_Mca, LLVM_Mca_Error class Slothy: @@ -251,6 +251,7 @@ def optimize(self, start=None, end=None, loop_synthesis_cb=None, logname=None): body = SourceLine.split_semicolons(body) body = AsmMacro.unfold_all_macros(pre, body, inherit_comments=c.inherit_macro_comments) body = AsmAllocation.unfold_all_aliases(c.register_aliases, body) + body = AsmIfElse.process_instructions(body) body = SourceLine.apply_indentation(body, indentation) self.logger.info("Instructions in body: %d", len(list(filter(None, body)))) diff --git a/slothy/helper.py b/slothy/helper.py index 14bac9b8..86f118d7 100644 --- a/slothy/helper.py +++ b/slothy/helper.py @@ -29,9 +29,10 @@ import subprocess import logging from abc import ABC, abstractmethod - +from sympy import simplify from slothy.targets.common import * + class SourceLine: """Representation of a single line of source code""" @@ -935,6 +936,106 @@ def extract_from_file(filename): res = AsmMacro.extract(f.read().splitlines()) return res + +class AsmIfElse(): + _REGEXP_IF_TXT = r"\s*\.if\s+(?P.*)" + _REGEXP_ELSE_TXT = r"\s*\.else" + _REGEXP_ENDIF_TXT = r"\s*\.endif" + + _REGEXP_IF = re.compile(_REGEXP_IF_TXT) + _REGEXP_ELSE = re.compile(_REGEXP_ELSE_TXT) + _REGEXP_ENDIF = re.compile(_REGEXP_ENDIF_TXT) + + @staticmethod + def check_if(line): + """Check if an assembly line is a .req directive. Return the pair + of alias and register, if so. Otherwise, return None.""" + assert SourceLine.is_source_line(line) + + p = AsmIfElse._REGEXP_IF.match(line.text) + if p is not None: + return p.group("cond") + return None + + @staticmethod + def is_if(line): + return AsmIfElse.check_if(line) is not None + + @staticmethod + def check_else(line): + """Check if an assembly line is a .req directive. Return the pair + of alias and register, if so. Otherwise, return None.""" + assert SourceLine.is_source_line(line) + + p = AsmIfElse._REGEXP_ELSE.match(line.text) + if p is not None: + return True + return None + + @staticmethod + def is_else(line): + return AsmIfElse.check_else(line) is not None + + @staticmethod + def check_endif(line): + """Check if an assembly line is a .req directive. Return the pair + of alias and register, if so. Otherwise, return None.""" + assert SourceLine.is_source_line(line) + + p = AsmIfElse._REGEXP_ENDIF.match(line.text) + if p is not None: + return True + return None + + @staticmethod + def is_endif(line): + return AsmIfElse.check_endif(line) is not None + + @staticmethod + def evaluate_condition(condition): + """Evaluates the condition string and returns True or False.""" + try: + # Evaluate the condition and return the result. + return simplify(condition) + except Exception as e: + print(f"Error evaluating condition '{condition}': {e}") + return False + + @staticmethod + def process_instructions(instructions): + """Processes a list of instructions with conditional statements.""" + output_lines = [] + skip_stack = [] + + for instruction in instructions: + if AsmIfElse.is_if(instruction): + # Extract condition and evaluate it. + condition = AsmIfElse.check_if(instruction) + if AsmIfElse.evaluate_condition(condition): + skip_stack.append(False) + else: + skip_stack.append(True) + continue + elif AsmIfElse.is_else(instruction): + if skip_stack: + # Invert the top of the stack + skip_stack[-1] = not skip_stack[-1] + continue # Skip adding the .else line to output + elif AsmIfElse.is_endif(instruction): + if skip_stack: + skip_stack.pop() # Exit the current .if block + continue # Skip adding the .endif line to output + + # Determine if the current line should be skipped + if skip_stack and True in skip_stack: + continue # Skip lines when inside a false .if block + + # Add the line to output if not skipped + output_lines.append(instruction) + + return output_lines + + class CPreprocessor(): """Helper class for the application of the C preprocessor"""