diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index c0d3ca460..ae6deae1b 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -9,7 +9,7 @@ jobs: - name: Set up Python '3.11' uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 710dd2b80..6a29e11c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: # hooks: # - id: check-yaml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.278 + rev: v0.6.4 hooks: - id: ruff types_or: [python, pyi] @@ -18,7 +18,7 @@ repos: "120", ] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 24.8.0 hooks: - id: black args: [--line-length, "120"] diff --git a/zigzag/cost_model/cost_model.py b/zigzag/cost_model/cost_model.py index 89fb73c22..2541d5487 100644 --- a/zigzag/cost_model/cost_model.py +++ b/zigzag/cost_model/cost_model.py @@ -234,9 +234,7 @@ def __jsonrepr__(self): "layer": ( self.layer if isinstance(self, CostModelEvaluation) - else self.layer_ids - if isinstance(self, CumulativeCME) - else None + else self.layer_ids if isinstance(self, CumulativeCME) else None ), "spatial_mapping": (self.spatial_mapping_int if isinstance(self, CostModelEvaluation) else None), "temporal_mapping": (self.temporal_mapping if isinstance(self, CostModelEvaluation) else None), diff --git a/zigzag/hardware/architecture/core.py b/zigzag/hardware/architecture/core.py index aaa21af5f..7f574afba 100644 --- a/zigzag/hardware/architecture/core.py +++ b/zigzag/hardware/architecture/core.py @@ -13,6 +13,19 @@ class Core: on top. """ + id: int + operational_array: OperationalArrayABC + memory_hierarchy: MemoryHierarchy + dataflows: SpatialMapping | None + mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] + mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] + mem_size_dict: dict[MemoryOperand, list[int]] + mem_r_bw_dict: dict[MemoryOperand, list[int]] + mem_w_bw_dict: dict[MemoryOperand, list[int]] + mem_r_bw_min_dict: dict[MemoryOperand, list[int]] + mem_w_bw_min_dict: dict[MemoryOperand, list[int]] + mem_sharing_list: list[dict[MemoryOperand, int]] + def __init__( self, core_id: int, @@ -20,14 +33,10 @@ def __init__( memory_hierarchy: MemoryHierarchy, dataflows: SpatialMapping | None = None, ): - self.id = core_id self.id = core_id self.operational_array = operational_array self.memory_hierarchy = memory_hierarchy - self.mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] = {} - - self.dataflows = dataflows - self.mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] = {} + self.mem_hierarchy_dict = {} self.dataflows = dataflows self.recalculate_memory_hierarchy_information() @@ -40,63 +49,55 @@ def get_memory_level(self, mem_op: MemoryOperand, mem_lv: int) -> MemoryLevel: def recalculate_memory_hierarchy_information(self): self.__generate_memory_hierarchy_dict() - self.__generate_memory_sharing_list() + self.__generate_mem_sharing_list() def __generate_memory_hierarchy_dict(self): mem_operands = self.memory_hierarchy.nb_levels.keys() - mem_hierarchy_dict: dict[MemoryOperand, list[MemoryLevel]] = {} - mem_size_dict: dict[MemoryOperand, list[int]] = {} - mem_r_bw_dict: dict[MemoryOperand, list[int]] = {} - mem_w_bw_dict: dict[MemoryOperand, list[int]] = {} - mem_r_bw_min_dict: dict[MemoryOperand, list[int]] = {} - mem_w_bw_min_dict: dict[MemoryOperand, list[int]] = {} + self.mem_hierarchy_dict = {} + self.mem_size_dict = {} + self.mem_r_bw_dict = {} + self.mem_w_bw_dict = {} + self.mem_r_bw_min_dict = {} + self.mem_w_bw_min_dict = {} for mem_op in mem_operands: - mem_hierarchy_dict[mem_op] = [ + self.mem_hierarchy_dict[mem_op] = [ node for node in self.memory_hierarchy.topological_sort() if mem_op in node.operands ] - mem_size_dict[mem_op] = [ + self.mem_size_dict[mem_op] = [ node.memory_instance.size for node in self.memory_hierarchy.topological_sort() if mem_op in node.operands ] - mem_r_bw_dict[mem_op] = [ + self.mem_r_bw_dict[mem_op] = [ node.memory_instance.r_bw for node in self.memory_hierarchy.topological_sort() if mem_op in node.operands ] - mem_w_bw_dict[mem_op] = [ + self.mem_w_bw_dict[mem_op] = [ node.memory_instance.w_bw for node in self.memory_hierarchy.topological_sort() if mem_op in node.operands ] - mem_r_bw_min_dict[mem_op] = [ + self.mem_r_bw_min_dict[mem_op] = [ node.memory_instance.r_bw_min for node in self.memory_hierarchy.topological_sort() if mem_op in node.operands ] - mem_w_bw_min_dict[mem_op] = [ + self.mem_w_bw_min_dict[mem_op] = [ node.memory_instance.w_bw_min for node in self.memory_hierarchy.topological_sort() if mem_op in node.operands ] - self.mem_hierarchy_dict = mem_hierarchy_dict - self.mem_size_dict = mem_size_dict - self.mem_r_bw_dict = mem_r_bw_dict - self.mem_w_bw_dict = mem_w_bw_dict - self.mem_r_bw_min_dict = mem_r_bw_min_dict - self.mem_w_bw_min_dict = mem_w_bw_min_dict - - def __generate_memory_sharing_list(self): + + def __generate_mem_sharing_list(self): """! Generates a list of dictionary that indicates which operand's which memory levels are sharing the same physical memory""" - memory_sharing_list: list[dict[MemoryOperand, int]] = [] + self.mem_sharing_list = [] for mem_lv in self.mem_hierarchy_dict.values(): for mem in mem_lv: operand_mem_share = mem.mem_level_of_operands - if len(operand_mem_share) > 1 and operand_mem_share not in memory_sharing_list: - memory_sharing_list.append(operand_mem_share) - - self.mem_sharing_list = memory_sharing_list + if len(operand_mem_share) > 1 and operand_mem_share not in self.mem_sharing_list: + self.mem_sharing_list.append(operand_mem_share) def get_top_memory_instance(self, mem_op: MemoryOperand) -> MemoryInstance: if mem_op not in self.memory_hierarchy.get_operands(): diff --git a/zigzag/hardware/architecture/memory_hierarchy.py b/zigzag/hardware/architecture/memory_hierarchy.py index b35d93f5e..4885a392a 100644 --- a/zigzag/hardware/architecture/memory_hierarchy.py +++ b/zigzag/hardware/architecture/memory_hierarchy.py @@ -1,23 +1,27 @@ from collections import defaultdict -from typing import Any, Iterator - -import networkx as nx -from networkx import DiGraph +from typing import Any from zigzag.datatypes import MemoryOperand from zigzag.hardware.architecture.memory_instance import MemoryInstance from zigzag.hardware.architecture.memory_level import MemoryLevel, ServedMemDimensions from zigzag.hardware.architecture.memory_port import PortAllocation from zigzag.hardware.architecture.operational_array import OperationalArrayABC -from zigzag.utils import json_repr_handler +from zigzag.utils import DiGraphWrapper, json_repr_handler -class MemoryHierarchy(DiGraph): +class MemoryHierarchy(DiGraphWrapper[MemoryLevel]): """! Class that represents a memory hierarchy as a directed networkx graph. The memory hierarchy graph is directed, with the root nodes representing the lowest level in the memory hierarchy. """ + name: str + operational_array: OperationalArrayABC + operansd: set[MemoryOperand] + nb_levels: dict[MemoryOperand, int] + mem_level_list: list[MemoryLevel] + memory_level_id: int + def __init__( self, operational_array: OperationalArrayABC, @@ -31,7 +35,7 @@ def __init__( @param nodes: a list of MemoryLevels. Entries need to be provided from lowest to highest memory level. """ super().__init__(**attr) # type: ignore - self.name: str = name + self.name = name # type: ignore self.operational_array = operational_array # Initialize the set that will store all memory operands self.operands: set[MemoryOperand] = set() @@ -71,7 +75,7 @@ def add_memory( # Compute which memory level this is for all the operands mem_level_of_operands: dict[MemoryOperand, int] = {} for mem_op in operands: - nb_levels_so_far = len([node for node in self.memory_nodes if mem_op in node.operands]) + nb_levels_so_far = len([node for node in self.node_list if mem_op in node.operands]) mem_level_of_operands[mem_op] = nb_levels_so_far memory_level = MemoryLevel( @@ -94,11 +98,11 @@ def add_memory( to_edge_from.add(m) # Add the node to the graph - self.__add_node(memory_level) + self.add_node(memory_level) for sink_node in to_edge_from: # Add an edge from this sink node to the current node - self.__add_edge(sink_node, memory_level) + self.add_edge(sink_node, memory_level) def get_memory_levels(self, mem_op: MemoryOperand) -> list[MemoryLevel]: """! Returns a list of memories in the memory hierarchy for the memory operand. @@ -113,11 +117,11 @@ def get_operands(self) -> set[MemoryOperand]: def get_inner_memories(self) -> list[MemoryLevel]: """! Returns the inner-most memory levels for all memory operands.""" - return [node for node, in_degree in self.in_degree() if in_degree == 0] # type: ignore + return [node for node, in_degree in self.in_degree() if in_degree == 0] def get_outer_memories(self) -> list[MemoryLevel]: """! Returns the outer-most memory levels for all memory operands.""" - return [node for node, out_degree in self.out_degree() if out_degree == 0] # type: ignore + return [node for node, out_degree in self.out_degree() if out_degree == 0] def get_top_memories(self) -> tuple[list[MemoryLevel], int]: """! Returns the 'top'-most MemoryLevels, where 'the' level of MemoryLevel is considered to be the largest @@ -125,7 +129,7 @@ def get_top_memories(self) -> tuple[list[MemoryLevel], int]: @return (list_of_memories_on_top_level, top_level) """ level_to_mems: defaultdict[int, list[MemoryLevel]] = defaultdict(lambda: []) - for node in self.memory_nodes: + for node in self.node_list: level_to_mems[max(node.mem_level_of_operands.values())].append(node) top_level = max(level_to_mems.keys()) return level_to_mems[top_level], top_level @@ -136,7 +140,7 @@ def get_operator_top_level(self, operand: MemoryOperand) -> tuple[list[MemoryLev 'The' level of a MemoryLevel is considered to be the largest level it has across its assigned operands. """ level_to_mems: dict[int, list[MemoryLevel]] = defaultdict(lambda: []) - for node in self.memory_nodes: + for node in self.node_list: if operand in node.operands: level_to_mems[max(node.mem_level_of_operands.values())].append(node) top_level = max(level_to_mems.keys()) if level_to_mems else -1 @@ -151,23 +155,6 @@ def get_operand_top_level(self, operand: MemoryOperand) -> MemoryLevel: return mem raise ValueError(f"Operand {operand} not found in any of the memory instances.") - def topological_sort(self) -> Iterator[MemoryLevel]: - """! Wrap `DiGraph.topological_sort` with correct type annotation""" - return nx.topological_sort(self) # type: ignore - - def __add_node(self, node: MemoryLevel) -> None: - """! Wrap `DiGraph.add_node` with correct type annotation""" - self.add_node(node) # type: ignore - - def __add_edge(self, sink_node: MemoryLevel, source_node: MemoryLevel): - """! Wrap `DiGraph.add_edge` with correct type annotation""" - self.add_edge(sink_node, source_node) # type: ignore - - @property - def memory_nodes(self) -> list[MemoryLevel]: - """! Wrap `DiGraph.nodes()` with custom type annotation""" - return list(self.nodes()) # type: ignore - def __jsonrepr__(self): """! JSON Representation of this object to save it to a json file.""" return json_repr_handler(list(self.topological_sort())) @@ -176,5 +163,5 @@ def __eq__(self, other: object) -> bool: return ( isinstance(other, MemoryHierarchy) and self.nb_levels == other.nb_levels - and all([self_ml == other_ml for self_ml, other_ml in zip(self.memory_nodes, other.memory_nodes)]) + and all([self_ml == other_ml for self_ml, other_ml in zip(self.node_list, other.node_list)]) ) diff --git a/zigzag/hardware/architecture/memory_instance.py b/zigzag/hardware/architecture/memory_instance.py index 45cb70231..5a9ed55df 100644 --- a/zigzag/hardware/architecture/memory_instance.py +++ b/zigzag/hardware/architecture/memory_instance.py @@ -5,6 +5,24 @@ class MemoryInstance: """A single instance within the memory hierarchy, without information about connectivity.""" + name: str + size: int + r_bw: int + w_bw: int + r_cost: float + w_cost: float + area: float + r_port: int + w_port: int + rw_port: int + latency: int + min_r_granularity: int + min_w_granularity: int + mem_type: str + auto_cost_extraction: bool + double_buffering_support: bool + shared_memory_group_id: int + def __init__( self, name: str, @@ -23,6 +41,7 @@ def __init__( mem_type: str = "sram", auto_cost_extraction: bool = False, double_buffering_support: bool = False, + shared_memory_group_id: int = -1, ): """ Collect all the basic information of a physical memory module. @@ -42,6 +61,8 @@ def __init__( @param mem_type (str): The type of memory. Used for CACTI cost extraction. @param auto_cost_extraction (bool): Automatically extract the read cost, write cost and area using CACTI. @param double_buffering_support (bool): Support for double buffering on this memory instance. + @param shared_memory_group_id: used to indicate whether two MemoryInstance instances represent the same, shared + memory between two cores (feature used in Stream). """ if auto_cost_extraction: cacti_parser = CactiParser() @@ -68,6 +89,7 @@ def __init__( self.rw_port_nb = rw_port self.latency = latency self.double_buffering_support = double_buffering_support + self.shared_memory_group_id = shared_memory_group_id self.r_bw_min: int = min_r_granularity if min_r_granularity is not None else r_bw self.w_bw_min: int = min_w_granularity if min_w_granularity is not None else w_bw @@ -84,7 +106,8 @@ def __eq__(self, other: object) -> bool: return isinstance(other, MemoryInstance) and self.__dict__ == other.__dict__ def __hash__(self): - return id(self) # unique for every object within its lifetime + # id(self) # unique for every object within its lifetime + return hash(frozenset(self.__dict__.values())) def __str__(self): return f"MemoryInstance({self.name})" diff --git a/zigzag/hardware/architecture/memory_level.py b/zigzag/hardware/architecture/memory_level.py index cba22a6bf..8f6f68932 100644 --- a/zigzag/hardware/architecture/memory_level.py +++ b/zigzag/hardware/architecture/memory_level.py @@ -48,6 +48,21 @@ class MemoryLevel: """Represents a single memory in the memory hierarchy, consisting of a memory instance and connectivity information""" + memory_instance: MemoryInstance + operands: list[MemoryOperand] + mem_level_of_operands: dict[MemoryOperand, int] + oa_dim_sizes: dict[OADimension, int] + port_alloc: PortAllocation + served_dimensions: ServedMemDimensions + id: int + name: str + port_alloc_raw: PortAllocation + read_energy: float + write_energy: float + read_bw: float + write_bw: float + port_list: list[MemoryPort] + def __init__( self, memory_instance: MemoryInstance, @@ -64,10 +79,9 @@ def __init__( """ self.memory_instance = memory_instance self.operands = operands - self.operands = operands self.mem_level_of_operands = mem_level_of_operands self.oa_dim_sizes = operational_array.dimension_sizes - self.id: int = identifier + self.id = identifier self.served_dimensions = served_dimensions self.name = self.memory_instance.name diff --git a/zigzag/opt/loma/multipermute.py b/zigzag/opt/loma/multipermute.py index 89a2a1b88..acfc42193 100644 --- a/zigzag/opt/loma/multipermute.py +++ b/zigzag/opt/loma/multipermute.py @@ -58,12 +58,10 @@ class PermutationConstraint(ABC): """! An abstract class to represent a constraint on a permutation.""" @abstractmethod - def is_valid(self, permutation: list[Any]) -> bool: - ... + def is_valid(self, permutation: list[Any]) -> bool: ... @abstractmethod - def is_empty(self) -> bool: - ... + def is_empty(self) -> bool: ... class StaticPositionsConstraint(PermutationConstraint): diff --git a/zigzag/parser/accelerator_factory.py b/zigzag/parser/accelerator_factory.py index e797f1929..656979fd0 100644 --- a/zigzag/parser/accelerator_factory.py +++ b/zigzag/parser/accelerator_factory.py @@ -45,7 +45,7 @@ def __init__(self, data: dict[str, Any]): """! Generate an `Core` instance from the validated user-provided data.""" self.data = data - def create(self, core_id: int = 1) -> Core: + def create(self, core_id: int = 1, shared_mem_group_id: int | None = None) -> Core: """! Create a Core instance from the user-provided data. NOTE the memory instances must be defined from lowest to highest. """ @@ -53,9 +53,12 @@ def create(self, core_id: int = 1) -> Core: mem_graph = MemoryHierarchy(operational_array) dataflows = self.create_dataflows() + shared_mem_group_id = core_id if shared_mem_group_id is None else shared_mem_group_id for mem_name in self.data["memories"]: - memory_factory = MemoryFactory(mem_name, self.data["memories"][mem_name]) + memory_factory = MemoryFactory( + mem_name, self.data["memories"][mem_name], shared_mem_group_id=shared_mem_group_id + ) memory_factory.add_memory_to_graph(mem_graph) return Core( @@ -143,9 +146,10 @@ def __create_dataflow_single_oa_dim(self, mapping_data: list[str]) -> MappingSin class MemoryFactory: """! Create MemoryInstances and adds them to memory hierarchy.""" - def __init__(self, name: str, mem_data: dict[str, Any]): + def __init__(self, name: str, mem_data: dict[str, Any], shared_mem_group_id: int = -1): self.data = mem_data self.name = name + self.shared_mem_group_id = shared_mem_group_id def create_memory_instance(self) -> MemoryInstance: return MemoryInstance( @@ -164,6 +168,7 @@ def create_memory_instance(self) -> MemoryInstance: min_r_granularity=self.data["min_r_granularity"], min_w_granularity=self.data["min_w_granularity"], auto_cost_extraction=self.data["auto_cost_extraction"], + shared_memory_group_id=self.shared_mem_group_id, ) def add_memory_to_graph(self, mem_graph: MemoryHierarchy) -> None: diff --git a/zigzag/parser/onnx/onnx_operator_parser.py b/zigzag/parser/onnx/onnx_operator_parser.py index 3e7d38e6c..2d0295525 100644 --- a/zigzag/parser/onnx/onnx_operator_parser.py +++ b/zigzag/parser/onnx/onnx_operator_parser.py @@ -34,8 +34,7 @@ def __init__( self.accelerator = accelerator @abstractmethod - def run(self) -> LayerNodeABC: - ... + def run(self) -> LayerNodeABC: ... def get_input_output_weight_data_type(self): """! Return the data type of the input, output and weight tensors of this node.""" diff --git a/zigzag/stages/exploit_data_locality_stages.py b/zigzag/stages/exploit_data_locality_stages.py index 9b7d65b6d..53f1fabe8 100644 --- a/zigzag/stages/exploit_data_locality_stages.py +++ b/zigzag/stages/exploit_data_locality_stages.py @@ -17,8 +17,7 @@ from zigzag.stages.stage import Stage, StageCallable from zigzag.utils import pickle_deepcopy from zigzag.workload.layer_node import LayerNode -from zigzag.workload.layer_node_abc import LayerNodeABC -from zigzag.workload.workload_abc import WorkloadABC +from zigzag.workload.workload_abc import WorkloadABC, WorkloadNoDummyABC logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ def __init__( list_of_callables: list[StageCallable], *, accelerator: Accelerator, - workload: WorkloadABC[LayerNodeABC], + workload: WorkloadABC | WorkloadNoDummyABC, **kwargs: Any, ): super().__init__(list_of_callables, **kwargs) @@ -151,17 +150,14 @@ def update_top_mem_level(self): weight_mem_op = layer.memory_operand_links.layer_to_mem_op(weight_layer_op) output_mem_op = layer.memory_operand_links.layer_to_mem_op(output_layer_op) - is_branch_starting_node = self.workload_no_dummy.get_out_degree_for_layer(layer) > 1 - is_branch_final_node = (self.workload_no_dummy.get_out_degree_for_layer(layer) == 1) and ( - self.workload_no_dummy.get_in_degree_for_layer( - next(self.workload_no_dummy.get_successors_for_layer(layer)) - ) - > 1 + is_branch_starting_node = self.workload_no_dummy.out_degree(layer) > 1 + is_branch_final_node = (self.workload_no_dummy.out_degree(layer) == 1) and ( + self.workload_no_dummy.in_degree(next(self.workload_no_dummy.successors(layer))) > 1 ) if not is_first_layer: # propagate output mem level of the previous layer to input mem level of current layer - prev_layer = next(self.workload_no_dummy.get_predecessors_for_layer(layer)) + prev_layer = next(self.workload_no_dummy.predecessors(layer)) prev_layer_output_layer_op = prev_layer.output_operand prev_layer_output_mem_op = prev_layer.memory_operand_links.layer_to_mem_op(prev_layer_output_layer_op) # starting node of branches @@ -182,7 +178,7 @@ def update_top_mem_level(self): # check if curr_mem_level serve the next layer input if not is_final_layer: # grab the next layer name, which is a non-Adder layer for sure - next_layer = next(self.workload_no_dummy.get_successors_for_layer(layer)) + next_layer = next(self.workload_no_dummy.successors(layer)) next_layer_act_layer_op = next_layer.get_act_layer_op() next_layer_act_mem_op = next_layer.memory_operand_links.layer_to_mem_op(next_layer_act_layer_op) diff --git a/zigzag/stages/stage.py b/zigzag/stages/stage.py index 372a0b846..d04a12329 100644 --- a/zigzag/stages/stage.py +++ b/zigzag/stages/stage.py @@ -29,8 +29,7 @@ def __init__( ) @abstractmethod - def run(self) -> Generator[tuple[CostModelEvaluationABC, Any], None, None]: - ... + def run(self) -> Generator[tuple[CostModelEvaluationABC, Any], None, None]: ... def __iter__(self): return self.run() @@ -45,5 +44,4 @@ def is_leaf(self) -> bool: @runtime_checkable class StageCallable(Protocol): - def __call__(self, list_of_callables: list["StageCallable"], **kwagrs: Any) -> Stage: - ... + def __call__(self, list_of_callables: list["StageCallable"], **kwagrs: Any) -> Stage: ... diff --git a/zigzag/stages/workload_iterator.py b/zigzag/stages/workload_iterator.py index 7150ef259..2dfeb5e96 100644 --- a/zigzag/stages/workload_iterator.py +++ b/zigzag/stages/workload_iterator.py @@ -5,8 +5,7 @@ from zigzag.hardware.architecture.imc_array import ImcArray from zigzag.stages.stage import Stage, StageCallable from zigzag.workload.layer_node import LayerNode -from zigzag.workload.layer_node_abc import LayerNodeABC -from zigzag.workload.workload_abc import WorkloadABC +from zigzag.workload.workload_abc import WorkloadABC, WorkloadNoDummyABC logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def __init__( self, list_of_callables: list[StageCallable], *, - workload: WorkloadABC[LayerNodeABC], + workload: WorkloadABC | WorkloadNoDummyABC, accelerator: Accelerator, **kwargs: Any, ): diff --git a/zigzag/utils.py b/zigzag/utils.py index 3fc43c003..93f1501f4 100644 --- a/zigzag/utils.py +++ b/zigzag/utils.py @@ -1,16 +1,19 @@ import logging import pickle from copy import deepcopy -from hashlib import sha512 -from typing import Any +from hashlib import sha512 # type: ignore +from typing import Any, Generic, Iterator, Literal, Sequence, TypeVar, no_type_check, overload +import networkx as nx import numpy as np import yaml +from networkx import DiGraph +from typeguard import typeguard_ignore # type: ignore def hash_sha512(data: Any) -> int: """! Hashes the input data using SHA-512""" - return int(sha512(pickle.dumps(data)).hexdigest(), 16) + return int(sha512(pickle.dumps(data)).hexdigest(), 16) # type: ignore def pickle_deepcopy(to_copy: Any) -> Any: @@ -33,7 +36,7 @@ def pickle_load(path: str): return obj -def open_yaml(path: str): +def open_yaml(path: str) -> dict[str, Any] | list[dict[str, Any]]: with open(path, encoding="utf-8") as f: data = yaml.safe_load(f) return data @@ -80,3 +83,105 @@ def filter(self, record: logging.LogRecord): else: self.recorded_messages.add(message) return True + + +T = TypeVar("T") + + +@no_type_check +class DiGraphWrapper(Generic[T], DiGraph): + """Wraps the DiGraph class with type annotations for the nodes""" + + @overload + def in_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: ... + + @overload + def in_edges(self, node: T, data: Literal[True]) -> list[tuple[T, T, dict[str, Any]]]: ... + + @overload + def in_edges(self, node: T) -> list[tuple[T, T]]: ... + + def in_edges( # type: ignore # pylint: disable=W0246 + self, + node: T, + data: bool = False, + ) -> list[tuple[T, T]] | list[tuple[T, T, dict[str, Any]]]: + return super().in_edges(node, data) # type: ignore + + @overload + def out_edges(self, node: T, data: Literal[True]) -> list[tuple[T, T, dict[str, Any]]]: ... + + @overload + def out_edges(self, node: T, data: Literal[False]) -> list[tuple[T, T]]: ... + + @overload + def out_edges(self, node: T) -> list[tuple[T, T]]: ... + + def out_edges( # type: ignore # pylint: disable=W0246 + self, + node: T, + data: bool = False, + ) -> list[tuple[T, T]] | list[tuple[T, T, dict[str, Any]]]: + return super().out_edges(node, data) # type: ignore + + @typeguard_ignore + def in_degree(self) -> Iterator[tuple[T, int]]: # type: ignore + return super().in_degree() # type: ignore + + @overload + def out_degree(self, node: Literal[None]) -> Iterator[tuple[T, int]]: ... + + @overload + def out_degree(self) -> Iterator[tuple[T, int]]: ... + + @overload + def out_degree(self, node: T) -> int: ... + + def out_degree(self, node: T | None = None) -> int | Iterator[tuple[T, int]]: # type: ignore + if node: + return super().out_degree(node) # type: ignore + return super().out_degree() # type: ignore + + def successors(self, node: T) -> Iterator[T]: # type: ignore # pylint: disable=W0246 + return super().successors(node) # type: ignore + + def predecessors(self, node: T) -> Iterator[T]: # type: ignore # pylint: disable=W0246 + return super().predecessors(node) # type: ignore + + @typeguard_ignore + def topological_sort(self) -> Iterator[T]: + return nx.topological_sort(self) # type: ignore + + def add_node(self, node: T) -> None: # type: ignore # pylint: disable=W0246 + super().add_node(node) # type: ignore + + def add_nodes_from(self, node: Sequence[T]) -> None: # pylint: disable=W0246 + super().add_nodes_from(node) # type: ignore + + def remove_nodes_from(self, nodes: Iterator[T]) -> None: # pylint: disable=W0246 + super().remove_nodes_from(nodes) # type: ignore + + def add_edge(self, edge_from: T, edge_to: T) -> None: # type: ignore # pylint: disable=W0246 + super().add_edge(edge_from, edge_to) # type: ignore + + def add_edges_from( # type: ignore # pylint: disable=W0246 + self, + edges: Sequence[tuple[T, T] | tuple[T, T, Any]], + ) -> None: + super().add_edges_from(edges) # type: ignore + + def all_simple_paths(self, producer: T, consumer: T) -> Iterator[list[T]]: + return nx.all_simple_paths(self, source=producer, target=consumer) # type: ignore + + def shortest_path(self, producer: T, consumer: T) -> list[T]: + return nx.shortest_path(self, producer, consumer) # type: ignore + + @property + def node_list(self) -> list[T]: + return list(self.nodes()) # type: ignore + + def get_node_with_id(self, node_id: int) -> T: + for node in self.node_list: + if node.id == node_id: # type: ignore + return node + raise ValueError(f"Node with id {node_id} not found.") diff --git a/zigzag/visualization/graph/memory_hierarchy.py b/zigzag/visualization/graph/memory_hierarchy.py index 9cde4b90a..7da5f40c1 100644 --- a/zigzag/visualization/graph/memory_hierarchy.py +++ b/zigzag/visualization/graph/memory_hierarchy.py @@ -1,11 +1,11 @@ import matplotlib.pyplot as plt import networkx as nx -from networkx import DiGraph +from zigzag.hardware.architecture.memory_hierarchy import MemoryHierarchy from zigzag.hardware.architecture.memory_level import MemoryLevel -def visualize_memory_hierarchy_graph(graph: DiGraph, save_path: str = ""): +def visualize_memory_hierarchy_graph(graph: MemoryHierarchy, save_path: str = ""): """ Visualizes a memory hierarchy graph. """ diff --git a/zigzag/workload/dnn_workload.py b/zigzag/workload/dnn_workload.py index 595856057..5dab006f5 100644 --- a/zigzag/workload/dnn_workload.py +++ b/zigzag/workload/dnn_workload.py @@ -2,10 +2,10 @@ from typing import Any from zigzag.workload.layer_node import LayerNode -from zigzag.workload.workload_abc import WorkloadABC +from zigzag.workload.workload_abc import WorkloadNoDummyABC -class DNNWorkload(WorkloadABC[LayerNode]): +class DNNWorkload(WorkloadNoDummyABC): """Extends the ABC for workloads. For user-defined workloads (from yaml), the DummyNodes are not saved so this class only holds (non-dummy) LayerNodes""" @@ -13,7 +13,7 @@ def __init__(self, nodes: list[LayerNode], **attr: Any): """! @return (self): Directed Graph with nodes the layers and edges the connections between layers. """ - super().__init__(**attr) + super().__init__(**attr) # type: ignore layer_id_to_obj: dict[int, LayerNode] = {} self.layer_node_list = nodes @@ -21,7 +21,7 @@ def __init__(self, nodes: list[LayerNode], **attr: Any): for layer_node in nodes: layer_id_to_obj[layer_node.id] = layer_node - self.add_workload_node(layer_node) + self.add_node(layer_node) # Find all of its operand sources and add edges accordingly edges: list[tuple[LayerNode, LayerNode]] = [] for _, parent_id in layer_node.input_operand_source.items(): @@ -30,8 +30,8 @@ def __init__(self, nodes: list[LayerNode], **attr: Any): parent_layer = layer_id_to_obj[parent_id] edges.append((parent_layer, layer_node)) - self.add_workload_edges_from(edges) + self.add_edges_from(edges) - def get_copy_no_dummy(self) -> WorkloadABC[LayerNode]: + def get_copy_no_dummy(self) -> "DNNWorkload": """Return a copy. DNNWorkloads don't contain DummyNodes in the first place.""" return deepcopy(self) diff --git a/zigzag/workload/layer_attributes.py b/zigzag/workload/layer_attributes.py index 82a232b8e..1956f0c6a 100644 --- a/zigzag/workload/layer_attributes.py +++ b/zigzag/workload/layer_attributes.py @@ -110,6 +110,9 @@ def final_output_precision(self) -> int: class MemoryOperandLinks(LayerAttribute): """! Links LayerOperand to MemoryOperand.""" + layer_operands: list[LayerOperand] + mem_operands: list[MemoryOperand] + def __init__(self, data: dict[LayerOperand, MemoryOperand]): self.data = data # Class variables are computed and stored once to improve runtime performance @@ -231,25 +234,25 @@ def to_legacy_format(self): return self.data def get_constraints(self) -> list[PermutationConstraint]: - static_posistions_dict: dict[int, LayerDim] = {} - static_posistions_and_sizes_dict: dict[int, tuple[LayerDim, int]] = {} + static_positions_dict: dict[int, LayerDim] = {} + static_positions_and_sizes_dict: dict[int, tuple[LayerDim, int]] = {} outer_loop = False for count, (layer_dim, factor) in enumerate(self.data): if (layer_dim == Constants.UNKNOWN_DIM_OPERATOR) and (factor is None): outer_loop = True elif factor is None: if not outer_loop: - static_posistions_dict[count] = layer_dim + static_positions_dict[count] = layer_dim else: - static_posistions_dict[count - len(self.data)] = layer_dim + static_positions_dict[count - len(self.data)] = layer_dim else: if not outer_loop: - static_posistions_and_sizes_dict[count] = (layer_dim, factor) + static_positions_and_sizes_dict[count] = (layer_dim, factor) else: - static_posistions_and_sizes_dict[count - len(self.data)] = (layer_dim, factor) - static_positions = StaticPositionsConstraint(static_posistions_dict) - static_posistions_and_sizes = StaticPositionsAndSizesConstraint(static_posistions_and_sizes_dict) - return [static_positions, static_posistions_and_sizes] + static_positions_and_sizes_dict[count - len(self.data)] = (layer_dim, factor) + static_positions = StaticPositionsConstraint(static_positions_dict) + static_positions_and_sizes = StaticPositionsAndSizesConstraint(static_positions_and_sizes_dict) + return [static_positions, static_positions_and_sizes] class LayerPadding(LayerAttribute): diff --git a/zigzag/workload/layer_node.py b/zigzag/workload/layer_node.py index 2084f06c0..ae756a8eb 100644 --- a/zigzag/workload/layer_node.py +++ b/zigzag/workload/layer_node.py @@ -124,6 +124,27 @@ class LayerNodeAttributes: class LayerNode(LayerNodeABC): """! Represents a single layer in a workload.""" + type: str + equation: LayerEquation + layer_dim_sizes: LayerDimSizes + operand_precision: LayerOperandPrecision + dimension_relations: list[LayerDimRelation] + spatial_mapping: SpatialMapping + spatial_mapping_hint: SpatialMappingHint + core_allocation: list[int] + core_allocation_is_fixed: bool + memory_operand_links: MemoryOperandLinks + temporal_ordering: LayerTemporalOrdering + padding: LayerPadding + constant_operands: list[LayerOperand] + input_operand_source: InputOperandSource + layer_operands: list[LayerOperand] + output_operand: LayerOperand + input_operands: list[LayerOperand] + layer_dims: list[LayerDim] + pr_loop: PrLoop + pr_layer_dim_sizes: LayerDimSizes | None + def __init__(self, layer_id: int, node_name: str, node_attr: LayerNodeAttributes): """ To construct each layer node, algorithm equation/dimension/indirect relation are parsed. @@ -153,8 +174,8 @@ def __init__(self, layer_id: int, node_name: str, node_attr: LayerNodeAttributes # Derived attributes self.layer_operands = self.equation.get_contained_operands() - self.output_operand: LayerOperand = self.layer_operands[0] - self.input_operands: list[LayerOperand] = self.layer_operands[1:] + self.output_operand = self.layer_operands[0] + self.input_operands = self.layer_operands[1:] self.layer_dims = list(self.layer_dim_sizes.layer_dims) self.pr_loop, pr_loop_list, self.pr_scaling_factors = self.build_pr_funcs() diff --git a/zigzag/workload/onnx_workload.py b/zigzag/workload/onnx_workload.py index a92ed3c10..1b7f31097 100644 --- a/zigzag/workload/onnx_workload.py +++ b/zigzag/workload/onnx_workload.py @@ -1,18 +1,20 @@ from copy import deepcopy from typing import Any +from typeguard import typeguard_ignore + from zigzag.workload.dummy_node import DummyNode from zigzag.workload.layer_node import LayerNode from zigzag.workload.layer_node_abc import LayerNodeABC -from zigzag.workload.workload_abc import WorkloadABC +from zigzag.workload.workload_abc import WorkloadABC, WorkloadNoDummyABC -class ONNXWorkload(WorkloadABC[LayerNodeABC]): +class ONNXWorkload(WorkloadABC): """Represents a workload graph parsed from ONNX""" def __init__(self, **attr: Any): """! Collect all the algorithmic workload information here.""" - super().__init__(**attr) + super().__init__(**attr) # type: ignore self.node_id_to_obj: dict[int, LayerNodeABC] = {} @@ -22,14 +24,15 @@ def add(self, node_id: int, node_obj: LayerNodeABC): """ self.node_id_to_obj[node_id] = node_obj - self.add_workload_node(node_obj) + self.add_node(node_obj) edges: list[tuple[LayerNodeABC, LayerNodeABC]] = [] for parent_id in node_obj.input_operand_source.values(): parent_node_obj = self.node_id_to_obj[parent_id] edges.append((parent_node_obj, node_obj)) - self.add_workload_edges_from(edges) + self.add_edges_from(edges) - def get_copy_no_dummy(self) -> WorkloadABC[LayerNode]: + @typeguard_ignore + def get_copy_no_dummy(self) -> WorkloadNoDummyABC: """! Remove dummy nodes (layers) in the graph Redirect the outgoing edges of dummy nodes to non-dummy nodes Method: for each dummy node, add edges between its predecessor nodes and successor nodes; then remove the dummy node. @@ -38,14 +41,11 @@ def get_copy_no_dummy(self) -> WorkloadABC[LayerNode]: dummy_nodes = [node for node in workload_copy.node_list if isinstance(node, DummyNode)] for dummy_node in dummy_nodes: - for successor_node in workload_copy.get_successors_for_layer(dummy_node): - for predecessor_node in workload_copy.get_predecessors_for_layer(dummy_node): - workload_copy.add_workload_edge(predecessor_node, successor_node) - - workload_copy.remove_workload_nodes_from(iter(dummy_nodes)) + for successor_node in workload_copy.successors(dummy_node): + for predecessor_node in workload_copy.predecessors(dummy_node): + workload_copy.add_edge(predecessor_node, successor_node) - # Typecast - workload_result: WorkloadABC[LayerNode] = workload_copy # type: ignore + workload_copy.remove_nodes_from(iter(dummy_nodes)) - assert all([isinstance(x, LayerNode) for x in workload_result.node_list]) - return workload_result + assert all([isinstance(x, LayerNode) for x in workload_copy.node_list]) + return workload_copy # type: ignore diff --git a/zigzag/workload/workload_abc.py b/zigzag/workload/workload_abc.py index e1d142f46..b4f45ee65 100644 --- a/zigzag/workload/workload_abc.py +++ b/zigzag/workload/workload_abc.py @@ -1,57 +1,18 @@ from abc import ABCMeta -from typing import Any, Generic, Iterator, Sequence, TypeVar - -import networkx as nx -from networkx import DiGraph +from typing import TypeVar +from zigzag.utils import DiGraphWrapper from zigzag.workload.layer_node import LayerNode from zigzag.workload.layer_node_abc import LayerNodeABC T = TypeVar("T", bound=LayerNodeABC) -class WorkloadABC(DiGraph, Generic[T], metaclass=ABCMeta): +class WorkloadABC(DiGraphWrapper[LayerNodeABC], metaclass=ABCMeta): """! Abstract Base Class for workloads, parameterizable with type T, which must be a (subclass from) LayerNodeABC""" - def __init__(self, **attr: Any): - super().__init__(**attr) # type: ignore - - def topological_sort(self) -> Iterator[T]: - return nx.topological_sort(self) # type: ignore - - def add_workload_node(self, node: T) -> None: - self.add_node(node) # type: ignore - - def remove_workload_nodes_from(self, nodes: Iterator[T]) -> None: - self.remove_nodes_from(nodes) # type: ignore - - def add_workload_edge(self, edge_from: T, edge_to: T) -> None: - self.add_edge(edge_from, edge_to) # type: ignore - - def add_workload_edges_from(self, edges: Sequence[tuple[T, T]]) -> None: - self.add_edges_from(edges) # type: ignore - - def get_out_degree_for_layer(self, layer: T) -> int: - return self.out_degree(layer) # type: ignore - - def get_in_degree_for_layer(self, layer: T) -> int: - return self.in_degree(layer) # type: ignore - - def get_successors_for_layer(self, layer: T) -> Iterator[T]: - return self.successors(layer) # type: ignore - - def get_predecessors_for_layer(self, layer: T) -> Iterator[T]: - return self.predecessors(layer) # type: ignore - - def get_node_with_id(self, node_id: int) -> T: - for node in self.node_list: - if node.id == node_id: - return node - raise ValueError(f"Node with id {node_id} not found in workload") + def get_copy_no_dummy(self) -> "WorkloadNoDummyABC": ... - def get_copy_no_dummy(self) -> "WorkloadABC[LayerNode]": - ... - @property - def node_list(self) -> list[T]: - return list(self.nodes()) # type: ignore +class WorkloadNoDummyABC(DiGraphWrapper[LayerNode], metaclass=ABCMeta): + "Abstract bass class for workloads with only simulatable nodes"