Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[treewide] refactor: replace sys.path.append() with pyproject.toml configuration #174

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/build-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
name: build-docker
on:
push:
branches: [main]
branches: [jssc/double-buffering]
pull_request:
workflow_dispatch:
jobs:
build-docker:
Expand Down Expand Up @@ -39,6 +40,6 @@ jobs:
context: .
file: util/container/Dockerfile
push: true
tags: ghcr.io/pulp-platform/snitch_cluster:${{ github.ref_name }}
tags: ghcr.io/pulp-platform/snitch_cluster:double-buffering
build-args: |-
SNITCH_LLVM_VERSION=latest
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
name: Build documentation
runs-on: ubuntu-22.04
container:
image: ghcr.io/pulp-platform/snitch_cluster:main
image: ghcr.io/pulp-platform/snitch_cluster:double-buffering
steps:
- uses: actions/checkout@v2
- name: Build docs
Expand All @@ -29,7 +29,7 @@ jobs:
name: Simulate SW on Snitch Cluster w/ Verilator
runs-on: ubuntu-22.04
container:
image: ghcr.io/pulp-platform/snitch_cluster:main
image: ghcr.io/pulp-platform/snitch_cluster:double-buffering
steps:
- uses: actions/checkout@v2
with:
Expand Down Expand Up @@ -61,7 +61,7 @@ jobs:
name: Simulate SW on Snitch Cluster w/ Banshee
runs-on: ubuntu-22.04
container:
image: ghcr.io/pulp-platform/snitch_cluster:main
image: ghcr.io/pulp-platform/snitch_cluster:double-buffering
steps:
- uses: actions/checkout@v2
with:
Expand Down
3 changes: 3 additions & 0 deletions iis-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ mkdir tmp
TMPDIR=tmp pip install -r python-requirements.txt
rm -rf tmp

# Install local packages in editable mode.
pip install -e .

# Bender initialization
$BENDER vendor init

Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2023 ETH Zurich and University of Bologna.
# Solderpad Hardware License, Version 0.51, see LICENSE for details.
# SPDX-License-Identifier: SHL-0.51

[build-system]
requires = ["setuptools>=42"]
build-backend = "setuptools.build_meta"

[project]
name = "snitch"
authors = [
{name = "Luca Colagrande", email = "[email protected]"}
]
dynamic = ["version"]

[tool.setuptools.package-dir]
"snitch.dnn" = "sw/dnn"
"snitch.blas" = "sw/blas"
"snitch.util" = "util"
7 changes: 2 additions & 5 deletions sw/apps/atax/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
# Luca Colagrande <[email protected]>

import numpy as np
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen # noqa: E402
from snitch.util.sim.data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen


# AXI splits bursts crossing 4KB address boundaries. To minimize
Expand Down
4 changes: 1 addition & 3 deletions sw/apps/atax/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@

import numpy as np
import sys
from pathlib import Path
from datagen import AtaxDataGen

sys.path.append(str(Path(__file__).parent / '../../../util/sim/'))
from verif_utils import Verifier # noqa: E402
from snitch.util.sim.verif_utils import Verifier


class AtaxVerifier(Verifier):
Expand Down
7 changes: 2 additions & 5 deletions sw/apps/correlation/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
# Luca Colagrande <[email protected]>

import numpy as np
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen # noqa: E402
from snitch.util.sim.data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen


# AXI splits bursts crossing 4KB address boundaries. To minimize
Expand Down
4 changes: 1 addition & 3 deletions sw/apps/correlation/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@

import numpy as np
import sys
from pathlib import Path
from datagen import CorrelationDataGen

sys.path.append(str(Path(__file__).parent / '../../../util/sim/'))
from verif_utils import Verifier # noqa: E402
from snitch.util.sim.verif_utils import Verifier


class CorrelationVerifier(Verifier):
Expand Down
7 changes: 2 additions & 5 deletions sw/apps/covariance/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
# Luca Colagrande <[email protected]>

import numpy as np
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen # noqa: E402
from snitch.util.sim.data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen


# AXI splits bursts crossing 4KB address boundaries. To minimize
Expand Down
4 changes: 1 addition & 3 deletions sw/apps/covariance/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@

import numpy as np
import sys
from pathlib import Path
from datagen import CovarianceDataGen

sys.path.append(str(Path(__file__).parent / '../../../util/sim/'))
from verif_utils import Verifier # noqa: E402
from snitch.util.sim.verif_utils import Verifier


class CovarianceVerifier(Verifier):
Expand Down
9 changes: 9 additions & 0 deletions sw/blas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2024 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Luca Colagrande <[email protected]>

from . import gemm

__all__ = ['gemm']
6 changes: 2 additions & 4 deletions sw/blas/axpy/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
# Author: Luca Colagrande <[email protected]>

import numpy as np
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen # noqa: E402
from snitch.util.sim.data_utils import format_scalar_definition, format_array_definition, \
format_array_declaration, format_ifdef_wrapper, DataGen


class AxpyDataGen(DataGen):
Expand Down
4 changes: 1 addition & 3 deletions sw/blas/axpy/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
# Luca Colagrande <[email protected]>

import sys
from pathlib import Path
from datagen import AxpyDataGen

sys.path.append(str(Path(__file__).parent / '../../../../util/sim/'))
from verif_utils import Verifier # noqa: E402
from snitch.util.sim.verif_utils import Verifier


class AxpyVerifier(Verifier):
Expand Down
6 changes: 2 additions & 4 deletions sw/blas/dot/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
from data_utils import format_scalar_definition, format_array_definition, \
format_scalar_declaration, format_ifdef_wrapper, DataGen # noqa: E402
from snitch.util.sim.data_utils import format_scalar_definition, format_array_definition, \
format_scalar_declaration, format_ifdef_wrapper, DataGen


class DotDataGen(DataGen):
Expand Down
4 changes: 1 addition & 3 deletions sw/blas/dot/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
# SPDX-License-Identifier: Apache-2.0

import sys
from pathlib import Path
from datagen import DotDataGen

sys.path.append(str(Path(__file__).parent / '../../../../util/sim/'))
from verif_utils import Verifier # noqa: E402
from snitch.util.sim.verif_utils import Verifier


class DotVerifier(Verifier):
Expand Down
4 changes: 2 additions & 2 deletions sw/blas/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
#
# Luca Colagrande <[email protected]>

from .scripts import datagen
from .scripts.datagen import GemmDataGen

__all__ = ['datagen']
__all__ = ['GemmDataGen']
100 changes: 54 additions & 46 deletions sw/blas/gemm/scripts/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
# Luca Colagrande <[email protected]>

import numpy as np
import os
import re
import pyflexfloat as ff
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../util/sim/"))
import data_utils # noqa: E402
from data_utils import DataGen, format_array_declaration, format_struct_definition, \
format_array_definition, format_ifdef_wrapper # noqa: E402
from snitch.util.sim import data_utils
from snitch.util.sim.data_utils import DataGen, format_array_declaration, \
format_struct_definition, format_array_definition, format_ifdef_wrapper


np.random.seed(42)
Expand Down Expand Up @@ -47,55 +45,49 @@ def infer_implementation(self, gemm_fp):
prec, impl = re.search(r'gemm_fp(\d+)_(\w+)', gemm_fp).group(1, 2)
return (int(prec) / 8), impl

def validate_config(self, gemm_fp, parallelize_m,
parallelize_k, m_tiles, n_tiles, k_tiles, transa,
transb, M, N, K, beta, **kwargs):
frac_m = M / m_tiles
frac_n = N / n_tiles
frac_k = K / k_tiles

dtype, impl = self.infer_implementation(gemm_fp)

# Calculate total TCDM occupation
# Note: doesn't account for double buffering
prec = data_utils.size_from_precision_t(dtype)
a_size = frac_m * frac_k * prec
b_size = frac_k * frac_n * prec
c_size = frac_m * frac_n * prec
def load_params(self, params):
self.M = params.get('M')
self.N = params.get('N')
self.K = params.get('K')
self.m_tiles = params.get('m_tiles')
self.n_tiles = params.get('n_tiles')
self.k_tiles = params.get('k_tiles')
self.load_a = params.get('load_a')
self.load_b = params.get('load_b')
self.load_c = params.get('load_c')
self.setup_ssr = params.get('setup_ssr')
self.parallelize_m = params.get('parallelize_m')
self.parallelize_k = params.get('parallelize_k')
self.gemm_fp = params.get('gemm_fp')
self.transa = params.get('transa')
self.transb = params.get('transb')
self.alpha = params.get('alpha', 1)
self.beta = params.get('beta')
self.section = params.get('section')
self.dtype, self.impl = self.infer_implementation(self.gemm_fp)
self.prec = data_utils.size_from_precision_t(self.dtype)
self.ff_desc = data_utils.ff_desc_from_precision_t(self.dtype)
self.ctype = data_utils.ctype_from_precision_t(self.dtype)

def validate(self):
frac_m = self.M / self.m_tiles
frac_n = self.N / self.n_tiles
frac_k = self.K / self.k_tiles

a_size = frac_m * frac_k * self.prec
b_size = frac_k * frac_n * self.prec
c_size = frac_m * frac_n * self.prec
total_size = a_size
total_size += b_size
total_size += c_size
data_utils.validate_tcdm_footprint(total_size)

assert (M % m_tiles) == 0, 'M is not an integer multiple of tile size'
assert (N % n_tiles) == 0, 'N is not an integer multiple of tile size'
assert (K % k_tiles) == 0, 'K is not an integer multiple of tile size'
assert not (parallelize_m and parallelize_k), 'Cannot parallelize K and M simultaneously'
assert not transa, 'SIMD kernels don\'t support transposed A matrix'
assert (dtype == 8) or (impl == 'baseline') or (impl == 'naive') \
or transb, 'Optimized SIMD kernels only support transposed B matrix'
assert not transb or n_tiles == 1, 'Tiling in the N dimension not supported' \
' if B is transposed'
assert not transb or k_tiles == 1, 'Tiling in the K dimension not supported' \
' if B is transposed'
assert (impl == 'baseline') or (impl == 'naive') or frac_n >= 8, \
'N dimension of tile size must be greater or equal to the unrolling factor (8) ' \
'when using optimized kernels'
assert beta == 0 or beta == 1, 'Only values of 0 or 1 supported for beta'
assert not (dtype == 8 and impl == "baseline"), 'No baseline implemented' \
' for FP64 (switch to NAIVE)'
assert not (((dtype == 8) or (dtype == 4)) and impl == "opt_ex"), \
'Expanding GEMM kernels' \
' not supported for FP64 and FP32'
assert not (dtype == 1 and impl == "opt"), 'FP8 not supported in' \
' optimized implementation' \
' (switch to opt_ex)'

def emit_header(self, **kwargs):
header = [super().emit_header()]
self.load_params(kwargs)

# Validate parameters
self.validate_config(**kwargs)
self.validate()

M, N, K = kwargs['M'], kwargs['N'], kwargs['K']

Expand All @@ -119,7 +111,23 @@ def emit_header(self, **kwargs):

cfg = {
'prec': prec,
**kwargs,
'setup_ssr': kwargs['setup_ssr'],
'parallelize_m': kwargs['parallelize_m'],
'parallelize_k': kwargs['parallelize_k'],
'm_tiles': kwargs['m_tiles'],
'n_tiles': kwargs['n_tiles'],
'k_tiles': kwargs['k_tiles'],
'load_a': kwargs['load_a'],
'load_b': kwargs['load_b'],
'load_c': kwargs['load_c'],
'transa': kwargs['transa'],
'transb': kwargs['transb'],
'M': M,
'N': N,
'K': K,
'alpha': kwargs['alpha'],
'beta': kwargs['beta'],
'gemm_fp': kwargs['gemm_fp'],
'a': a_uid,
'b': b_uid,
'c': c_uid,
Expand Down
6 changes: 2 additions & 4 deletions sw/blas/gemm/scripts/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@

import numpy as np
import sys
from pathlib import Path
from datagen import GemmDataGen

sys.path.append(str(Path(__file__).parent / '../../../../util/sim/'))
from verif_utils import Verifier # noqa: E402
from data_utils import ctype_from_precision_t # noqa: E402
from snitch.util.sim.verif_utils import Verifier
from snitch.util.sim.data_utils import ctype_from_precision_t


class GemmVerifier(Verifier):
Expand Down
Loading
Loading