Skip to content

Commit

Permalink
More pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
hanno-becker committed Dec 23, 2023
1 parent f5baf3c commit 62a7d57
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 60 deletions.
6 changes: 5 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@
Target_CortexM55r1: "m55",
Target_CortexM85r1: "m85"}

class ExampleException(Exception):
"""Exception thrown when an example goes wrong"""

class Example():
"""Common boilerplate for SLOTHY examples"""

def __init__(self, infile, name=None, funcname=None, suffix="opt",
rename=False, outfile="", arch=Arch_Armv81M, target=Target_CortexM55r1,
**kwargs):
Expand Down Expand Up @@ -1128,7 +1132,7 @@ def run_example(name, debug=False):
ex = e
break
if ex is None:
raise Exception(f"Could not find example {name}")
raise ExampleException(f"Could not find example {name}")
ex.run(debug=debug)

for e in todo:
Expand Down
114 changes: 71 additions & 43 deletions examples/misc/gen_roots.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,37 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import math, sys
"""Helper script for the generation of twiddle factors for various NTTs"""

import math

class NttRootGenInvalidParameters(Exception):
"""Invalid parameters for NTT root generation"""

class NttRootGen():
"""Helper class for the generation of NTT twiddle factors"""

def __init__(self,*,
size,
modulus,
root,
layers,
print_label=False,
pad = [],
bitsize = 16,
inverse = False,
print_label = False,
pad = None,
bitsize = 16,
inverse = False,
vector_length = 128,
word_offset_mod_4=None,
word_offset_mod_4 = None,
incomplete_root = True,
widen_single_twiddles_to_words=True,
block_strided_twiddles=True,
negacyclic = True,
widen_single_twiddles_to_words = True,
block_strided_twiddles = True,
negacyclic = True,
iters = None):
if pad is None:
pad = []

assert bitsize in [16,32]
if bitsize not in [16, 32]:
raise NttRootGenInvalidParameters("Invalid bit width")

self.pad = pad
self.print_label=print_label
Expand Down Expand Up @@ -76,7 +85,7 @@ def __init__(self,*,

# Need an odd prime modulus
if self.modulus % 2 == 0:
raise Exception("Modulus must be odd")
raise NttRootGenInvalidParameters("Modulus must be odd")
self._inv_mod = pow(self.modulus,-1,2**self.bitsize)

# Check that we've indeed been given a root of unity of the correct order
Expand All @@ -87,12 +96,12 @@ def __init__(self,*,

self.log2size = int(math.log(size,2))
if size != pow(2,self.log2size):
raise Exception(f"Size {size} not a power of 2")
raise NttRootGenInvalidParameters(f"Size {size} not a power of 2")

self.layers = layers
self.incompleteness_factor = 2**(self.log2size - self.layers)

if iters == None:
if iters is None:
if self.layers % 2 == 0:
self.iters = [(x,2) for x in range(0,self.layers,2)]
else:
Expand All @@ -106,18 +115,22 @@ def __init__(self,*,

if ( pow(root, real_root_order, modulus) != 1 or
pow(root, real_root_order // 2, modulus) == 1 ):
raise Exception(f"{root} is not a primitive {real_root_order}-th root of unity modulo {modulus}")
raise NttRootGenInvalidParameters(f"{root} is not a primitive {real_root_order}-th "
f"root of unity modulo {modulus}")

self.radixes = [2] * self.log2size

def get_root_pow(self, exp):
"""Returns specific power of base root of unity"""

if not exp % self.incompleteness_factor == 0:
raise Exception(f"Invalid exponent {exp} for incompleteness factor {self.incompleteness_factor}")
raise NttRootGenInvalidParameters(f"Invalid exponent {exp} for incompleteness "
f"factor {self.incompleteness_factor}")
if self.incomplete_root:
exp = exp // self.incompleteness_factor
return pow(self.root,exp,self.modulus)

def _prepare_root(self,root,layer=None):
def _prepare_root(self,root):

# Force _signed_ representation of root?
if root > self.modulus // 2:
Expand All @@ -143,6 +156,8 @@ def _bitrev_list(self,num,radix_list):
return result

def root_of_unity_for_block(self,layer,block):
"""Returns the twiddle factor to be used for a specific layer and block"""

actual_layer = layer
if self.negacyclic:
block += pow(2,layer)
Expand All @@ -157,10 +172,14 @@ def root_of_unity_for_block(self,layer,block):
if self.inverse:
log = (self.root_order - log) % self.root_order
root = self.get_root_pow(log)
root, root_twisted = self._prepare_root(root,layer)
root, root_twisted = self._prepare_root(root)
return root, root_twisted

def roots_of_unity_for_layer_core(self, layer, merged):
def _roots_of_unity_for_layer_core(self, layer, merged):

if not merged in [1,2,3,4]:
raise NttRootGenInvalidParameters("Invalid layer merge")

for cur_block in range(0,2**layer):
if merged == 1:
root, root_twisted = self.root_of_unity_for_block(layer, cur_block)
Expand Down Expand Up @@ -204,11 +223,13 @@ def roots_of_unity_for_layer_core(self, layer, merged):

if layer in self.pad:
yield ([root0, root1, root2, root3, root4, root5, root6, 0],
[root0_tw, root1_tw, root2_tw, root3_tw, root4_tw, root5_tw, root6_tw, 0])
[root0_tw, root1_tw, root2_tw, root3_tw, root4_tw,
root5_tw, root6_tw, 0])
else:
yield ([root0, root1, root2, root3, root4, root5, root6],
[root0_tw, root1_tw, root2_tw, root3_tw, root4_tw, root5_tw, root6_tw])
elif merged == 4:
else:
assert merged == 4
# Compute the roots of unity that we need at this stage
fst_layer = layer + 0
snd_layer = layer + 1
Expand Down Expand Up @@ -261,10 +282,10 @@ def roots_of_unity_for_layer_core(self, layer, merged):
root5_tw, root6_tw, root7_tw, root8_tw, root9_tw,
root10_tw, root11_tw, root12_tw, root13_tw,
root14_tw])
else:
raise Exception("Something went wrong")

def roots_of_unity_for_layer(self, layer, merged):
"""Generator yielding the twiddle factors for a number of merged layers"""

num_blocks = 2 ** layer
block_size = self.size // num_blocks
butterfly_size = block_size // 2 ** merged
Expand All @@ -273,7 +294,7 @@ def roots_of_unity_for_layer(self, layer, merged):
if butterfly_size < self.vector_length // self.bitsize:
stride = (self.vector_length // self.bitsize) // butterfly_size

all_root_pairs = list(self.roots_of_unity_for_layer_core(layer, merged))
all_root_pairs = list(self._roots_of_unity_for_layer_core(layer, merged))
all_roots = [ x[0] for x in all_root_pairs ]
all_roots_twisted = [ x[1] for x in all_root_pairs ]
num_pairs = len(all_root_pairs)
Expand All @@ -288,7 +309,8 @@ def roots_of_unity_for_layer(self, layer, merged):
res = [(z,stride) for x in roots for y in x for z in y]
yield from res

def get_roots_of_unity_core(self):
def get_roots_of_unity(self):
"""Yields roots of unity for NTT"""
iters = self.iters.copy()
if self.inverse:
iters.reverse()
Expand All @@ -297,17 +319,17 @@ def get_roots_of_unity_core(self):
yield f"roots_l{''.join([str(i) for i in range(cur_iter,cur_iter+merged)])}:"
yield from self.roots_of_unity_for_layer(cur_iter,merged)

def get_roots_of_unity_real(self):
def _get_roots_of_unity_asm(self):
"""Yields roots of unity """
if self.bitsize == 16:
twiddlesize = "short"
elif self.bitsize == 32:
twiddlesize = "word"
else:
raise Exception("Should not happen")
assert self.bitsize == 32
twiddlesize = "word"

count = 0
last_stride = None
for x in self.get_roots_of_unity_core():
for x in self.get_roots_of_unity():
if isinstance(x,str):
yield x
continue
Expand All @@ -320,7 +342,7 @@ def get_roots_of_unity_real(self):
count += 1
if stride > 1:
if last_stride == 1:
if self.word_offset_mod_4 != None:
if self.word_offset_mod_4 is not None:
yield f"// Word count until here: {count}"
cc4 = count % 4
diff = self.word_offset_mod_4 - cc4
Expand All @@ -338,10 +360,13 @@ def get_roots_of_unity_real(self):
last_stride = stride

def export(self, filename):
license = """
"""Export twiddle factors as file"""

license_text = """
///
/// Copyright (c) 2022 Arm Limited
/// Copyright (c) 2022 Hanno Becker
/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer
/// SPDX-License-Identifier: MIT
///
/// Permission is hereby granted, free of charge, to any person obtaining a copy
Expand All @@ -364,18 +389,19 @@ def export(self, filename):
///
"""
f = open(filename,"w")
f.write(license)
f.write('\n'.join(self.get_roots_of_unity_real()))
f.close()
with open(filename, "w", encoding="utf-8") as f:
f.write(license_text)
f.write('\n'.join(self._get_roots_of_unity_asm()))

def main():
def _main():

ntt_kyber_l345 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,2),(2,3),(5,2)], word_offset_mod_4=2)
ntt_kyber_l345 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,2),(2,3),(5,2)],
word_offset_mod_4=2)
ntt_kyber_l345.export("../naive/ntt_kyber_12_345_67_twiddles.s")
ntt_kyber_l345.export("../opt/ntt_kyber_12_345_67_twiddles.s")

ntt_kyber_l123 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,3),(3,2),(5,2)], pad=[0,3], print_label=True, widen_single_twiddles_to_words=False)
ntt_kyber_l123 = NttRootGen(size=256,modulus=3329,root=17,layers=7,iters=[(0,3),(3,2),(5,2)],
pad=[0,3], print_label=True, widen_single_twiddles_to_words=False)
ntt_kyber_l123.export("../naive/ntt_kyber_123_45_67_twiddles.s")
ntt_kyber_l123.export("../opt/ntt_kyber_123_45_67_twiddles.s")

Expand All @@ -387,11 +413,13 @@ def main():
intt_kyber.export("../naive/intt_kyber_1_23_45_67_twiddles.s")
intt_kyber.export("../opt/intt_kyber_1_23_45_67_twiddles.s")

ntt_dilithium = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8, word_offset_mod_4=2)
ntt_dilithium = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8,
word_offset_mod_4=2)
ntt_dilithium.export("../naive/ntt_dilithium_12_34_56_78_twiddles.s")
ntt_dilithium.export("../opt/ntt_dilithium_12_34_56_78_twiddles.s")

ntt_dilithium_l1234 = NttRootGen(size=256, bitsize=32, modulus=8380417, root=1753, layers=8, iters=[(0,4),(4,2),(6,2)], pad=[0], print_label=True)
ntt_dilithium_l1234 = NttRootGen(size=256, bitsize=32, modulus=8380417, root=1753,
layers=8, iters=[(0,4),(4,2),(6,2)], pad=[0], print_label=True)
ntt_dilithium_l1234.export("../naive/aarch64/ntt_dilithium_1234_5678_twiddles.s")
ntt_dilithium_l1234.export("../opt/aarch64/ntt_dilithium_1234_5678_twiddles.s")

Expand All @@ -400,8 +428,8 @@ def main():
ntt_dilithium_l123.export("../naive/ntt_dilithium_123_456_78_twiddles.s")
ntt_dilithium_l123.export("../opt/ntt_dilithium_123_456_78_twiddles.s")

ntt_dilithium_l123 = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8, print_label=True, pad=[0,3],
iters=[(0,3),(3,3),(6,2)])
ntt_dilithium_l123 = NttRootGen(size=256,bitsize=32,modulus=8380417,root=1753,layers=8,
print_label=True, pad=[0,3], iters=[(0,3),(3,3),(6,2)])
ntt_dilithium_l123.export("../naive/aarch64/ntt_dilithium_123_456_78_twiddles.s")
ntt_dilithium_l123.export("../opt/aarch64/ntt_dilithium_123_456_78_twiddles.s")

Expand Down Expand Up @@ -448,4 +476,4 @@ def main():
intt_n256_s32_l8_test.export("../opt/intt_n256_l8_s32_twiddles.s")

if __name__ == "__main__":
main()
_main()
1 change: 1 addition & 0 deletions slothy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from slothy.core.slothy import Slothy
from slothy.core.core import SlothyException
from slothy.core.config import Config
from slothy.targets.query import Archery
26 changes: 10 additions & 16 deletions slothy/core/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from slothy.core.dataflow import DataFlowGraph as DFG
from slothy.core.dataflow import Config as DFGConfig, ComputationNode
from slothy.core.core import SlothyBase, Result
from slothy.core.core import SlothyBase, Result, SlothyException
from slothy.helper import Permutation, SourceLine
from slothy.helper import binary_search, BinarySearchLimitException

Expand Down Expand Up @@ -175,7 +175,7 @@ def optimize_binsearch_external(source, logger, conf, flexible=True, **kwargs):
if not flexible:
core = SlothyBase(conf.arch, conf.target, logger=logger,config=conf)
if not core.optimize(source):
raise Exception("Optimization failed")
raise SlothyException("Optimization failed")
return core.result

logger.info("Perform external binary search for minimal number of stalls...")
Expand Down Expand Up @@ -240,7 +240,7 @@ def optimize_binsearch_internal(source, logger, conf, **kwargs):
cur_attempt = max(1,cur_attempt * 2)
if cur_attempt > conf.constraints.stalls_maximum_attempt:
logger.error("Exceeded stall limit without finding a working solution")
raise Exception("No solution found")
raise SlothyException("No solution found")

logger.info(f"Minimum number of stalls: {min_stalls}")

Expand Down Expand Up @@ -370,11 +370,12 @@ def linear(body, logger, conf):
conf: The configuration to be applied. Software pipelining must be disabled.
Raises:
Raises an exception if software pipelining is enabled.
Raises a SlothyException if software pipelining is enabled.
"""
assert SourceLine.is_source(body)
if conf.sw_pipelining.enabled:
raise Exception("Linear heuristic should only be called with SW pipelining disabled")
raise SlothyException("Linear heuristic should only be called "
"with SW pipelining disabled")

Heuristics._dump("Starting linear optimization...", body, logger)

Expand Down Expand Up @@ -466,8 +467,8 @@ def pick_candidate(candidate_idxs):
logger.debug("Candidate %s: %s", depth_str, candidate_depths)
choice_idx = candidate_idxs[candidate_depths.index(min(candidate_depths))]

elif strategy == "alternate_functional_units":

else:
assert strategy == "alternate_functional_units"
def flatten_units(units):
res = []
for u in units:
Expand Down Expand Up @@ -504,9 +505,6 @@ def units_different(a,b):
for i,d in enumerate(candidate_depths) if d == min_depth ]
choice_idx = random.choice(refined_candidates)

else:
raise Exception("Unknown preprocessing strategy")

return choice_idx

def move_entry_forward(lst, idx_from, idx_to):
Expand All @@ -516,12 +514,8 @@ def move_entry_forward(lst, idx_from, idx_to):

choice_idx = None
while choice_idx is None:
try:
choice_idx = pick_candidate(candidate_idxs)
insts = move_entry_forward(insts, choice_idx, i)
except:
candidate_idxs.remove(choice_idx)
choice_idx = None
choice_idx = pick_candidate(candidate_idxs)
insts = move_entry_forward(insts, choice_idx, i)

local_perm = Permutation.permutation_move_entry_forward(l, choice_idx, i)
perm = Permutation.permutation_comp (local_perm, perm)
Expand Down

0 comments on commit 62a7d57

Please sign in to comment.