From 9bbbc8763ab2a486c91ec7179763f474b50d21ae Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Mon, 30 Sep 2024 09:02:21 -0700 Subject: [PATCH] Raise a8wxdq load errors only when quant scheme is used (#1231) * Show a8wxdq load error only when the quant is used * Update Error check --- torchchat/utils/quantize.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 77b03fcba9..abca48d25e 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -51,6 +51,9 @@ ) +# Flag for whether the a8wxdq quantizer is available. +a8wxdq_load_error: Optional[Exception] = None + ######################################################################### ### torchchat quantization API ### @@ -76,6 +79,10 @@ def quantize_model( quantize_options = json.loads(quantize_options) for quantizer, q_kwargs in quantize_options.items(): + # Test if a8wxdq quantizer is available; Surface error if not. + if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None: + raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}") + if ( quantizer not in quantizer_class_dict and quantizer not in ao_quantizer_class_dict @@ -899,4 +906,4 @@ def quantized_model(self) -> nn.Module: print("Slow fallback kernels will be used.") except Exception as e: - print(f"Failed to load torchao experimental a8wxdq quantizer with error: {e}") + a8wxdq_load_error = e