Skip to content

Commit

Permalink
Reduce quantization time for large models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689148013
  • Loading branch information
v-dziuba authored and copybara-github committed Oct 23, 2024
1 parent 729b76e commit 1d6688a
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions ai_edge_quantizer/model_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def _process_constant_map(
buffer_size += len(buffer.data)
return buffer_size

def _pad_bytearray(self, bytearr: bytearray):
"""Pad the bytearray to 16 bytes."""
remainder = len(bytearr) % 16
if remainder != 0:
padding_size = 16 - remainder
bytearr.extend(b'\0' * padding_size)

# TODO: b/333797307 - support > 2GB output model
def _serialize_large_model(
self, quantized_model: schema_py_generated.ModelT
Expand All @@ -123,36 +130,32 @@ def _serialize_large_model(
buffer.data = None
buffer.offset = 1
buffer.size = 1
dummy_bytearray = flatbuffer_utils.convert_object_to_bytearray(
quantized_model
dummy_bytearray = bytearray(
flatbuffer_utils.convert_object_to_bytearray(quantized_model)
)
# calculate the correct buffer size and offset
while len(dummy_bytearray) % 16:
dummy_bytearray += b'\0'
self._pad_bytearray(dummy_bytearray)
for buffer_idx, buffer in enumerate(quantized_model.buffers):
buffer_data = self._constant_map[buffer_idx]
if buffer_data is None:
continue
buffer.offset = len(dummy_bytearray)
buffer.size = len(buffer_data)
dummy_bytearray += buffer_data
while len(dummy_bytearray) % 16:
dummy_bytearray += b'\0'
self._pad_bytearray(dummy_bytearray)
del dummy_bytearray

# build new tflite file with correct buffer offset
model_bytearray = flatbuffer_utils.convert_object_to_bytearray(
quantized_model
model_bytearray = bytearray(
flatbuffer_utils.convert_object_to_bytearray(quantized_model)
)
while len(model_bytearray) % 16:
model_bytearray += b'\0'
self._pad_bytearray(model_bytearray)
for buffer_idx, _ in enumerate(quantized_model.buffers):
buffer_data = self._constant_map[buffer_idx]
if buffer_data is None:
continue
model_bytearray += buffer_data
while len(model_bytearray) % 16:
model_bytearray += b'\0'
self._pad_bytearray(model_bytearray)
return model_bytearray

def _serialize_small_model(
Expand Down

0 comments on commit 1d6688a

Please sign in to comment.