diff --git a/zigzag/stages/SpatialMappingConversionStage.py b/zigzag/stages/SpatialMappingConversionStage.py index 106b3dba..e23e15de 100644 --- a/zigzag/stages/SpatialMappingConversionStage.py +++ b/zigzag/stages/SpatialMappingConversionStage.py @@ -162,7 +162,7 @@ def generate_limited_user_spatial_mapping( # Check 3: Adjust unrolling if it is not a divisor of the layer dimension size # and if there is no more mapping for this layer dimension - no_more_mapping_for_current_layer_dim = self.check_if_there_is_further_oa_mapping_for_current_layer_dim( + no_more_mapping_for_current_layer_dim = self.check_if_oa_dim_mapping_is_first_max( oa_dim=oa_dim, loop_dim_unrolled=layer_dim, user_spatial_mapping=user_spatial_mapping, @@ -240,32 +240,28 @@ def generate_mapping_per_mem_lvl(self, user_spatial_mapping: SpatialMapping) -> mapping_per_mem_lvl[layer_op].append(top_level_spatial_mapping) return mapping_per_mem_lvl - def check_if_there_is_further_oa_mapping_for_current_layer_dim( + def check_if_oa_dim_mapping_is_first_max( self, oa_dim: OADimension, loop_dim_unrolled: LayerDim, user_spatial_mapping: SpatialMapping ): """! For the case when there is layer dimension that is mapped on multiple oa dimensions. We need to decide on which oa dimension to adjust the unrolling if the total unrolling size is not a multiple of the layer dimension size. - In this case, we decide to only adjust the unrolling size on the last oa dimension, - This function is to check if the current oa dimension is the last oa dimension for the current layer dim.""" - start_check_on_layer_dim_mapping = False - no_more_mapping_for_current_layer_dim = True - - for curr_oa_dim, mapping_this_oa_dim in user_spatial_mapping.items(): - if oa_dim == curr_oa_dim: - start_check_on_layer_dim_mapping = True - continue - - if start_check_on_layer_dim_mapping: - for layer_dim, _ in mapping_this_oa_dim.items(): - if loop_dim_unrolled == layer_dim: - no_more_mapping_for_current_layer_dim = False - - # early exit if the flag is already False - if not no_more_mapping_for_current_layer_dim: - break - - return no_more_mapping_for_current_layer_dim + In this case, we decide to only adjust the unrolling size of the first oa dimension with the largest unrolling. + This function is to check if the given oa_dim has the largest unrolling for the given loop_dim_unrolled.""" + + oa_dim_mapping_sizes: list[int] = [] + for mapping in user_spatial_mapping.values(): + layer_dim_mapping_size = mapping[loop_dim_unrolled] if loop_dim_unrolled in mapping.layer_dims else 0 + oa_dim_mapping_sizes.append(layer_dim_mapping_size) + max_mapping_size = max(oa_dim_mapping_sizes) + assert max_mapping_size > 0, f"Given {oa_dim=} is not present in {user_spatial_mapping=}" + first_oa_dim_with_max_mapping = next( + curr_oa_dim + for curr_oa_dim, mapping in user_spatial_mapping.items() + if loop_dim_unrolled in mapping.layer_dims and mapping[loop_dim_unrolled] == max_mapping_size + ) + should_be_limited = oa_dim == first_oa_dim_with_max_mapping + return should_be_limited def calc_unrolled_loop_size_on_early_oa_dims( self, oa_dim: OADimension, loop_dim_unrolled: LayerDim, user_spatial_mapping: SpatialMapping