Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Fusion CB for splitting #105

Merged
merged 4 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,21 @@ def core(self,slothy):
slothy.config.sw_pipelining.optimize_postamble = False
slothy.optimize_loop("start")

class AArch64Split0(Example):
def __init__(self, var="", arch=AArch64_Neon, target=Target_CortexA55):
name = "aarch64_split0"
infile = name

if var != "":
name += f"_{var}"
infile += f"_{var}"
name += f"_{target_label_dict[target]}"

super().__init__(infile, name, rename=True, arch=arch, target=target)

def core(self,slothy):
slothy.config.allow_useless_instructions = True
slothy.fusion_region("start", "end", ssa=False)
class Armv7mExample0(Example):
def __init__(self, var="", arch=Arch_Armv7M, target=Target_CortexM7):
name = "armv7m_simple0"
Expand Down Expand Up @@ -1534,6 +1549,8 @@ def main():
AArch64Example2(target=Target_CortexA72),
AArch64IfElse(),

AArch64Split0(),

# Armv7m examples
Armv7mExample0(),
Armv7mExample0Func(),
Expand Down
15 changes: 15 additions & 0 deletions examples/naive/aarch64/aarch64_split0.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
start:
ldr q0, [x1, #0]
ldr q1, [x2, #0]
eor3 v5.16b, v1.16b, v2.16b, v3.16b // @slothy:some_tag // some comment
eor3 v3.16b, v1.16b, v2.16b, v3.16b // Cannot we split naively
ldr q8, [x0]
ldr q9, [x0, #1*16]
ldr q10, [x0, #2*16]
ldr q11, [x0, #3*16]
mul v24.8h, v9.8h, v0.h[0]
sqrdmulh v9.8h, v9.8h, v0.h[1]
mls v24.8h, v9.8h, v1.h[0]
sub v9.8h, v8.8h, v24.8h
add v8.8h, v8.8h, v24.8h
end:
17 changes: 17 additions & 0 deletions examples/opt/aarch64/aarch64_split0_opt_a55.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
start:
ldr q0, [x1, #0]
ldr q1, [x2, #0]
eor v5.16B, v1.16B, v2.16B// some comment // @slothy:some_tag
eor v5.16B, v5.16B, v3.16B// some comment // @slothy:some_tag
eor3 v3.16B, v1.16B, v2.16B, v3.16B// Cannot we split naively
ldr q8, [x0]
ldr q9, [x0, #16]
ldr q10, [x0, #32]
ldr q11, [x0, #48]
mul v24.8H, v9.8H, v0.H[0]
sqrdmulh v9.8H, v9.8H, v0.H[1]
mls v24.8H, v9.8H, v1.H[0]
sub v9.8H, v8.8H, v24.8H
add v8.8H, v8.8H, v24.8H
end:

13 changes: 12 additions & 1 deletion slothy/core/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,18 @@ def apply_cbs(self, cb, logger, one_a_time=False):
break

z = filter(lambda x: x.delete is False, self.nodes)
z = map(lambda x: ([x.inst], x.inst.source_line), z)

def pair_with_source(i):
return ([i], i.source_line)
def map_node(t):
s = t.inst
if not isinstance(t.inst, list):
s = [s]
return map(pair_with_source, s)
def flatten(llst):
return [x for y in llst for x in y]

z = flatten(map(map_node, z))

self.src = list(z)

Expand Down
50 changes: 27 additions & 23 deletions slothy/core/slothy.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def get_input_from_output(self, start, end, outputs=None):
dfgc = DFGConfig(c)
return list(DFG(body, logger.getChild("dfg_find_deps"), dfgc).inputs)

def _fusion_core(self, pre, body, post, logger):
def _fusion_core(self, pre, body, post, logger, ssa=True):
c = self.config.copy()

if c.with_preprocessor:
Expand All @@ -400,29 +400,46 @@ def _fusion_core(self, pre, body, post, logger):
body = AsmAllocation.unfold_all_aliases(c.register_aliases, body)
dfgc = DFGConfig(c)

dfg = DFG(body, logger.getChild("ssa"), dfgc, parsing_cb=False)
dfg.ssa()
body = [ ComputationNode.to_source_line(t) for t in dfg.nodes ]
if ssa is True:
dfg = DFG(body, logger.getChild("ssa"), dfgc, parsing_cb=False)
dfg.ssa()
body = [ ComputationNode.to_source_line(t) for t in dfg.nodes ]

dfg = DFG(body, logger.getChild("fusion"), dfgc, parsing_cb=False)
dfg.apply_fusion_cbs()
body = [ ComputationNode.to_source_line(t) for t in dfg.nodes ]

return body

def fusion_region(self, start, end):
"""Run fusion callbacks on straightline code"""
def fusion_region(self, start, end, **kwargs):
""" Run fusion callbacks on straightline code replacing certain
instruction (sequences) with an alternative. These replacements are
defined in the architectural model by setting an instruction class'
global_fusion_cb.

Args:
start: The label marking the beginning of the part of the code to
apply fusion to.
end: The label marking the end of the part of the code to apply
fusion to.
"""
logger = self.logger.getChild(f"ssa_{start}_{end}")
pre, body, post = AsmHelper.extract(self.source, start, end)

body_ssa = [ SourceLine(f"{start}:") ] +\
self._fusion_core(pre, body, logger) + \
self._fusion_core(pre, body, post, logger, **kwargs) + \
mkannwischer marked this conversation as resolved.
Show resolved Hide resolved
[ SourceLine(f"{end}:") ]
self.source = pre + body_ssa + post
assert SourceLine.is_source(self.source)

def fusion_loop(self, loop_lbl):
"""Run fusion callbacks on loop body"""
def fusion_loop(self, loop_lbl, **kwargs):
mkannwischer marked this conversation as resolved.
Show resolved Hide resolved
"""Run fusion callbacks on loop body replacing certain instruction
(sequences) with an alternative. These replacements are defined in the
architectural model by setting an instruction class' global_fusion_cb.

Args:
loop_lbl: Label of loop to which the fusions are applied to.
"""
logger = self.logger.getChild(f"ssa_loop_{loop_lbl}")

pre , body, post, _, other_data, loop = \
Expand All @@ -431,25 +448,12 @@ def fusion_loop(self, loop_lbl):
indentation = AsmHelper.find_indentation(body)

body_ssa = SourceLine.read_multiline(loop.start(loop_cnt)) + \
SourceLine.apply_indentation(self._fusion_core(pre, body, logger), indentation) + \
SourceLine.apply_indentation(self._fusion_core(pre, body, post, logger, **kwargs), indentation) + \
SourceLine.read_multiline(loop.end(other_data))

self.source = pre + body_ssa + post
assert SourceLine.is_source(self.source)

c = self.config.copy()
self.config.keep_tags = True
self.config.constraints.functional_only = True
self.config.constraints.allow_reordering = False
self.config.sw_pipelining.enabled = False
self.config.split_heuristic = False
self.config.inputs_are_outputs = True
self.config.sw_pipelining.unknown_iteration_count = False
self.optimize_loop(loop_lbl)
self.config = c

assert SourceLine.is_source(self.source)

mkannwischer marked this conversation as resolved.
Show resolved Hide resolved
def optimize_loop(self, loop_lbl, postamble_label=None):
"""Optimize the loop starting at a given label
The postamble_label marks the end of the loop kernel.
Expand Down
57 changes: 55 additions & 2 deletions slothy/targets/aarch64/aarch64_neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class which generates instruction parsers and writers from instruction templates
from unicorn.arm64_const import *

from slothy.targets.common import *
from slothy.helper import Loop, LLVM_Mc
from slothy.helper import Loop, LLVM_Mc, SourceLine

arch_name = "Arm_AArch64"

Expand Down Expand Up @@ -3219,6 +3219,12 @@ def core(inst,t, log=None):
q_ld2_lane_post_inc.global_parsing_cb = q_ld2_lane_post_inc_parsing_cb()

def eor3_fusion_cb():
"""
Example for a fusion call back. Allows to merge two eor instruction with
two inputs into one eor with three inputs. Such technique can help perform
transformations in case of differences between uArchs.
Note: This is not used in any real (crypto) example. This is merely a PoC.
"""
def core(inst,t,log=None):
succ = None

Expand Down Expand Up @@ -3275,7 +3281,54 @@ def core(inst,t,log=None):

return core

veor.global_fusion_cb = eor3_fusion_cb()
def eor3_splitting_cb():
mkannwischer marked this conversation as resolved.
Show resolved Hide resolved
"""
Example for a splitting call back. Allows to split one eor instruction with
three inputs into two eors with two inputs. Such technique can help perform
transformations in case of differences between uArchs.
Note: This is not used in any real (crypto) example. This is merely a PoC.
"""
def core(inst,t,log=None):

d = inst.args_out[0]
a = inst.args_in[0]
b = inst.args_in[1]
c = inst.args_in[2]

# Check if we can use the output as a temporary
if d in [a,b,c]:
return False

eor0 = AArch64Instruction.build(veor, { "Vd": d, "Va" : a, "Vb" : b,
"datatype0":"16b",
"datatype1":"16b",
"datatype2":"16b" })
eor1 = AArch64Instruction.build(veor, { "Vd": d, "Va" : d, "Vb" : c,
"datatype0":"16b",
"datatype1":"16b",
"datatype2":"16b" })

eor0_src = SourceLine(eor0.write()).\
add_tags(inst.source_line.tags).\
add_comments(inst.source_line.comments)
eor1_src = SourceLine(eor1.write()).\
add_tags(inst.source_line.tags).\
add_comments(inst.source_line.comments)

eor0.source_line = eor0_src
eor1.source_line = eor1_src

if log is not None:
log(f"EOR3 splitting: {t.inst}; {eor0} + {eor1}")

t.changed = True
t.inst = [eor0, eor1]
return True

return core

# Can alternatively set veor3.global_fusion_cb to eor3_fusion_cb() here
veor3.global_fusion_cb = eor3_splitting_cb()

def iter_aarch64_instructions():
yield from all_subclass_leaves(Instruction)
Expand Down
Loading