Skip to content

Commit

Permalink
Merge pull request #42 from JiacongSun/master
Browse files Browse the repository at this point in the history
Fix Incorrect Spatial Mapping Saving and Update If Condition in SpatialMappingGenerateStage
  • Loading branch information
asyms authored Feb 12, 2024
2 parents fc08868 + 4fb6eb0 commit 9032261
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 21 deletions.
19 changes: 0 additions & 19 deletions zigzag/classes/stages/SpatialMappingConversionStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,6 @@ def run(self):
spatial_mapping, spatial_mapping_int = self.convert_user_spatial_mapping(
user_spatial_mapping
)
# Since the spatial_mapping may be modified in the previous step,
# we have to update this change to self.layer
updated_user_spatial_mapping = {}
for oa_dim, sm_loop in user_spatial_mapping.items():
if self.is_nested_tuple(sm_loop): # a mix sm loop
sm_comb = []
for sub_sm_loop in sm_loop:
sm_layer_dim = sub_sm_loop[0]
for sm_element in spatial_mapping.spatial_loop_dim_size:
if sm_element[0] == sm_layer_dim:
sm_comb.append(sm_element)
sm_comb = tuple(sm_comb)
updated_user_spatial_mapping[oa_dim] = sm_comb
else:
sm_layer_dim = sm_loop[0]
for sm_element in spatial_mapping.spatial_loop_dim_size:
if sm_element[0] == sm_layer_dim:
updated_user_spatial_mapping[oa_dim] = sm_element
self.layer.user_spatial_mapping = updated_user_spatial_mapping

kwargs = self.kwargs.copy()
kwargs["spatial_mapping"] = spatial_mapping
Expand Down
5 changes: 3 additions & 2 deletions zigzag/classes/stages/SpatialMappingGeneratorStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,12 @@ def modify_innermost_input_mem_size(self, core_id, user_spatial_mapping):
if layer_op_to_mem_op[act_operand] in mem_ops:
act_innermost_mem_level = memory_level
act_served_oa_dim: set = memory_level.served_dimensions
act_served_oa_dim_name = list(act_served_oa_dim)[0].name
# check if act is not served in the innermost memories, or it is uti-casting for act.
# keep the spatial loop as it was if act is not served.
if "act_served_oa_dim" not in locals() or len(act_served_oa_dim) == 0:
if "act_served_oa_dim" not in locals() or len(act_served_oa_dim) != 1:
return input_mem_size_updated, self.accelerator
else:
act_served_oa_dim_name = list(act_served_oa_dim)[0].name
# get the mem scaling factor if OX, OY exist
mem_scaling_factor = 1
if (
Expand Down

0 comments on commit 9032261

Please sign in to comment.