Skip to content

Commit

Permalink
outer_tmap_loop_dimensions order must match scheduling order
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinGeens committed Jan 2, 2025
1 parent 2e1107b commit a9701c6
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def extract_steady_state_per_stack(self):
sink_nodes: list[ComputationNode] = sorted(
n for n in sg.nodes() if len(get_real_successors(n, sg)) == 0 # type: ignore
)
sink_layer_ids = sg.get_sink_layer_ids()
sink_layer_ids = set(n.id for n in sink_nodes)
sink_layer_nodes = [tuple(sorted(n for n in sink_nodes if n.id == layer_id)) for layer_id in sink_layer_ids]
interlaced = [tuple(filter(lambda x: x is not None, t)) for t in itertools.zip_longest(*sink_layer_nodes)]
computed: set[ComputationNode] = set()
Expand Down Expand Up @@ -413,6 +413,7 @@ def schedule_allocation(self, allocation: ALLOCATION_T) -> StreamCostModelEvalua
kwargs["accelerator"] = self.accelerator
kwargs["workload"] = unpartitioned_sub_workload
kwargs["scheduling_order"] = scheduling_order
kwargs["layer_stacks"] = self.layer_stacks
kwargs["tiled_workload_path"] = self.tiled_workload_post_co_path
kwargs["cost_lut_path"] = self.cost_lut_post_co_path
kwargs["latency_attr"] = self.latency_attr
Expand Down
1 change: 1 addition & 0 deletions stream/stages/generation/scheduling_order_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def run(self):
self.kwargs["accelerator"] = self.accelerator
self.kwargs["workload"] = self.workload
self.kwargs["scheduling_order"] = self.scheduling_order
# self.kwargs["layer_stacks"] = self.layer_stacks # TODO is already in kwargs
sub_stage = self.list_of_callables[0](
self.list_of_callables[1:],
**self.kwargs,
Expand Down
144 changes: 113 additions & 31 deletions stream/stages/generation/tiled_workload_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import logging
import os
from copy import deepcopy
Expand All @@ -7,6 +8,7 @@
from rtree import index
from zigzag.datatypes import Constants, LayerDim, LayerOperand
from zigzag.utils import pickle_deepcopy, pickle_load, pickle_save
from zigzag.workload.layer_attributes import LayerDimSizes

from stream.cost_model.group_allocation import GroupIdManager
from stream.hardware.architecture.accelerator import Accelerator
Expand Down Expand Up @@ -55,26 +57,26 @@ def __init__(
super().__init__(list_of_callables, **kwargs)
self.workload = workload
self.accelerator = accelerator
self.layer_stacks = kwargs.get("layer_stacks", [])

# Save for each of the workload's nodes the tiles that will be generated
self.tiles_dict: dict[ComputationNode, list[ComputationNode]] = {}

# Memoize the numpy tensors for dependency generation
self.numpy_tensors = {}

self.tiled_workload_path = tiled_workload_path

def run(self):
all_unique_tiles: list[ComputationNode] = []
# For each node get all the tiles and the edges between them
all_tiles = []
all_edges = []
all_tiles: list[ComputationNode] = []
all_edges: list[tuple[ComputationNode, ComputationNode, dict[str, int]]] = []
for node in self.workload.topological_sort():
# If other node types shouldn't be included in tiled workload graph, add here
if not isinstance(node, ComputationNode):
continue
outer_temporal_loops = self.get_outer_tmap_loop_dimensions(node)
tiles, unique_tiles = self.get_tiles(node, outer_temporal_loops)
mandatory_divisors = self.get_mandatory_divisors(node)
tiles, unique_tiles = self.get_tiles(node, outer_temporal_loops, mandatory_divisors)
logger.info(f"{node}: Outer loops {outer_temporal_loops}.")
logger.info(f"{node}: Generated {len(tiles)} tile(s).")
self.tiles_dict[node] = tiles
Expand Down Expand Up @@ -180,6 +182,11 @@ def get_all_node_pairs(G: ONNXWorkload) -> tuple[tuple[ComputationNode, Computat
def get_outer_tmap_loop_dimensions(self, node: ComputationNode) -> list[TemporalLoop]:
"""Get the temporal loops that are outside a CN for this node.
NOTE the order of this list matters! The order in which sub-tiles are generated should match the scheduling
order. First generate all tiles within the same intra-core split (by splitting inter-core).
i.e. tiles with sub-id 0, 1, ..., (nb_inter_tiles - 1) should have the same intra-core split and allocated
to different cores
Args:
node: node for which to return outer-cn loops
Expand All @@ -190,9 +197,8 @@ def get_outer_tmap_loop_dimensions(self, node: ComputationNode) -> list[Temporal
# inter core tiling is not set by CO yet
tiling_to_split = node.intra_core_tiling
else:
# inter core tiling is ok, also split into these tiles
tiling_to_split = node.intra_core_tiling + node.inter_core_tiling

# inter core tiling is ok, also split into these tiles. NOTE: this list is ordered
tiling_to_split = node.inter_core_tiling + node.intra_core_tiling
outer_loops = convert_outer_cn_loops(tiling_to_split, node)

# In case no valid intra core tiling is found: add an arbitrary tiling of size 1
Expand Down Expand Up @@ -223,34 +229,109 @@ def get_non_type_predecessors(self, node: Node, types: list[type]) -> list[Node]
preds += skip_node_preds
return preds

@staticmethod
def get_mandatory_divisors(self, node: ComputationNode) -> dict[LayerDim, set[int]]:
"""Get the factors by which the smaller tiles' dimensions must be divisible.
Tile dimensions must be divisible by all the inter-core tiling factors of the nodes within the same layer stack.
This ensures dependencies between tiles within the stack do not cross the layer stack boundaries.
# TODO can nodes within the same stack have different intra-core tiling? This is not accounted for
"""
# # These divisors accumulate: e.g. if a dim must be divisible by 2 and 4, it must be divisible by 8
# divisors_multiplicative: dict[LayerDim, int] = defaultdict(lambda: 1)

divisors: dict[LayerDim, set[int]] = defaultdict(lambda: set())

# # Must be divisible by inter- and intra-core tiling factors (multiplicative)
# for dim, factor in node.intra_core_tiling + node.inter_core_tiling:
# if isinstance(factor, int):
# divisors_multiplicative[dim] *= factor

# # Multiplied divisors become one lcm divisor
# for dim, factor in divisors_multiplicative.items():
# divisors_lcm[dim].add(factor)

# Must be divisible by inter-core tiling factors of all nodes in the same layer stack (least common multiple)
# Find nodes in stack
try:
curr_stack = next(stack for stack in self.layer_stacks if node.id in stack)
except StopIteration:
# No stack found
return divisors
if len(curr_stack) == 1:
return divisors
other_nodes_in_stack = [
n
for n in self.workload.node_list
if n.id in curr_stack and n.id != node.id and isinstance(n, ComputationNode)
]

for curr_node in other_nodes_in_stack:
assert len(curr_node.inter_core_tiling) == len(
set(dim for dim, _ in curr_node.inter_core_tiling)
), "Inter-core tiling contains duplicate dimensions. The divisors for this node must be multiplied"

for layer_dim, factor in curr_node.inter_core_tiling:
if isinstance(factor, int):
divisors[layer_dim].add(factor)
return divisors

def get_tiles(
original_node: ComputationNode, outer_temporal_loops: list[TemporalLoop]
self,
original_node: ComputationNode,
outer_temporal_loops: list[TemporalLoop],
mandatory_divisors: dict[LayerDim, set[int]] = {},
) -> tuple[list[ComputationNode], list[ComputationNode]]:

# Take away the outer_temporal_loops to create tiled CNs for this node
def get_total_outer_size(dim: LayerDim):
return prod([loop.size for loop in outer_temporal_loops if loop.dimension == dim])

def get_lcm(n: int, divisors: set[int]) -> int:
"""Make n divisible by all the divisors in the set."""
for divisor in divisors:
if n % divisor != 0:
n = ceil(n / divisor) * divisor
return n

def pad_until_divisible(layer_dim: LayerDim, n: int) -> int:
"""Return x >= n such that x is divisible by `total_outer_size`, and `x // total_outer_size` divisible by
all mandatory divisors (coming from the inter-core tiling of other nodes within the same stack)"""
total_outer_size = get_total_outer_size(layer_dim)
inner_size = ceil(n / total_outer_size)
inner_size_padded = get_lcm(inner_size, mandatory_divisors[layer_dim])
x = inner_size_padded * total_outer_size
return x

# Pad the layer_dim_sizes to be divisible by the mandatory divisors (coming from the outer_temporal_loops)
tile_attrs = original_node.extract_node_attr()
for dim, size in tile_attrs.layer_dim_sizes.items():
new_size = pad_until_divisible(dim, size)
if size != new_size:
tile_attrs.layer_dim_sizes[dim] = new_size
logger.warning(f"Padded layer dimension {dim}: {size} -> {new_size} to be divisible by tiling factors")

# Save these extended sizes for later
extended_layer_dim_sizes = deepcopy(tile_attrs.layer_dim_sizes)

# Take away the outer_temporal_loops to create tiled CNs for this node
for loop in outer_temporal_loops:
outer_dim, outer_size = loop.unpack()
node_dim_size: int = tile_attrs.layer_dim_sizes[outer_dim]
q, rem = divmod(node_dim_size, outer_size) # returns x//y, x%y
# Make sure that the outer_dim is divisible by the outer_size
if rem != 0:
# Pad the dimension to a multiple of outer_size
node_dim_size = (q + 1) * outer_size
q += 1

assert rem == 0, "Should be guaranteed through mandatory divisors"
# # Make sure that the outer_dim is divisible by the outer_size
# if rem != 0:
# # Pad the dimension to a multiple of outer_size
# node_dim_size = (q + 1) * outer_size
# q += 1
tile_attrs.layer_dim_sizes[outer_dim] = q

# Reconstruct the total, padded layer_dim_sizes as padded tile size * outer_sizes
extended_layer_dim_sizes = deepcopy(tile_attrs.layer_dim_sizes)
for loop in outer_temporal_loops:
outer_dim, outer_size = loop.unpack()
extended_layer_dim_sizes[outer_dim] *= outer_size
# # Reconstruct the total, padded layer_dim_sizes as padded tile size * outer_sizes
# extended_layer_dim_sizes = deepcopy(tile_attrs.layer_dim_sizes)
# for loop in outer_temporal_loops:
# outer_dim, outer_size = loop.unpack()
# extended_layer_dim_sizes[outer_dim] *= outer_size

# Loop dimension + size of the tiles (called span here)
tile_span = tile_attrs.layer_dim_sizes
loop_dims = original_node.layer_dims
stop_values = [temporal_loop.size for temporal_loop in outer_temporal_loops]
nb_cns = int(prod(stop_values))

Expand Down Expand Up @@ -281,18 +362,16 @@ def get_tiles(
for n in range(nb_cns):
outer_loop_values: list[int] = []
for i, outer_loop in enumerate(outer_temporal_loops):
loop_dim = outer_loop.dimension
stop_value = outer_loop.size
m = prod(stop_values[:i])
outer_loop_values.append(int((n // m) % stop_value))
outer_loop_values.append((n // m) % stop_value)

dim_min_max: LoopRanges = {}
for loop_dim in loop_dims:
# find all outer-cn loops that iterate over this loop_dim
# and multiply their loop values by their mult_factor
for loop_dim in original_node.layer_dims:
# multiply all outer-cn loop values that iterate over this loop_dim by their mult_factor
dim_min = 0
for i, outer_loop in enumerate(outer_temporal_loops):
dim = outer_loop.dimension
stop_value = outer_loop.size
dim, stop_value = outer_loop.unpack()
if dim == loop_dim:
# current loop value of this outer-cn loop
loop_val = outer_loop_values[i]
Expand Down Expand Up @@ -372,6 +451,9 @@ def get_tiles(

@staticmethod
def get_intra_edges(nodes: list[ComputationNode]):
"""
# TODO Why do we need this?
"""
# Get all the group ids
group_ids = sorted(set([n.group for n in nodes]))
intra_edges: list[tuple[ComputationNode, ComputationNode, dict[str, int]]] = []
Expand Down Expand Up @@ -873,7 +955,7 @@ def get_layer_split_factors_k(self):
split_factors[node] = split_factor
return split_factors

def load_cached_tiled_workload(self):
def load_cached_tiled_workload(self) -> ComputationNodeWorkload | None:
if os.path.exists(self.tiled_workload_path):
return pickle_load(self.tiled_workload_path)
return None
Expand Down
12 changes: 9 additions & 3 deletions stream/workload/onnx_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ class ComputationNodeWorkload(DiGraphWrapper[ComputationNode]):
"""Workload graph with only ComputationNodes"""

def get_sink_layer_ids(self):
"""Return the ids of layers where ALL sub-nodes have out-degree 0"""
"""Return the ids of layers where ALL sub-nodes have out-degree 0
# TODO this might nog work yet! When there is intra-core tiling, edges between nodes in the same layer
# TODO (with bits==0) are added, meaning some nodes have an out-degree > 0
# TODO -> use get_real_nb_predecessors instead? or remove the empty intra-core edges?
"""
out_degrees = self.out_degree()
layer_ids = set(n.id for n, _ in out_degrees)
# x: (node, out_degree)
sink_layer_ids = [
all(filter(lambda x: (x[0].id == curr_id and x[1] == 0), out_degrees)) for curr_id in layer_ids
curr_id
for curr_id in layer_ids
# x: (node, out_degree). Filter by id -> map to out_degree == 0 -> check if all are 0
if all(map(lambda x: x[1] == 0, filter(lambda x: x[0].id == curr_id, out_degrees)))
]
return sink_layer_ids

Expand Down

0 comments on commit a9701c6

Please sign in to comment.