diff --git a/ai_edge_quantizer/calibrator.py b/ai_edge_quantizer/calibrator.py index f56bdb9..c91066d 100644 --- a/ai_edge_quantizer/calibrator.py +++ b/ai_edge_quantizer/calibrator.py @@ -41,6 +41,7 @@ class Calibrator: def __init__( self, float_tflite: Union[str, bytes], + num_threads: int = 16, ): self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite) @@ -50,7 +51,7 @@ def __init__( " the model (e.g., if it is already quantized)." ) self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter( - float_tflite + float_tflite, use_xnnpack=True, num_threads=num_threads ) # Tensor name to tensor content. self._tensor_content_map: dict[str, Any] = {} diff --git a/ai_edge_quantizer/model_validator.py b/ai_edge_quantizer/model_validator.py index 7ec9bdf..6ab81cc 100644 --- a/ai_edge_quantizer/model_validator.py +++ b/ai_edge_quantizer/model_validator.py @@ -207,7 +207,8 @@ def _setup_validation_interpreter( model: bytes, signature_input: dict[str, Any], signature_key: Optional[str], - use_reference_kernel: bool, + use_xnnpack: bool, + num_threads: int, ) -> tuple[Any, int, dict[str, Any]]: """Setup the interpreter for validation given a signature key. @@ -216,15 +217,15 @@ def _setup_validation_interpreter( signature_input: A dictionary of input tensor name and its value. signature_key: The signature key to be used for invoking the models. If the model only has one signature, this can be set to None. - use_reference_kernel: Whether to use the reference kernel for the - interpreter. + use_xnnpack: Whether to use xnnpack for the interpreter. + num_threads: The number of threads to use for the interpreter. Returns: A tuple of interpreter, subgraph_index and tensor_name_to_details. """ interpreter = utils.create_tfl_interpreter( - tflite_model=model, use_reference_kernel=use_reference_kernel + tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads ) utils.invoke_interpreter_signature( interpreter, signature_input, signature_key @@ -247,7 +248,8 @@ def compare_model( test_data: dict[str, Iterable[dict[str, Any]]], error_metric: str, compare_fn: Callable[[Any, Any], float], - use_reference_kernel: bool = False, + use_xnnpack: bool = True, + num_threads: int = 16, ) -> ComparisonResult: """Compares model tensors over a model signature using the compare_fn. @@ -266,8 +268,8 @@ def compare_model( compare_fn: a comparison function to be used for calculating the statistics, this function must be taking in two ArrayLike strcuture and output a single float value. - use_reference_kernel: Whether to use the reference kernel for the - interpreter. + use_xnnpack: Whether to use xnnpack for the interpreter. + num_threads: The number of threads to use for the interpreter. Returns: A ComparisonResult object. @@ -282,12 +284,17 @@ def compare_model( reference_model, signature_input, signature_key, - use_reference_kernel, + use_xnnpack, + num_threads, ) ) targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = ( _setup_validation_interpreter( - target_model, signature_input, signature_key, use_reference_kernel + target_model, + signature_input, + signature_key, + use_xnnpack, + num_threads, ) ) # Compare the cached tensor values. diff --git a/ai_edge_quantizer/quantizer.py b/ai_edge_quantizer/quantizer.py index cc1765c..670b15c 100644 --- a/ai_edge_quantizer/quantizer.py +++ b/ai_edge_quantizer/quantizer.py @@ -216,6 +216,7 @@ def calibrate( self, calibration_data: dict[str, Iterable[_SignatureInput]], previous_calibration_result: Optional[_CalibrationResult] = None, + num_threads: int = 16, ) -> _CalibrationResult: """Calibrates the float model (required by static range quantization). @@ -223,6 +224,7 @@ def calibrate( calibration_data: Calibration data for a model signature. previous_calibration_result: Previous calibration result to be loaded. The calibration process will be resumed from the previous result. + num_threads: Number of threads to use for calibration. Returns: Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}). @@ -233,7 +235,7 @@ def calibrate( if not self.need_calibration: return {} - calib = calibrator.Calibrator(self.float_model) + calib = calibrator.Calibrator(self.float_model, num_threads=num_threads) if previous_calibration_result is not None: calib.load_model_qsvs(previous_calibration_result) calib.calibrate(calibration_data, self._recipe_manager) @@ -297,7 +299,8 @@ def validate( self, test_data: Optional[dict[str, Iterable[_SignatureInput]]] = None, error_metrics: str = 'mse', - use_reference_kernel: bool = False, + use_xnnpack: bool = True, + num_threads: int = 16, ) -> model_validator.ComparisonResult: """Numerical validation of the quantized model for a model signature. @@ -314,7 +317,8 @@ def validate( data that will be used for validation. If set to None, random normal distributed data will be used for all signatures in the model. error_metrics: Error metrics to be used for comparison. - use_reference_kernel: Whether to use the reference kernel for validation. + use_xnnpack: Whether to use the xnnpack library for validation. + num_threads: Number of threads to use for validation. Returns: The comparison result. @@ -330,7 +334,8 @@ def validate( test_data, error_metrics, validation_utils.get_validation_func(error_metrics), - use_reference_kernel=use_reference_kernel, + use_xnnpack=use_xnnpack, + num_threads=num_threads, ) def _get_quantization_params( diff --git a/ai_edge_quantizer/utils/tfl_interpreter_utils.py b/ai_edge_quantizer/utils/tfl_interpreter_utils.py index 4d46874..a27a326 100644 --- a/ai_edge_quantizer/utils/tfl_interpreter_utils.py +++ b/ai_edge_quantizer/utils/tfl_interpreter_utils.py @@ -30,15 +30,16 @@ def create_tfl_interpreter( tflite_model: Union[str, bytes], allocate_tensors: bool = True, - use_reference_kernel: bool = False, + use_xnnpack: bool = True, + num_threads: int = 16, ) -> tfl.Interpreter: """Creates a TFLite interpreter from a model file. Args: tflite_model: Model file path or bytes. allocate_tensors: Whether to allocate tensors. - use_reference_kernel: Whether to use the reference kernel for the - interpreter. + use_xnnpack: Whether to use the XNNPACK delegate for the interpreter. + num_threads: The number of threads to use for the interpreter. Returns: A TFLite interpreter. @@ -47,12 +48,13 @@ def create_tfl_interpreter( with gfile.GFile(tflite_model, "rb") as f: tflite_model = f.read() - if use_reference_kernel: - op_resolver = tfl.OpResolverType.BUILTIN_REF + if use_xnnpack: + op_resolver = tfl.OpResolverType.BUILTIN else: op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES tflite_interpreter = tfl.Interpreter( model_content=bytes(tflite_model), + num_threads=num_threads, experimental_op_resolver_type=op_resolver, experimental_preserve_all_tensors=True, )