diff --git a/zigzag/classes/stages/SpatialMappingConversionStage.py b/zigzag/classes/stages/SpatialMappingConversionStage.py index 617921e7..c6d29891 100644 --- a/zigzag/classes/stages/SpatialMappingConversionStage.py +++ b/zigzag/classes/stages/SpatialMappingConversionStage.py @@ -55,25 +55,6 @@ def run(self): spatial_mapping, spatial_mapping_int = self.convert_user_spatial_mapping( user_spatial_mapping ) - # Since the spatial_mapping may be modified in the previous step, - # we have to update this change to self.layer - updated_user_spatial_mapping = {} - for oa_dim, sm_loop in user_spatial_mapping.items(): - if self.is_nested_tuple(sm_loop): # a mix sm loop - sm_comb = [] - for sub_sm_loop in sm_loop: - sm_layer_dim = sub_sm_loop[0] - for sm_element in spatial_mapping.spatial_loop_dim_size: - if sm_element[0] == sm_layer_dim: - sm_comb.append(sm_element) - sm_comb = tuple(sm_comb) - updated_user_spatial_mapping[oa_dim] = sm_comb - else: - sm_layer_dim = sm_loop[0] - for sm_element in spatial_mapping.spatial_loop_dim_size: - if sm_element[0] == sm_layer_dim: - updated_user_spatial_mapping[oa_dim] = sm_element - self.layer.user_spatial_mapping = updated_user_spatial_mapping kwargs = self.kwargs.copy() kwargs["spatial_mapping"] = spatial_mapping diff --git a/zigzag/classes/stages/SpatialMappingGeneratorStage.py b/zigzag/classes/stages/SpatialMappingGeneratorStage.py index 8f7bf9fc..271f9350 100644 --- a/zigzag/classes/stages/SpatialMappingGeneratorStage.py +++ b/zigzag/classes/stages/SpatialMappingGeneratorStage.py @@ -215,11 +215,12 @@ def modify_innermost_input_mem_size(self, core_id, user_spatial_mapping): if layer_op_to_mem_op[act_operand] in mem_ops: act_innermost_mem_level = memory_level act_served_oa_dim: set = memory_level.served_dimensions - act_served_oa_dim_name = list(act_served_oa_dim)[0].name # check if act is not served in the innermost memories, or it is uti-casting for act. # keep the spatial loop as it was if act is not served. - if "act_served_oa_dim" not in locals() or len(act_served_oa_dim) == 0: + if "act_served_oa_dim" not in locals() or len(act_served_oa_dim) != 1: return input_mem_size_updated, self.accelerator + else: + act_served_oa_dim_name = list(act_served_oa_dim)[0].name # get the mem scaling factor if OX, OY exist mem_scaling_factor = 1 if (