Skip to content

Commit

Permalink
use latency_attr to determine latency value to use from cme; add perf…
Browse files Browse the repository at this point in the history
…etto waco visualization
  • Loading branch information
asyms committed Dec 18, 2024
1 parent 7803915 commit 52b4274
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 71 deletions.
9 changes: 8 additions & 1 deletion stream/opt/allocation/constraint_optimization/allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_optimal_allocations(
iterations: int,
gap: float = 0.5,
time_limit: int = 600,
latency_attr: str = "latency_total1",
) -> ALLOCATION_T:
core_ids = sorted((core.id for core in accelerator.cores.node_list if core.id != accelerator.offchip_core_id))
core_capacities = get_core_capacities(accelerator, MemoryOperand("I2"), core_ids)
Expand All @@ -34,7 +35,13 @@ def get_optimal_allocations(
ids = convert_ids(nodes)

latencies, possible_allocation_splits = get_latencies(
nodes, core_ids, accelerator, cost_lut, impossible_lat=0, ids=ids
nodes,
core_ids,
accelerator,
cost_lut,
impossible_lat=0,
ids=ids,
latency_attr=latency_attr,
)
energies = get_energies(nodes, core_ids, accelerator, cost_lut, impossible_energy=0, ids=ids)
output_operand = LayerOperand("O")
Expand Down
103 changes: 102 additions & 1 deletion stream/opt/allocation/constraint_optimization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from zigzag.datatypes import LayerDim, LayerOperand, UnrollFactor

from stream.hardware.architecture.accelerator import Accelerator
from stream.hardware.architecture.core import Core
from stream.utils import CostModelEvaluationLUT
from stream.workload.computation.computation_node import ComputationNode

Expand Down Expand Up @@ -50,6 +51,7 @@ def get_latencies(
cost_lut: CostModelEvaluationLUT,
impossible_lat: float = 1e11,
ids: dict[ComputationNode, int] = {},
latency_attr: str = "latency_total1",
) -> tuple[dict[tuple[int, str, int], int], dict]:
if not ids:
ids = {node: node.id for node in nodes}
Expand All @@ -74,7 +76,7 @@ def get_latencies(
inter_core_tiling_dims = [layer_dim for layer_dim, _ in node.inter_core_tiling]
inter_core_tiling_size = get_loop_size(temporal_loops, inter_core_tiling_dims)
inter_core_tiling_sizes[(node_id, core_name)] = inter_core_tiling_size
lat = cme.latency_total1
lat = getattr(cme, latency_attr)
possible_allocations[node_id].append(core_name)
except ValueError:
lat = impossible_lat
Expand Down Expand Up @@ -130,3 +132,102 @@ def get_energies(
energies[(ids[node], core_name)] = en

return energies


def get_k_splits(allocation):
k_splits: dict[int, list[Core]] = {}
for _, core, id in allocation:
k_splits[id] = k_splits.get(id, []) + [core]
return k_splits


def get_node_latencies(allocation, cost_lut, accelerator, k_splits, latency_attr):
node_latencies = {}
core_names = sorted(set([a for _, a, _ in allocation]))
core_ids = [int(core_name.split(" ")[-1]) for core_name in core_names]
for _, a, id in allocation:
node = next(n for n in cost_lut.get_nodes() if n.id == id[0])
latencies, _ = get_latencies([node], core_ids, accelerator, cost_lut, latency_attr=latency_attr)
nb_k_splits = len(k_splits[id])
lat = latencies[(node.id, a, nb_k_splits)]
node_latencies[id, a] = lat
return node_latencies


def get_layer_ids(allocation):
layer_ids: set[int] = set()
for _, _, id in allocation:
layer_ids.add(id[0])
layer_ids = sorted(layer_ids)
return layer_ids


def get_timesteps(allocation) -> list[int]:
return [item[0] for item in allocation]


def get_resources(allocation) -> set[int]:
return set(item[1] for item in allocation)


def get_node_timesteps(allocation):
node_timesteps = {}
for t, a, id in allocation:
node_timesteps[id, a] = t
return node_timesteps


def get_timestep_latencies(allocation, node_latencies, timesteps):
timestep_latencies = {t: 0 for t in range(max(timesteps) + 1)}
for t, a, id in allocation:
timestep_latencies[t] = max(timestep_latencies.get(t, 0), node_latencies[id, a])
return timestep_latencies


def get_node_start_timesteps(k_splits, node_timesteps, timestep_latencies):
starts = {}
for id, allocations in k_splits.items():
for a in allocations:
start = get_start_time_of_node(id, a, node_timesteps, timestep_latencies)
starts[id, a] = start
return starts


def get_start_time_of_node(id, a, timesteps, timestep_latencies, t_start=0):
node_timestep = timesteps[id, a]
for t in range(node_timestep):
t_end = t_start + timestep_latencies[t]
t_start = t_end
return t_start


def calculate_total_latency(allocation, cost_lut, accelerator, iterations, latency_attr) -> tuple[int, str]:
k_splits = get_k_splits(allocation)
timesteps = get_timesteps(allocation)
node_latencies = get_node_latencies(allocation, cost_lut, accelerator, k_splits, latency_attr)
timestep_latencies = get_timestep_latencies(allocation, node_latencies, timesteps)
node_timesteps = get_node_timesteps(allocation)
starts = get_node_start_timesteps(k_splits, node_timesteps, timestep_latencies)
total_timestep_latency = sum(timestep_latencies.values())
cores = sorted(set(k[1] for k in starts))
overlap = compute_iterations_overlap(timestep_latencies, node_timesteps, starts, total_timestep_latency, cores)
total_lat = iterations * total_timestep_latency - (iterations - 1) * overlap
total_lat_str = f"total_lat = N * T - (N - 1) * overlap --> {total_lat} = {iterations} * {total_timestep_latency} - {iterations-1} * {overlap}"
return total_lat, total_lat_str


def compute_iterations_overlap(timestep_latencies, node_timesteps, starts, T, cores):
slacks = {}
for core in cores:
relevant_starts = [v for k, v in starts.items() if k[1] == core]
earliest_start = min(relevant_starts)
latest_start = max(relevant_starts)
latest_id_core = next((k for k, v in starts.items() if v == latest_start and k[1] == core))
latest_timestep = node_timesteps[latest_id_core]
timestep_latency = timestep_latencies[latest_timestep]
latest_end = latest_start + timestep_latency
slack = T - latest_end + earliest_start
assert slack >= 0
slacks[core] = slack
overlap = min(slacks.values())
return overlap
4 changes: 3 additions & 1 deletion stream/opt/allocation/genetic_algorithm/fitness_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
layer_groups_flexible,
operands_to_prefetch: list[LayerOperand],
scheduling_order: list[tuple[int, int]],
latency_attr: str,
) -> None:
super().__init__(workload, accelerator, cost_lut)

Expand All @@ -44,6 +45,7 @@ def __init__(
self.layer_groups_flexible = layer_groups_flexible
self.operands_to_prefetch = operands_to_prefetch
self.scheduling_order = scheduling_order
self.latency_attr = latency_attr

def get_fitness(self, core_allocations: list[int], return_scme: bool = False):
"""Get the fitness of the given core_allocations
Expand Down Expand Up @@ -88,7 +90,7 @@ def set_node_core_allocations(self, core_allocations: list[int]):
assert equal_unique_node is not None, "Node not found in CostModelEvaluationLUT"
cme = self.cost_lut.get_cme(equal_unique_node, core)
onchip_energy = cme.energy_total # Initialize on-chip energy as total energy
latency = cme.latency_total1
latency = getattr(cme, self.latency_attr)
too_large_operands = get_too_large_operands(cme, self.accelerator, core_id=core_allocation)
# If there is a too_large_operand, we separate the off-chip energy.
offchip_energy = 0
Expand Down
19 changes: 16 additions & 3 deletions stream/stages/allocation/constraint_optimization_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from stream.cost_model.cost_model import StreamCostModelEvaluation
from stream.hardware.architecture.accelerator import Accelerator
from stream.opt.allocation.constraint_optimization.allocation import ALLOCATION_T, get_optimal_allocations
from stream.opt.allocation.constraint_optimization.utils import calculate_total_latency
from stream.stages.estimation.stream_cost_model_evaluation import StreamCostModelEvaluationStage
from stream.stages.estimation.zigzag_core_mapping_estimation import ZigZagCoreMappingEstimationStage
from stream.stages.generation.tiled_workload_generation import (
Expand All @@ -20,7 +21,7 @@
from stream.stages.set_fixed_allocation_performance import SetFixedAllocationPerformanceStage
from stream.stages.stage import MainStage, Stage, StageCallable
from stream.utils import CostModelEvaluationLUT
from stream.visualization.constraint_optimization import visualize_waco
from stream.visualization.constraint_optimization import to_perfetto_json, visualize_waco
from stream.workload.computation.computation_node import ComputationNode
from stream.workload.dnn_workload import DNNWorkloadStream
from stream.workload.mapping import TILING_T
Expand Down Expand Up @@ -192,13 +193,21 @@ def extract_steady_state_per_stack(self):
logger.info(f"Percentage of steady state macs: {nb_steady_state_macs}/{nb_macs} = {percentage_macs:.2f}%")

def find_best_allocation_per_stack(self):
total_ss_latency = 0
for stack, to_compute in self.ss_to_computes.items():
iterations = self.ss_iterations_per_stack[stack]
t_start = time()
optimal_allocation = self.find_best_allocation(to_compute, iterations, stack, self.co_time_limit)
ss_latency, _ = calculate_total_latency(
optimal_allocation, self.cost_lut, self.accelerator, iterations, self.latency_attr
)
t_end = time()
logger.info(f"Stack {stack}: {t_end - t_start:.3f} seconds")
logger.info(
f"Stack {stack}: Optimization took {t_end - t_start:.3f} seconds; Predicted steady-state latency: {ss_latency} cycles"
)
self.optimal_allocation_per_stack[stack] = optimal_allocation
total_ss_latency += ss_latency
logger.info(f"Total steady-state latency across stacks: {total_ss_latency} cycles")

def find_best_allocation(
self, to_compute: set[ComputationNode], iterations: int, stack: STACK_T = (0,), time_limit: int = 600
Expand All @@ -218,10 +227,14 @@ def find_best_allocation(
self.cost_lut,
iterations,
time_limit=time_limit,
latency_attr=self.latency_attr,
)
pickle_save(allocation, stack_allocations_path)
fig_path = stack_allocations_path.replace(".pickle", ".html")
visualize_waco(allocation, self.cost_lut, self.accelerator, fig_path, iterations)
visualize_waco(allocation, self.cost_lut, self.accelerator, iterations, self.latency_attr, fig_path)
json_path = stack_allocations_path.replace(".pickle", ".json")
to_perfetto_json(allocation, self.cost_lut, self.accelerator, iterations, self.latency_attr, json_path)

return allocation

def get_scheduling_order(self, unpartitioned_workload: DNNWorkloadStream) -> SCHEDULE_ORDER_T:
Expand Down
3 changes: 3 additions & 0 deletions stream/stages/allocation/genetic_algorithm_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
nb_ga_individuals: int,
operands_to_prefetch: list[LayerOperand],
scheduling_order: list[tuple[int, int]],
latency_attr: str,
**kwargs: Any,
):
"""Initialize the InterCoreMappingStage.
Expand All @@ -59,6 +60,7 @@ def __init__(
self.nb_individuals = nb_ga_individuals
self.operands_to_prefetch = operands_to_prefetch
self.scheduling_order = scheduling_order
self.latency_attr = latency_attr

# Determine the set of all (layer, group) combinations to be allocated separately
self.layer_groups: list[tuple[int, int]] = sorted(set((n.id, n.group) for n in self.workload.node_list))
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
self.layer_groups_flexible,
self.operands_to_prefetch,
self.scheduling_order,
self.latency_attr,
)

# Extract the length of an individual.
Expand Down
2 changes: 1 addition & 1 deletion stream/stages/set_fixed_allocation_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self.accelerator = accelerator
self.workload = workload
self.cost_lut = cost_lut
self.latency_attr = kwargs.get("latency_attr", "latency_total2")
self.latency_attr = kwargs.get("latency_attr", "latency_total1")

def run(self):
logger.info("Start SetFixedAllocationPerformanceStage.")
Expand Down
Loading

0 comments on commit 52b4274

Please sign in to comment.