Skip to content

Commit

Permalink
update SeachUnusedMemoryStage and keep top weight mem to be a mem tha…
Browse files Browse the repository at this point in the history
…t serves all hardware dims
  • Loading branch information
JiacongSun committed Nov 13, 2023
1 parent 79df7f8 commit aff0de7
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions zigzag/classes/stages/SearchUnusedMemoryStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,16 @@ def update_top_mem_level(self):
if (
const_operand in served_operands
): # identify the top weight mem level
# We need to check if the current mem serve all oa dims, otherwise we will not decrease
# the mem_update_weight.
# The reason is if the current mem not serve all oa dims, the mapping will impact the memory
# utilization, so solely comparing with total memory size will be incorrect.
mem_serve_all_oa_dims = self.check_if_mem_serve_all_oa_dims(
mem, self.accelerator
)
if (
curr_mem_level < self.mem_update_weight
): # mem_update_weight is bigger than the top weight mem level
) and mem_serve_all_oa_dims: # mem_update_weight is bigger than the top weight mem level
self.mem_update_weight = curr_mem_level
break
else: ## node (layer) that is not a branch starting node or a branch final node
Expand Down Expand Up @@ -402,9 +409,18 @@ def update_top_mem_level(self):
self.update_IO_mem_level(
curr_id, output_operand, curr_mem_level
) # update output mem level
# For weight, we need to check if the current mem serve all oa dims, otherwise we will not
# decrease the mem_update_weight.
# The reason is if the current mem not serve all oa dims, the mapping will impact the memory
# utilization, so solely comparing with total memory size will be incorrect.
mem_serve_all_oa_dims = self.check_if_mem_serve_all_oa_dims(
mem, self.accelerator
)
if (
curr_mem_level < self.mem_update_weight
) and mem_serve_weight: # update weight mem level
(curr_mem_level < self.mem_update_weight)
and mem_serve_all_oa_dims
and mem_serve_weight
): # update weight mem level
self.mem_update_weight = curr_mem_level
## [OPTIONAL CHECK] assert check if there is -1 value in mem_update_list
## [NOTE] Until here, if there is still -1 value in mem_update_list, it means the size of top mem level for IO is not big enough.
Expand All @@ -414,6 +430,17 @@ def update_top_mem_level(self):
list(operand_dict.values())[0] >= 0
), "SearchUnusedMemoryStage fisnishes abnormally, there are still layers with top mem levels not figured out."

def check_if_mem_serve_all_oa_dims(self, mem, accelerator):
# check if mem serve all hardare dimensions
core = accelerator.cores[0]
operational_array = core.operational_array
oa_dim_nb = len(operational_array.dimensions)
mem_served_oa_dim_nb = len(mem.served_dimensions)
if mem_served_oa_dim_nb == oa_dim_nb:
return True
else:
return False

def update_mem_level_for_loading_data(self):
"""
[OPTIONAL FUNCTION] This is an optional function.
Expand Down

0 comments on commit aff0de7

Please sign in to comment.