diff --git a/ai_edge_quantizer/model_modifier.py b/ai_edge_quantizer/model_modifier.py index 434cc7b..07c8198 100644 --- a/ai_edge_quantizer/model_modifier.py +++ b/ai_edge_quantizer/model_modifier.py @@ -102,6 +102,13 @@ def _process_constant_map( buffer_size += len(buffer.data) return buffer_size + def _pad_bytearray(self, bytearr: bytearray): + """Pad the bytearray to 16 bytes.""" + remainder = len(bytearr) % 16 + if remainder != 0: + padding_size = 16 - remainder + bytearr.extend(b'\0' * padding_size) + # TODO: b/333797307 - support > 2GB output model def _serialize_large_model( self, quantized_model: schema_py_generated.ModelT @@ -123,12 +130,11 @@ def _serialize_large_model( buffer.data = None buffer.offset = 1 buffer.size = 1 - dummy_bytearray = flatbuffer_utils.convert_object_to_bytearray( - quantized_model + dummy_bytearray = bytearray( + flatbuffer_utils.convert_object_to_bytearray(quantized_model) ) # calculate the correct buffer size and offset - while len(dummy_bytearray) % 16: - dummy_bytearray += b'\0' + self._pad_bytearray(dummy_bytearray) for buffer_idx, buffer in enumerate(quantized_model.buffers): buffer_data = self._constant_map[buffer_idx] if buffer_data is None: @@ -136,23 +142,20 @@ def _serialize_large_model( buffer.offset = len(dummy_bytearray) buffer.size = len(buffer_data) dummy_bytearray += buffer_data - while len(dummy_bytearray) % 16: - dummy_bytearray += b'\0' + self._pad_bytearray(dummy_bytearray) del dummy_bytearray # build new tflite file with correct buffer offset - model_bytearray = flatbuffer_utils.convert_object_to_bytearray( - quantized_model + model_bytearray = bytearray( + flatbuffer_utils.convert_object_to_bytearray(quantized_model) ) - while len(model_bytearray) % 16: - model_bytearray += b'\0' + self._pad_bytearray(model_bytearray) for buffer_idx, _ in enumerate(quantized_model.buffers): buffer_data = self._constant_map[buffer_idx] if buffer_data is None: continue model_bytearray += buffer_data - while len(model_bytearray) % 16: - model_bytearray += b'\0' + self._pad_bytearray(model_bytearray) return model_bytearray def _serialize_small_model(