Skip to content

Commit

Permalink
Add BranchLoop type
Browse files Browse the repository at this point in the history
  • Loading branch information
mkannwischer committed Dec 18, 2024
1 parent bae251d commit acb04d1
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
5 changes: 4 additions & 1 deletion slothy/core/slothy.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,10 @@ def optimize_loop(self, loop_lbl, postamble_label=None, forced_loop_type=None):

early, body, late, _, other_data, loop = \
self.arch.Loop.extract(self.source, loop_lbl, forced_loop_type=forced_loop_type)
loop_cnt = other_data['cnt']
try:
loop_cnt = other_data['cnt']
except KeyError:
loop_cnt = None

# Check if the body has a dominant indentation
indentation = AsmHelper.find_indentation(body)
Expand Down
91 changes: 91 additions & 0 deletions slothy/targets/arm_v7m/arch_v7m.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,97 @@ def end(self, other, indentation=0):
yield f'{indent}cmp {other["cnt"]}, {other["end"]}'
yield f'{indent}bne {lbl_start}'


class BranchLoop(Loop):
def __init__(self, lbl="lbl", lbl_start="1", lbl_end="2", loop_init="lr") -> None:
super().__init__(lbl_start=lbl_start, lbl_end=lbl_end, loop_init=loop_init)
self.lbl = lbl
self.lbl_regex = r"^\s*(?P<label>\w+)\s*:(?P<remainder>.*)$"
self.end_regex = (rf"^\s*(cbnz|cbz|bne)(?:\.w)?\s+{lbl}",)

def start(self, loop_cnt, indentation=0, fixup=0, unroll=1, jump_if_empty=None, preamble_code=None, body_code=None, postamble_code=None, register_aliases=None):
"""Emit starting instruction(s) and jump label for loop"""
indent = ' ' * indentation
if body_code is None:
logging.debug(f"No body code in loop start: Just printing label.")
yield f"{self.lbl}:"
return
# Identify the register that is used as a loop counter
body_code = [l for l in body_code if l.text != ""]
for l in body_code:
inst = Instruction.parser(l)
# Flags are set through cmp
# LIMITATION: By convention, we require the first argument to be the
# "counter" and the second the one marking the iteration end.
if isinstance(inst[0], cmp):
# Assume this mapping
loop_cnt_reg = inst[0].args_in[0]
loop_end_reg = inst[0].args_in[1]
logging.debug(f"Assuming {loop_cnt_reg} as counter register and {loop_end_reg} as end register.")
break
# Flags are set through subs
elif isinstance(inst[0], subs_imm_short):
loop_cnt_reg = inst[0].args_in_out[0]
loop_end_reg = inst[0].args_in_out[0]
break

# Find FPR that is used to stash the loop end incase it's vmov loop
loop_end_reg_fpr = None
for li, l in enumerate(body_code):
inst = Instruction.parser(l)
# Flags are set through cmp
if isinstance(inst[0], vmov_gpr):
if loop_end_reg in inst[0].args_out:
logging.debug(f"Copying from {inst[0].args_in} to {loop_end_reg}")
loop_end_reg_fpr = inst[0].args_in[0]

# The last vmov occurance before the cmp that writes to the register
# we compare to will be the right one. The same GPR could be written
# previously due to renaming, before it becomes the value used in
# the cmp.
if isinstance(inst[0], cmp):
break

if unroll > 1:
assert unroll in [1,2,4,8,16,32]
yield f"{indent}lsr {loop_end_reg}, {loop_end_reg}, #{int(math.log2(unroll))}"

inc_per_iter = 0
for l in body_code:
inst = Instruction.parser(l)
# Increment happens through pointer modification
if loop_cnt_reg.lower() == inst[0].addr and inst[0].increment is not None:
inc_per_iter = inc_per_iter + simplify(inst[0].increment)
# Increment through explicit modification
elif loop_cnt_reg.lower() in (inst[0].args_out + inst[0].args_in_out) and inst[0].immediate is not None:
# TODO: subtract if we have a subtraction
inc_per_iter = inc_per_iter + simplify(inst[0].immediate)
logging.debug(f"Loop counter {loop_cnt_reg} is incremented by {inc_per_iter} per iteration.")

if fixup != 0 and loop_end_reg_fpr is not None:
yield f"{indent}push {{{loop_end_reg}}}"
yield f"{indent}vmov {loop_end_reg}, {loop_end_reg_fpr}"

if fixup != 0:
yield f"{indent}sub {loop_end_reg}, {loop_end_reg}, #{fixup*inc_per_iter}"

if fixup != 0 and loop_end_reg_fpr is not None:
yield f"{indent}vmov {loop_end_reg_fpr}, {loop_end_reg}"
yield f"{indent}pop {{{loop_end_reg}}}"

if jump_if_empty is not None:
yield f"cbz {loop_cnt}, {jump_if_empty}"
yield f"{self.lbl}:"

def end(self, other, indentation=0):
"""Emit compare-and-branch at the end of the loop"""
indent = ' ' * indentation
lbl_start = self.lbl
if lbl_start.isdigit():
lbl_start += "b"

yield f'{indent}bne {lbl_start}'

class CmpLoop(Loop):
"""
Loop ending in a compare and a branch.
Expand Down

0 comments on commit acb04d1

Please sign in to comment.