diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ac354e6d7..3e34f1d465 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,4 +26,5 @@ repos: alias: ruff-isolated args: - --isolated - - select F821,F823,W191 + - --select + - F821,F823,W191 diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index b13f1d16a5..15166aca0d 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -426,6 +426,12 @@ def from_float(cls, weight): # avoid circular dep from torchao.dtypes import to_affine_quantized_intx + from torchao.quantization.quant_api import ( + _int8_symm_per_token_reduced_range_quant, + ) + + # input settings + input_quant_func = _int8_symm_per_token_reduced_range_quant # weight settings mapping_type = MappingType.SYMMETRIC @@ -436,32 +442,9 @@ def get_weight_block_size(x): target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - - # input settings - def get_per_token_block_size(x): - block_size = list(x.shape) - for i in range(len(block_size) - 1): - block_size[i] = 1 - return block_size - - input_mapping_type = MappingType.SYMMETRIC - input_target_dtype = torch.int8 - input_eps = 1e-5 - input_quant_min = -127 - input_quant_max = 127 _layout = cls.layout - input_quant_func = lambda x: to_affine_quantized_intx( - x, - input_mapping_type, - get_per_token_block_size(x), - input_target_dtype, - eps=input_eps, - quant_min=input_quant_min, - quant_max=input_quant_max, - scale_dtype=torch.float32 if x.dtype == torch.float16 else None, - ) - block_size = get_weight_block_size(weight) + weight = to_affine_quantized_intx( weight, mapping_type, @@ -937,6 +920,7 @@ def get_per_token_block_size(x): input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) + # TODO: make this serializable input_quant_func = lambda x: _input_activation_quant_func_fp8( x=x, activation_granularity=cls.activation_granularity, @@ -980,6 +964,7 @@ def get_weight_block_size(x): input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) + # TODO: make this serializable input_quant_func = lambda x: _input_activation_quant_func_fp8( x=x, activation_granularity=cls.activation_granularity, @@ -1287,3 +1272,7 @@ def finalize_autoquant(): model(*example_input) return model + + +if TORCH_VERSION_AT_LEAST_2_5: + torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST)