diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 77b03fcba9..abca48d25e 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -51,6 +51,9 @@ ) +# Flag for whether the a8wxdq quantizer is available. +a8wxdq_load_error: Optional[Exception] = None + ######################################################################### ### torchchat quantization API ### @@ -76,6 +79,10 @@ def quantize_model( quantize_options = json.loads(quantize_options) for quantizer, q_kwargs in quantize_options.items(): + # Test if a8wxdq quantizer is available; Surface error if not. + if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None: + raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}") + if ( quantizer not in quantizer_class_dict and quantizer not in ao_quantizer_class_dict @@ -899,4 +906,4 @@ def quantized_model(self) -> nn.Module: print("Slow fallback kernels will be used.") except Exception as e: - print(f"Failed to load torchao experimental a8wxdq quantizer with error: {e}") + a8wxdq_load_error = e