Skip to content

Commit

Permalink
[treewide] formatting: align all Python sources and fix import errors…
Browse files Browse the repository at this point in the history
… on init scripts
  • Loading branch information
Viviane Potocnik committed Aug 8, 2024
1 parent f28f88a commit 37dd3ab
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion sw/blas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

from . import gemm

__all__ = ['gemm']
__all__ = ['gemm']
4 changes: 3 additions & 1 deletion sw/blas/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
#
# Luca Colagrande <[email protected]>

from .scripts.datagen import *
from .scripts.datagen import GemmDataGen

__all__ = ['GemmDataGen']
8 changes: 7 additions & 1 deletion sw/dnn/flashattention_2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,11 @@
# SPDX-License-Identifier: Apache-2.0
#
# Luca Colagrande <[email protected]>
# Viviane Potocnik <[email protected]>

from .scripts.datagen import *
from .scripts.datagen import exact_golden_model, exact_flexfloat_golden_model, \
get_gemm_implementation, load_params, \
validate, emit_header

__all__ = ['exact_golden_model', 'exact_flexfloat_golden_model',
'get_gemm_implementation', 'load_params', 'validate', 'emit_header']
17 changes: 2 additions & 15 deletions sw/dnn/flashattention_2/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pyflexfloat as ff

from snitch.util.sim import data_utils
from snitch.util.sim.data_utils import DataGen, format_struct_definition, \
from snitch.util.sim.data_utils import format_struct_definition, \
format_array_definition, format_array_declaration, emit_license
from snitch.blas import gemm

Expand Down Expand Up @@ -135,18 +135,6 @@ def exact_flexfloat_golden_model(Q, K, V, B_r, B_c, desc):
return np.concatenate(O_tiles, 0)


def get_gemm_implementation(self, params):
prec = params['dtype'].lower()
impl = f'gemm_{prec}_'
if params['baseline']:
impl += 'naive'
else:
impl += 'opt'
if prec == 'fp8':
impl += '_ex'
return impl


def load_params(self, params):
self.L = params['L']
self.S = params['S']
Expand All @@ -162,9 +150,8 @@ def load_params(self, params):
self.ctype = data_utils.ctype_from_precision_t(self.dtype)
self.prec = data_utils.size_from_precision_t(self.dtype)

# Verify layer parameters are valid


# Verify layer parameters are valid
def validate(self):
assert (self.L % self.B_r) == 0, 'L is not an integer multiple of B_r'
assert (self.S % self.B_c) == 0, 'S is not an integer multiple of B_c'
Expand Down
4 changes: 3 additions & 1 deletion sw/dnn/fused_concat_linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
#
# Luca Colagrande <[email protected]>

from .scripts.datagen import *
from .scripts.datagen import golden_model, emit_header

__all__ = ['golden_model', 'emit_header']
2 changes: 1 addition & 1 deletion sw/dnn/fused_concat_linear/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch

from snitch.util.sim import data_utils
from snitch.util.sim.data_utils import DataGen, format_struct_definition, \
from snitch.util.sim.data_utils import format_struct_definition, \
format_array_definition, format_array_declaration, format_ifdef_wrapper, \
emit_license

Expand Down
2 changes: 1 addition & 1 deletion sw/dnn/gelu/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch

from snitch.util.sim import data_utils
from snitch.util.sim.data_utils import DataGen, format_struct_definition, \
from snitch.util.sim.data_utils import format_struct_definition, \
format_array_definition, format_array_declaration, format_ifdef_wrapper, \
emit_license

Expand Down
5 changes: 4 additions & 1 deletion sw/dnn/layernorm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@
#
# Luca Colagrande <[email protected]>

from .scripts.datagen import *
from .scripts.datagen import golden_model, golden_model_torch, \
validate_config, emit_header

__all__ = ['golden_model', 'golden_model_torch', 'validate_config', 'emit_header']
2 changes: 1 addition & 1 deletion util/trace/a2l.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, elf, a2l_binary='addr2line'):

@lru_cache(maxsize=1024)
def addr2line(self, addr):
if type(addr) == str:
if isinstance(addr, str):
addr = int(addr, 16)
cmd = f'{self.a2l} -e {self.elf} -f -i {addr:x}'
return Addr2LineOutput(os.popen(cmd).read())
6 changes: 3 additions & 3 deletions util/trace/layout_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main():
# Symbols must be added to globals to be used in list comprehensions
# see https://bugs.python.org/issue36300
tids = eval(code, {'cfg': cfg}, {'cfg': cfg})
if type(tids) == int:
if isinstance(tids, int):
tids = [tids]

# Iterate hart IDs
Expand All @@ -120,9 +120,9 @@ def main():
row_idx = tid
col_idx = 1 + reg_idx * 2
assert row_idx < df.shape[0], f'Hart ID {row_idx} out of bounds'
assert (col_idx + 1) < df.shape[1],\
assert (col_idx + 1) < df.shape[1], \
f'Region index {reg_idx} out of bounds for hart {tid}'
assert not isnan(df.iat[row_idx, col_idx]),\
assert not isnan(df.iat[row_idx, col_idx]), \
(f'Region {reg_idx} looks empty for hart {tid},'
f'check whether it was simulated')
orow.append(int(df.iat[row_idx, col_idx]))
Expand Down

0 comments on commit 37dd3ab

Please sign in to comment.