Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to force loop type #127

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,23 @@ def core(self,slothy):
slothy.config.variable_size=True
slothy.config.outputs = ["r6"]
slothy.optimize_loop("start")

class Armv7mLoopVmovCmpForced(Example):
def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7):
name = "loop_vmov_cmp_forced"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.variable_size=True
slothy.config.outputs = ["r5", "r6"]
slothy.optimize_loop("start", forced_loop_type=Arch_Armv7M.CmpLoop)

class AArch64IfElse(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
Expand Down Expand Up @@ -1561,6 +1578,7 @@ def main():
Armv7mLoopSubs(),
Armv7mLoopCmp(),
Armv7mLoopVmovCmp(),
Armv7mLoopVmovCmpForced(),

CRT(),

Expand Down
15 changes: 15 additions & 0 deletions examples/naive/armv7m/loop_vmov_cmp_forced.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/* For example, r5 represents an address where we will stop iterating and r6 is
the actual pointer which is incremented inside the loop.

In this specific example, the vmov shall not be accounted towards the loop
boundary but rather the body. */

mov.w r6, #0
add.w r5, r6, #64
vmov s0, r5

start:
add r6, r6, #4
vmov r5, s0
cmp r6, r5
bne start
35 changes: 35 additions & 0 deletions examples/opt/armv7m/loop_vmov_cmp_forced_opt_m7.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* For example, r5 represents an address where we will stop iterating and r6 is
the actual pointer which is incremented inside the loop.

In this specific example, the vmov shall not be accounted towards the loop
boundary but rather the body. */

mov.w r6, #0
add.w r5, r6, #64
vmov s0, r5

1:
// Instructions: 2
// Expected cycles: 1
// Expected IPC: 2.00
//
// Cycle bound: 1.0
// IPC bound: 2.00
//
// Wall time: 0.04s
// User time: 0.04s
//
// ----- cycle (expected) ------>
// 0 25
// |------------------------|----
add r6, r6, #4 // *.............................
vmov r5, s0 // *.............................

// ------ cycle (expected) ------>
// 0 25
// |------------------------|-----
// add r6, r6, #4 // *..............................
// vmov r5, s0 // *..............................

cmp r6, r5
bne 1b
25 changes: 17 additions & 8 deletions slothy/core/slothy.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,15 @@ def indented(code):
self.source = pre + optimized_source + post
assert SourceLine.is_source(self.source)

def get_loop_input_output(self, loop_lbl):
"""Find all registers that a loop body depends on"""
def get_loop_input_output(self, loop_lbl, forced_loop_type=None):
"""Find all registers that a loop body depends on

Args:
loop_lbl: Label of loop to process.
forced_loop_type: Forces the loop to be parsed as a certain type.
"""
logger = self.logger.getChild(loop_lbl)
_, body, _, _, _ = self.arch.Loop.extract(self.source, loop_lbl)
_, body, _, _, _ = self.arch.Loop.extract(self.source, loop_lbl, forced_loop_type=forced_loop_type)

c = self.config.copy()
dfgc = DFGConfig(c)
Expand Down Expand Up @@ -432,18 +437,19 @@ def fusion_region(self, start, end, **kwargs):
self.source = pre + body_ssa + post
assert SourceLine.is_source(self.source)

def fusion_loop(self, loop_lbl, **kwargs):
def fusion_loop(self, loop_lbl, forced_loop_type=None, **kwargs):
"""Run fusion callbacks on loop body replacing certain instruction
(sequences) with an alternative. These replacements are defined in the
architectural model by setting an instruction class' global_fusion_cb.

Args:
loop_lbl: Label of loop to which the fusions are applied to.
forced_loop_type: Forces the loop to be parsed as a certain type.
"""
logger = self.logger.getChild(f"ssa_loop_{loop_lbl}")

pre , body, post, _, other_data, loop = \
self.arch.Loop.extract(self.source, loop_lbl)
self.arch.Loop.extract(self.source, loop_lbl, forced_loop_type=forced_loop_type)
loop_cnt = other_data['cnt']
indentation = AsmHelper.find_indentation(body)

Expand All @@ -454,15 +460,18 @@ def fusion_loop(self, loop_lbl, **kwargs):
self.source = pre + body_ssa + post
assert SourceLine.is_source(self.source)

def optimize_loop(self, loop_lbl, postamble_label=None):
def optimize_loop(self, loop_lbl, postamble_label=None, forced_loop_type=None):
"""Optimize the loop starting at a given label
The postamble_label marks the end of the loop kernel.

Args:
postamble_label: Marks end of loop kernel.
forced_loop_type: Forces the loop to be parsed as a certain type.
"""

logger = self.logger.getChild(loop_lbl)

early, body, late, _, other_data, loop = \
self.arch.Loop.extract(self.source, loop_lbl)
self.arch.Loop.extract(self.source, loop_lbl, forced_loop_type=forced_loop_type)
loop_cnt = other_data['cnt']

# Check if the body has a dominant indentation
Expand Down
17 changes: 15 additions & 2 deletions slothy/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,8 +1559,21 @@ def _extract(self, source, lbl):
return pre, body, post, lbl, self.additional_data

@staticmethod
def extract(source, lbl):
for loop_type in Loop.__subclasses__():
def extract(source, lbl, forced_loop_type=None):
"""
Find a loop with start label `lbl` in `source` and return it together
with its type.

Args:
source: list of SourceLine objects
lbl: label of the loop to extract
forced_loop_type: if not None, only try to extract this type of loop
"""
if forced_loop_type is not None:
loop_types = [forced_loop_type]
else:
loop_types = Loop.__subclasses__()
for loop_type in loop_types:
try:
l = loop_type(lbl)
# concatenate the extracted loop with an instance of the
Expand Down
Loading