Skip to content

Commit

Permalink
Typing update (#87)
Browse files Browse the repository at this point in the history
* update pre-commit version

* move DiGraphWrapper from Stream to ZigZag

* fix typeguard issues

* disable pylint issue with super() method

* add mem sharing group to mem instance

* fix bug: shared_memory_group_id not stored in instance

* define class instance types in class body, to ensure they are included in stubs

* memory_sharing_list -> mem_sharing_list

* change open_yaml return type
  • Loading branch information
RobinGeens authored Oct 3, 2024
1 parent 5a9c141 commit cd78c23
Show file tree
Hide file tree
Showing 20 changed files with 290 additions and 182 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
- name: Set up Python '3.11'
uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
# hooks:
# - id: check-yaml
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
rev: v0.6.4
hooks:
- id: ruff
types_or: [python, pyi]
Expand All @@ -18,7 +18,7 @@ repos:
"120",
]
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.8.0
hooks:
- id: black
args: [--line-length, "120"]
4 changes: 1 addition & 3 deletions zigzag/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ def __jsonrepr__(self):
"layer": (
self.layer
if isinstance(self, CostModelEvaluation)
else self.layer_ids
if isinstance(self, CumulativeCME)
else None
else self.layer_ids if isinstance(self, CumulativeCME) else None
),
"spatial_mapping": (self.spatial_mapping_int if isinstance(self, CostModelEvaluation) else None),
"temporal_mapping": (self.temporal_mapping if isinstance(self, CostModelEvaluation) else None),
Expand Down
63 changes: 32 additions & 31 deletions zigzag/hardware/architecture/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,30 @@ class Core:
on top.
"""

id: int
operational_array: OperationalArrayABC
memory_hierarchy: MemoryHierarchy
dataflows: SpatialMapping | None
mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]]
mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]]
mem_size_dict: dict[MemoryOperand, list[int]]
mem_r_bw_dict: dict[MemoryOperand, list[int]]
mem_w_bw_dict: dict[MemoryOperand, list[int]]
mem_r_bw_min_dict: dict[MemoryOperand, list[int]]
mem_w_bw_min_dict: dict[MemoryOperand, list[int]]
mem_sharing_list: list[dict[MemoryOperand, int]]

def __init__(
self,
core_id: int,
operational_array: OperationalArrayABC,
memory_hierarchy: MemoryHierarchy,
dataflows: SpatialMapping | None = None,
):
self.id = core_id
self.id = core_id
self.operational_array = operational_array
self.memory_hierarchy = memory_hierarchy
self.mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] = {}

self.dataflows = dataflows
self.mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] = {}
self.mem_hierarchy_dict = {}

self.dataflows = dataflows
self.recalculate_memory_hierarchy_information()
Expand All @@ -40,63 +49,55 @@ def get_memory_level(self, mem_op: MemoryOperand, mem_lv: int) -> MemoryLevel:

def recalculate_memory_hierarchy_information(self):
self.__generate_memory_hierarchy_dict()
self.__generate_memory_sharing_list()
self.__generate_mem_sharing_list()

def __generate_memory_hierarchy_dict(self):
mem_operands = self.memory_hierarchy.nb_levels.keys()
mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] = {}
mem_size_dict: dict[MemoryOperand, list[int]] = {}
mem_r_bw_dict: dict[MemoryOperand, list[int]] = {}
mem_w_bw_dict: dict[MemoryOperand, list[int]] = {}
mem_r_bw_min_dict: dict[MemoryOperand, list[int]] = {}
mem_w_bw_min_dict: dict[MemoryOperand, list[int]] = {}
self.mem_hierarchy_dict = {}
self.mem_size_dict = {}
self.mem_r_bw_dict = {}
self.mem_w_bw_dict = {}
self.mem_r_bw_min_dict = {}
self.mem_w_bw_min_dict = {}
for mem_op in mem_operands:
mem_hierarchy_dict[mem_op] = [
self.mem_hierarchy_dict[mem_op] = [
node for node in self.memory_hierarchy.topological_sort() if mem_op in node.operands
]
mem_size_dict[mem_op] = [
self.mem_size_dict[mem_op] = [
node.memory_instance.size
for node in self.memory_hierarchy.topological_sort()
if mem_op in node.operands
]
mem_r_bw_dict[mem_op] = [
self.mem_r_bw_dict[mem_op] = [
node.memory_instance.r_bw
for node in self.memory_hierarchy.topological_sort()
if mem_op in node.operands
]
mem_w_bw_dict[mem_op] = [
self.mem_w_bw_dict[mem_op] = [
node.memory_instance.w_bw
for node in self.memory_hierarchy.topological_sort()
if mem_op in node.operands
]
mem_r_bw_min_dict[mem_op] = [
self.mem_r_bw_min_dict[mem_op] = [
node.memory_instance.r_bw_min
for node in self.memory_hierarchy.topological_sort()
if mem_op in node.operands
]
mem_w_bw_min_dict[mem_op] = [
self.mem_w_bw_min_dict[mem_op] = [
node.memory_instance.w_bw_min
for node in self.memory_hierarchy.topological_sort()
if mem_op in node.operands
]
self.mem_hierarchy_dict = mem_hierarchy_dict
self.mem_size_dict = mem_size_dict
self.mem_r_bw_dict = mem_r_bw_dict
self.mem_w_bw_dict = mem_w_bw_dict
self.mem_r_bw_min_dict = mem_r_bw_min_dict
self.mem_w_bw_min_dict = mem_w_bw_min_dict

def __generate_memory_sharing_list(self):

def __generate_mem_sharing_list(self):
"""! Generates a list of dictionary that indicates which operand's which memory levels are sharing the same
physical memory"""
memory_sharing_list: list[dict[MemoryOperand, int]] = []
self.mem_sharing_list = []
for mem_lv in self.mem_hierarchy_dict.values():
for mem in mem_lv:
operand_mem_share = mem.mem_level_of_operands
if len(operand_mem_share) > 1 and operand_mem_share not in memory_sharing_list:
memory_sharing_list.append(operand_mem_share)

self.mem_sharing_list = memory_sharing_list
if len(operand_mem_share) > 1 and operand_mem_share not in self.mem_sharing_list:
self.mem_sharing_list.append(operand_mem_share)

def get_top_memory_instance(self, mem_op: MemoryOperand) -> MemoryInstance:
if mem_op not in self.memory_hierarchy.get_operands():
Expand Down
51 changes: 19 additions & 32 deletions zigzag/hardware/architecture/memory_hierarchy.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
from collections import defaultdict
from typing import Any, Iterator

import networkx as nx
from networkx import DiGraph
from typing import Any

from zigzag.datatypes import MemoryOperand
from zigzag.hardware.architecture.memory_instance import MemoryInstance
from zigzag.hardware.architecture.memory_level import MemoryLevel, ServedMemDimensions
from zigzag.hardware.architecture.memory_port import PortAllocation
from zigzag.hardware.architecture.operational_array import OperationalArrayABC
from zigzag.utils import json_repr_handler
from zigzag.utils import DiGraphWrapper, json_repr_handler


class MemoryHierarchy(DiGraph):
class MemoryHierarchy(DiGraphWrapper[MemoryLevel]):
"""! Class that represents a memory hierarchy as a directed networkx graph.
The memory hierarchy graph is directed, with the root nodes representing the lowest level
in the memory hierarchy.
"""

name: str
operational_array: OperationalArrayABC
operansd: set[MemoryOperand]
nb_levels: dict[MemoryOperand, int]
mem_level_list: list[MemoryLevel]
memory_level_id: int

def __init__(
self,
operational_array: OperationalArrayABC,
Expand All @@ -31,7 +35,7 @@ def __init__(
@param nodes: a list of MemoryLevels. Entries need to be provided from lowest to highest memory level.
"""
super().__init__(**attr) # type: ignore
self.name: str = name
self.name = name # type: ignore
self.operational_array = operational_array
# Initialize the set that will store all memory operands
self.operands: set[MemoryOperand] = set()
Expand Down Expand Up @@ -71,7 +75,7 @@ def add_memory(
# Compute which memory level this is for all the operands
mem_level_of_operands: dict[MemoryOperand, int] = {}
for mem_op in operands:
nb_levels_so_far = len([node for node in self.memory_nodes if mem_op in node.operands])
nb_levels_so_far = len([node for node in self.node_list if mem_op in node.operands])
mem_level_of_operands[mem_op] = nb_levels_so_far

memory_level = MemoryLevel(
Expand All @@ -94,11 +98,11 @@ def add_memory(
to_edge_from.add(m)

# Add the node to the graph
self.__add_node(memory_level)
self.add_node(memory_level)

for sink_node in to_edge_from:
# Add an edge from this sink node to the current node
self.__add_edge(sink_node, memory_level)
self.add_edge(sink_node, memory_level)

def get_memory_levels(self, mem_op: MemoryOperand) -> list[MemoryLevel]:
"""! Returns a list of memories in the memory hierarchy for the memory operand.
Expand All @@ -113,19 +117,19 @@ def get_operands(self) -> set[MemoryOperand]:

def get_inner_memories(self) -> list[MemoryLevel]:
"""! Returns the inner-most memory levels for all memory operands."""
return [node for node, in_degree in self.in_degree() if in_degree == 0] # type: ignore
return [node for node, in_degree in self.in_degree() if in_degree == 0]

def get_outer_memories(self) -> list[MemoryLevel]:
"""! Returns the outer-most memory levels for all memory operands."""
return [node for node, out_degree in self.out_degree() if out_degree == 0] # type: ignore
return [node for node, out_degree in self.out_degree() if out_degree == 0]

def get_top_memories(self) -> tuple[list[MemoryLevel], int]:
"""! Returns the 'top'-most MemoryLevels, where 'the' level of MemoryLevel is considered to be the largest
level it has across its assigned operands
@return (list_of_memories_on_top_level, top_level)
"""
level_to_mems: defaultdict[int, list[MemoryLevel]] = defaultdict(lambda: [])
for node in self.memory_nodes:
for node in self.node_list:
level_to_mems[max(node.mem_level_of_operands.values())].append(node)
top_level = max(level_to_mems.keys())
return level_to_mems[top_level], top_level
Expand All @@ -136,7 +140,7 @@ def get_operator_top_level(self, operand: MemoryOperand) -> tuple[list[MemoryLev
'The' level of a MemoryLevel is considered to be the largest level it has across its assigned operands.
"""
level_to_mems: dict[int, list[MemoryLevel]] = defaultdict(lambda: [])
for node in self.memory_nodes:
for node in self.node_list:
if operand in node.operands:
level_to_mems[max(node.mem_level_of_operands.values())].append(node)
top_level = max(level_to_mems.keys()) if level_to_mems else -1
Expand All @@ -151,23 +155,6 @@ def get_operand_top_level(self, operand: MemoryOperand) -> MemoryLevel:
return mem
raise ValueError(f"Operand {operand} not found in any of the memory instances.")

def topological_sort(self) -> Iterator[MemoryLevel]:
"""! Wrap `DiGraph.topological_sort` with correct type annotation"""
return nx.topological_sort(self) # type: ignore

def __add_node(self, node: MemoryLevel) -> None:
"""! Wrap `DiGraph.add_node` with correct type annotation"""
self.add_node(node) # type: ignore

def __add_edge(self, sink_node: MemoryLevel, source_node: MemoryLevel):
"""! Wrap `DiGraph.add_edge` with correct type annotation"""
self.add_edge(sink_node, source_node) # type: ignore

@property
def memory_nodes(self) -> list[MemoryLevel]:
"""! Wrap `DiGraph.nodes()` with custom type annotation"""
return list(self.nodes()) # type: ignore

def __jsonrepr__(self):
"""! JSON Representation of this object to save it to a json file."""
return json_repr_handler(list(self.topological_sort()))
Expand All @@ -176,5 +163,5 @@ def __eq__(self, other: object) -> bool:
return (
isinstance(other, MemoryHierarchy)
and self.nb_levels == other.nb_levels
and all([self_ml == other_ml for self_ml, other_ml in zip(self.memory_nodes, other.memory_nodes)])
and all([self_ml == other_ml for self_ml, other_ml in zip(self.node_list, other.node_list)])
)
25 changes: 24 additions & 1 deletion zigzag/hardware/architecture/memory_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@
class MemoryInstance:
"""A single instance within the memory hierarchy, without information about connectivity."""

name: str
size: int
r_bw: int
w_bw: int
r_cost: float
w_cost: float
area: float
r_port: int
w_port: int
rw_port: int
latency: int
min_r_granularity: int
min_w_granularity: int
mem_type: str
auto_cost_extraction: bool
double_buffering_support: bool
shared_memory_group_id: int

def __init__(
self,
name: str,
Expand All @@ -23,6 +41,7 @@ def __init__(
mem_type: str = "sram",
auto_cost_extraction: bool = False,
double_buffering_support: bool = False,
shared_memory_group_id: int = -1,
):
"""
Collect all the basic information of a physical memory module.
Expand All @@ -42,6 +61,8 @@ def __init__(
@param mem_type (str): The type of memory. Used for CACTI cost extraction.
@param auto_cost_extraction (bool): Automatically extract the read cost, write cost and area using CACTI.
@param double_buffering_support (bool): Support for double buffering on this memory instance.
@param shared_memory_group_id: used to indicate whether two MemoryInstance instances represent the same, shared
memory between two cores (feature used in Stream).
"""
if auto_cost_extraction:
cacti_parser = CactiParser()
Expand All @@ -68,6 +89,7 @@ def __init__(
self.rw_port_nb = rw_port
self.latency = latency
self.double_buffering_support = double_buffering_support
self.shared_memory_group_id = shared_memory_group_id

self.r_bw_min: int = min_r_granularity if min_r_granularity is not None else r_bw
self.w_bw_min: int = min_w_granularity if min_w_granularity is not None else w_bw
Expand All @@ -84,7 +106,8 @@ def __eq__(self, other: object) -> bool:
return isinstance(other, MemoryInstance) and self.__dict__ == other.__dict__

def __hash__(self):
return id(self) # unique for every object within its lifetime
# id(self) # unique for every object within its lifetime
return hash(frozenset(self.__dict__.values()))

def __str__(self):
return f"MemoryInstance({self.name})"
Expand Down
18 changes: 16 additions & 2 deletions zigzag/hardware/architecture/memory_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ class MemoryLevel:
"""Represents a single memory in the memory hierarchy, consisting of a memory instance and connectivity
information"""

memory_instance: MemoryInstance
operands: list[MemoryOperand]
mem_level_of_operands: dict[MemoryOperand, int]
oa_dim_sizes: dict[OADimension, int]
port_alloc: PortAllocation
served_dimensions: ServedMemDimensions
id: int
name: str
port_alloc_raw: PortAllocation
read_energy: float
write_energy: float
read_bw: float
write_bw: float
port_list: list[MemoryPort]

def __init__(
self,
memory_instance: MemoryInstance,
Expand All @@ -64,10 +79,9 @@ def __init__(
"""
self.memory_instance = memory_instance
self.operands = operands
self.operands = operands
self.mem_level_of_operands = mem_level_of_operands
self.oa_dim_sizes = operational_array.dimension_sizes
self.id: int = identifier
self.id = identifier
self.served_dimensions = served_dimensions
self.name = self.memory_instance.name

Expand Down
6 changes: 2 additions & 4 deletions zigzag/opt/loma/multipermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ class PermutationConstraint(ABC):
"""! An abstract class to represent a constraint on a permutation."""

@abstractmethod
def is_valid(self, permutation: list[Any]) -> bool:
...
def is_valid(self, permutation: list[Any]) -> bool: ...

@abstractmethod
def is_empty(self) -> bool:
...
def is_empty(self) -> bool: ...


class StaticPositionsConstraint(PermutationConstraint):
Expand Down
Loading

0 comments on commit cd78c23

Please sign in to comment.