Skip to content

Commit

Permalink
py: Lint python files
Browse files Browse the repository at this point in the history
  • Loading branch information
fischeti committed Jan 3, 2024
1 parent 7257365 commit 7857c72
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 113 deletions.
49 changes: 26 additions & 23 deletions util/flit_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@
# Tim Fischer <[email protected]>

import argparse
import hjson
import pathlib
import math
import hjson
from jsonref import JsonRef
from mako.lookup import TemplateLookup

AXI_CHANNELS = ["aw", "w", "b", "ar", "r"]

templates = TemplateLookup(directories=[pathlib.Path(__file__).parent],
output_encoding="utf-8")
templates = TemplateLookup(directories=[pathlib.Path(__file__).parent], output_encoding="utf-8")


def clog2(x: int) -> int:
"""Compute the ceil of the log2 of x."""
return int(math.ceil(math.log(x, 2)))


def get_axi_chs(channel_mapping: dict, **kwargs) -> list:
def get_axi_chs(channel_mapping: dict, **_kwargs) -> list:
"""Return all the AXI channels."""
channels = []
for axi_chs in channel_mapping.values():
Expand All @@ -34,7 +33,7 @@ def get_axi_chs(channel_mapping: dict, **kwargs) -> list:
return channels


def get_inverted_mapping(channel_mapping: dict, **kwargs) -> dict:
def get_inverted_mapping(channel_mapping: dict, **_kwargs) -> dict:
"""Return the mapping of the link."""
mappings = {}
for phys_ch, ch_types in channel_mapping.items():
Expand All @@ -45,6 +44,7 @@ def get_inverted_mapping(channel_mapping: dict, **kwargs) -> dict:


def get_axi_channel_sizes(aw: int, dw: int, iw: int, uw: int) -> dict:
# pylint: disable=too-many-locals
"""Compute the AXI channel size in bits."""

# Constant widths
Expand All @@ -54,7 +54,7 @@ def get_axi_channel_sizes(aw: int, dw: int, iw: int, uw: int) -> dict:
prot = 3
qos = 4
region = 4
len = 8
length = 8
size = 3
atop = 6
last = 1
Expand All @@ -64,27 +64,29 @@ def get_axi_channel_sizes(aw: int, dw: int, iw: int, uw: int) -> dict:
iw = max(iw.values())

axi_ch_size = {}
axi_ch_size["aw"] = iw + aw + len + size + burst + lock + cache + \
prot + qos + region + atop + uw
axi_ch_size["w"] = dw + dw//8 + last + uw
axi_ch_size["aw"] = (
iw + aw + length + size + burst + lock + cache + prot + qos + region + atop + uw
)
axi_ch_size["w"] = dw + dw // 8 + last + uw
axi_ch_size["b"] = iw + resp + uw
axi_ch_size["ar"] = iw + aw + len + size + burst + lock + cache + \
prot + qos + region + uw
axi_ch_size["ar"] = iw + aw + length + size + burst + lock + cache + prot + qos + region + uw
axi_ch_size["r"] = iw + dw + resp + last + uw

return axi_ch_size


def get_link_sizes(channel_mapping: dict, protocols: list, **kwargs) -> dict:
def get_link_sizes(channel_mapping: dict, protocols: list, **_kwargs) -> dict:
"""Infer the link sizes AXI channels and the mapping."""
link_sizes = {}
for phys_ch, axi_chs in channel_mapping.items():
# Get all protocols that use this channel
used_protocols = [p for p in protocols if p['name'] in axi_chs and p['direction'] == 'input']
used_protocols = [
p for p in protocols if p["name"] in axi_chs and p["direction"] == "input"
]
# Get only the exact AXI channels that are used by the link
used_axi_chs = [axi_chs[p['name']] for p in used_protocols]
used_axi_chs = [axi_chs[p["name"]] for p in used_protocols]
# Get the sizes of the AXI channels
axi_ch_sizes = [get_axi_channel_sizes(**p['params']) for p in used_protocols]
axi_ch_sizes = [get_axi_channel_sizes(**p["params"]) for p in used_protocols]
link_message_sizes = []
for used_axi_ch, axi_ch_size in zip(used_axi_chs, axi_ch_sizes):
link_message_sizes += [axi_ch_size[ch] for ch in used_axi_ch]
Expand All @@ -96,21 +98,22 @@ def get_link_sizes(channel_mapping: dict, protocols: list, **kwargs) -> dict:
def main():
"""Generate a flit packet package."""

parser = argparse.ArgumentParser(
description="Generate flit files for a given configuration")
parser.add_argument("--config", "-c", type=pathlib.Path, required=True, help="Path to the config file")
parser = argparse.ArgumentParser(description="Generate flit files for a given configuration")
parser.add_argument(
"--config", "-c", type=pathlib.Path, required=True, help="Path to the config file"
)

args = parser.parse_args()

# Read HJSON description of System.
with open(args.config, "r") as f:
with open(args.config, "r", encoding="utf-8") as f:
cfg = JsonRef.replace_refs(hjson.load(f))

kwargs = cfg
kwargs['axi_channels'] = get_axi_chs(**kwargs)
kwargs['inv_map'] = get_inverted_mapping(**kwargs)
kwargs['get_axi_channel_sizes'] = get_axi_channel_sizes
kwargs['link_sizes'] = get_link_sizes(**kwargs)
kwargs["axi_channels"] = get_axi_chs(**kwargs)
kwargs["inv_map"] = get_inverted_mapping(**kwargs)
kwargs["get_axi_channel_sizes"] = get_axi_channel_sizes
kwargs["link_sizes"] = get_link_sizes(**kwargs)

tpl = templates.get_template("floo_flit_pkg.sv.mako")
print(tpl.render_unicode(**kwargs))
Expand Down
189 changes: 99 additions & 90 deletions util/gen_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NUM_X = 4
NUM_Y = 4

data_widths = {'wide': 512, 'narrow': 64}
data_widths = {"wide": 512, "narrow": 64}


def clog2(x: int):
Expand All @@ -24,18 +24,21 @@ def clog2(x: int):

def get_xy_base_addr(x: int, y: int):
"""Get the address of a tile in the mesh."""
assert (x <= NUM_X and y <= NUM_Y)
return (x + 2**clog2(NUM_X+1)*y)*MEM_SIZE


def gen_job_str(length: int,
src_addr: int,
dst_addr: int,
max_src_burst_size: int = 256,
max_dst_burst_size: int = 256,
r_aw_decouple: bool = False,
r_w_decouple: bool = False,
num_errors: int = 0):
assert x <= NUM_X and y <= NUM_Y
return (x + 2 ** clog2(NUM_X + 1) * y) * MEM_SIZE


def gen_job_str(
length: int,
src_addr: int,
dst_addr: int,
max_src_burst_size: int = 256,
max_dst_burst_size: int = 256,
r_aw_decouple: bool = False,
r_w_decouple: bool = False,
num_errors: int = 0,
):
# pylint: disable=too-many-arguments
"""Generate a single job."""
job_str = ""
job_str += f"{int(length)}\n"
Expand All @@ -49,137 +52,143 @@ def gen_job_str(length: int,
return job_str


def emit_jobs(jobs, out_dir, name, id):
def emit_jobs(jobs, out_dir, name, idx):
"""Emit jobs to file."""
# Generate directory if it does not exist
if not os.path.exists(out_dir):
os.makedirs(out_dir)
with open(f'{out_dir}/{name}_{id}.txt', 'w', encoding='utf-8') as job_file:
with open(f"{out_dir}/{name}_{idx}.txt", "w", encoding="utf-8") as job_file:
job_file.write(jobs)
job_file.close()


def gen_chimney2chimney_traffic(narrow_burst_length: int = 16,
num_narrow_bursts: int = 16,
rw: str = 'write',
bidir: bool = False,
out_dir: str = 'jobs',
**kwargs):
def gen_chimney2chimney_traffic(
narrow_burst_length: int = 16,
num_narrow_bursts: int = 16,
rw: str = "write",
bidir: bool = False,
out_dir: str = "jobs"
):
"""Generate Chimney to Chimney traffic."""
num_masters = 2
for i in range(num_masters):
jobs = ""
if bidir or i == 0:
for j in range(num_narrow_bursts):
length = narrow_burst_length*data_widths['narrow']/8
assert (length <= MEM_SIZE)
src_addr = 0 if rw == 'write' else MEM_SIZE
dst_addr = MEM_SIZE if rw == 'write' else 0
for _ in range(num_narrow_bursts):
length = narrow_burst_length * data_widths["narrow"] / 8
assert length <= MEM_SIZE
src_addr = 0 if rw == "write" else MEM_SIZE
dst_addr = MEM_SIZE if rw == "write" else 0
job_str = gen_job_str(length, src_addr, dst_addr)
jobs += job_str
emit_jobs(jobs, out_dir, 'chimney2chimney', i)


def gen_nw_chimney2chimney_traffic(narrow_burst_length: int,
wide_burst_length: int,
num_narrow_bursts: int,
num_wide_bursts: int,
rw: str,
bidir: bool,
out_dir: str,
**kwargs):
emit_jobs(jobs, out_dir, "chimney2chimney", i)


def gen_nw_chimney2chimney_traffic(
narrow_burst_length: int,
wide_burst_length: int,
num_narrow_bursts: int,
num_wide_bursts: int,
rw: str,
bidir: bool,
out_dir: str
):
# pylint: disable=too-many-arguments
"""Generate Narrow Wide Chimney to Chimney traffic."""
num_masters = 2
for i in range(num_masters):
wide_jobs = ""
narrow_jobs = ""
wide_length = wide_burst_length*data_widths['wide']/8
narrow_length = narrow_burst_length*data_widths['narrow']/8
assert (wide_length <= MEM_SIZE and narrow_length <= MEM_SIZE)
src_addr = 0 if rw == 'write' else MEM_SIZE
dst_addr = MEM_SIZE if rw == 'write' else 0
wide_length = wide_burst_length * data_widths["wide"] / 8
narrow_length = narrow_burst_length * data_widths["narrow"] / 8
assert wide_length <= MEM_SIZE and narrow_length <= MEM_SIZE
src_addr = 0 if rw == "write" else MEM_SIZE
dst_addr = MEM_SIZE if rw == "write" else 0
if bidir or i == 0:
for j in range(num_wide_bursts):
for _ in range(num_wide_bursts):
wide_jobs += gen_job_str(wide_length, src_addr, dst_addr)
for j in range(num_narrow_bursts):
for _ in range(num_narrow_bursts):
narrow_jobs += gen_job_str(narrow_length, src_addr, dst_addr)
emit_jobs(wide_jobs, out_dir, 'nw_chimney2chimney', i)
emit_jobs(narrow_jobs, out_dir, 'nw_chimney2chimney', i+100)


def gen_mesh_traffic(narrow_burst_length: int,
wide_burst_length: int,
num_narrow_bursts: int,
num_wide_bursts: int,
rw: str,
type: str,
out_dir: str,
**kwargs):
emit_jobs(wide_jobs, out_dir, "nw_chimney2chimney", i)
emit_jobs(narrow_jobs, out_dir, "nw_chimney2chimney", i + 100)


def gen_mesh_traffic(
narrow_burst_length: int,
wide_burst_length: int,
num_narrow_bursts: int,
num_wide_bursts: int,
rw: str,
traffic_type: str,
out_dir: str
):
# pylint: disable=too-many-arguments, too-many-locals
"""Generate Mesh traffic."""
for x in range(1, NUM_X+1):
for y in range(1, NUM_Y+1):
for x in range(1, NUM_X + 1):
for y in range(1, NUM_Y + 1):
wide_jobs = ""
narrow_jobs = ""
wide_length = wide_burst_length*data_widths['wide']/8
narrow_length = narrow_burst_length*data_widths['narrow']/8
assert (wide_length <= MEM_SIZE and narrow_length <= MEM_SIZE)
if type == 'hbm':
wide_length = wide_burst_length * data_widths["wide"] / 8
narrow_length = narrow_burst_length * data_widths["narrow"] / 8
assert wide_length <= MEM_SIZE and narrow_length <= MEM_SIZE
if traffic_type == "hbm":
# Tile x=0 are the HBM channels
# Each core read from the channel of its y coordinate
hbm_addr = get_xy_base_addr(0, y)
local_addr = get_xy_base_addr(x, y)
src_addr = hbm_addr if rw == 'read' else local_addr
dst_addr = local_addr if rw == 'read' else hbm_addr
elif type == 'random':
src_addr = hbm_addr if rw == "read" else local_addr
dst_addr = local_addr if rw == "read" else hbm_addr
elif traffic_type == "random":
local_addr = get_xy_base_addr(x, y)
ext_addr = get_xy_base_addr(random.randint(1, NUM_X), random.randint(1, NUM_Y))
src_addr = ext_addr if rw == 'read' else local_addr
dst_addr = local_addr if rw == 'read' else ext_addr
elif type == 'onehop':
src_addr = ext_addr if rw == "read" else local_addr
dst_addr = local_addr if rw == "read" else ext_addr
elif traffic_type == "onehop":
if not (x == 1 and y == 1):
wide_length = 0
narrow_length = 0
src_addr = 0
dst_addr = 0
else:
local_addr = get_xy_base_addr(x, y)
ext_addr = get_xy_base_addr(x, y+1)
src_addr = ext_addr if rw == 'read' else local_addr
dst_addr = local_addr if rw == 'read' else ext_addr
ext_addr = get_xy_base_addr(x, y + 1)
src_addr = ext_addr if rw == "read" else local_addr
dst_addr = local_addr if rw == "read" else ext_addr
else:
raise ValueError(f'Unknown traffic type: {type}')
for j in range(num_wide_bursts):
raise ValueError(f"Unknown traffic type: {traffic_type}")
for _ in range(num_wide_bursts):
wide_jobs += gen_job_str(wide_length, src_addr, dst_addr)
for j in range(num_narrow_bursts):
for _ in range(num_narrow_bursts):
narrow_jobs += gen_job_str(narrow_length, src_addr, dst_addr)
emit_jobs(wide_jobs, out_dir, 'mesh', x + (y-1)*NUM_X)
emit_jobs(narrow_jobs, out_dir, 'mesh', x + (y-1)*NUM_X + 100)
emit_jobs(wide_jobs, out_dir, "mesh", x + (y - 1) * NUM_X)
emit_jobs(narrow_jobs, out_dir, "mesh", x + (y - 1) * NUM_X + 100)


def main():

"""Main function."""
# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--out_dir', type=str, default='test/jobs')
parser.add_argument('--num_narrow_bursts', type=int, default=10)
parser.add_argument('--num_wide_bursts', type=int, default=100)
parser.add_argument('--narrow_burst_length', type=int, default=1)
parser.add_argument('--wide_burst_length', type=int, default=16)
parser.add_argument('--bidir', action='store_true')
parser.add_argument('--tb', type=str, default='dma_mesh')
parser.add_argument('--type', type=str, default='random')
parser.add_argument('--rw', type=str, default='read')
parser.add_argument("--out_dir", type=str, default="test/jobs")
parser.add_argument("--num_narrow_bursts", type=int, default=10)
parser.add_argument("--num_wide_bursts", type=int, default=100)
parser.add_argument("--narrow_burst_length", type=int, default=1)
parser.add_argument("--wide_burst_length", type=int, default=16)
parser.add_argument("--bidir", action="store_true")
parser.add_argument("--tb", type=str, default="dma_mesh")
parser.add_argument("--type", type=str, default="random")
parser.add_argument("--rw", type=str, default="read")
args = parser.parse_args()

kwargs = vars(args)

if args.tb == 'chimney2chimney':
if args.tb == "chimney2chimney":
gen_chimney2chimney_traffic(**kwargs)
elif args.tb == 'nw_chimney2chimney':
elif args.tb == "nw_chimney2chimney":
gen_nw_chimney2chimney_traffic(**kwargs)
elif args.tb == 'dma_mesh':
elif args.tb == "dma_mesh":
gen_mesh_traffic(**kwargs)
else:
raise ValueError(f'Unknown testbench: {args.tb}')
raise ValueError(f"Unknown testbench: {args.tb}")


if __name__ == "__main__":
Expand Down

0 comments on commit 7857c72

Please sign in to comment.