From 04a25e7a877a6865185961927bf379bbebcbcf6f Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 26 Nov 2024 19:39:12 -0800 Subject: [PATCH 01/40] Benchamarking (#1353) Benchamarking (#1353) Summary: Add benchmarks for experimental torchao kernels. Differential Revision: D66512859 --- torchao/_models/llama/generate.py | 25 ++++++++++++++++- torchao/experimental/temp_build.py | 43 ++++++++++++++++++++++++++++++ torchao/quantization/README.md | 9 +++++++ 3 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 torchao/experimental/temp_build.py diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 862f5d186d..3e32ee356e 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -217,7 +217,6 @@ def main( float8_weight_only, float8_dynamic_activation_float8_weight, ) - from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 from torchao.utils import unwrap_tensor_subclass from torchao.quantization.granularity import PerTensor, PerRow @@ -297,6 +296,29 @@ def main( dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + elif "int8_dynamic_activation_intx_weight" in quantization: + from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight + assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires fp32 precision" + + # Build kernels in temp location, and load them in torch + # This requires an ARM CPU + from torchao.experimental.temp_build import temp_build_and_load_torchao_ops + temp_build_and_load_torchao_ops(cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + "/../../experimental") + + # Quantize model + _quant_args = quantization.split("-") + nbit = int(_quant_args[1]) + assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8" + group_size = int(_quant_args[2]) + has_weight_zeros = bool(_quant_args[3]) + quantize_( + model, + int8_dynamic_activation_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + ), + ) elif "float8wo" in quantization: quantize_(model, float8_weight_only()) elif "float8dq" in quantization: @@ -309,6 +331,7 @@ def main( granularity = PerTensor() quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity)) elif "autoquant_v2" in quantization: + from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 from torchao._models._eval import InputRecorder from torchao._models.llama.model import prepare_inputs_for_model diff --git a/torchao/experimental/temp_build.py b/torchao/experimental/temp_build.py new file mode 100644 index 0000000000..fb9d413037 --- /dev/null +++ b/torchao/experimental/temp_build.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import subprocess +import tempfile +import torch + +def cmake_build_torchao_ops(cmake_lists_path, temp_build_dir): + from distutils.sysconfig import get_python_lib + print("Building torchao ops for ATen target") + cmake_prefix_path = get_python_lib() + subprocess.run( + [ + "cmake", + "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, + "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, + "-S " + cmake_lists_path, + "-B " + temp_build_dir.name, + ] + ) + subprocess.run( + [ + "cmake", + "--build", + temp_build_dir.name, + "-j 16", + "--target install", + "--config Release", + ] + ) + +def temp_build_and_load_torchao_ops(cmake_lists_path): + temp_build_dir = tempfile.TemporaryDirectory() + cmake_build_torchao_ops(cmake_lists_path, temp_build_dir) + libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") + libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) + assert len(libs) == 1 + torch.ops.load_library(libs[0]) + print(f"TorchAO ops are loaded from {libs[0]}") diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 022fe7d916..3c2eeb08f6 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -333,7 +333,16 @@ We're trying to develop kernels for low bit quantization for intx quantization f You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. +### int8_dynamic_activation_intx_weight Quantization +We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ------------- | -------------------------------------------------| --------------| ------------------------| ---------------- | ----------------| +| Llama-3.1-8B | Base (bfloat16) | 1.24 | 18.62 | NA | 15.01 | +| | int8_dynamic_activation_intx_weight-4-256-false | 16.03 | 65.81 | NA | 4.11 | +| | int8_dynamic_activation_intx_weight-3-256-false | 18.94 | 59.97 | NA | 3.17 | + +You try can out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. ### Automatic Inductor Configuration The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. From 543209b0592b6116fd16420ef7f24fa102a4998f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 26 Nov 2024 19:54:33 -0800 Subject: [PATCH 02/40] Add floating point options for autoquant and add accuracy measurement (#1355) * Add floating point options for autoquant and add accuracy measurement Summary: * This PR adds float32/float16/bfloat16 as a list of options for autoquant, it converts input/weight/bias/output to the specified dtype * Also adds min_sqnr (https://en.wikipedia.org/wiki/Signal-to-quantization-noise_ratio) to allow users to filter out the quantization methods that has large numerical impact compared to original output Note that we use random generated input activation right now, we can improve this by adding the support for using real inputs Test Plan: python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-fp Reviewers: Subscribers: Tasks: Tags: * update docstring * fix * ruff * skip if no cuda --- examples/sam2_amg_server/server.py | 23 +++- test/integration/test_integration.py | 17 +++ torchao/_models/llama/generate.py | 2 + torchao/quantization/__init__.py | 2 + torchao/quantization/autoquant.py | 165 ++++++++++++++++++++++++--- 5 files changed, 190 insertions(+), 19 deletions(-) diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index d779411c93..066b339c21 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -371,6 +371,7 @@ def main(checkpoint_path, baseline=False, fast=False, furious=False, + use_autoquant=False, unittest=False, benchmark=False, profile=None, @@ -399,13 +400,13 @@ def main(checkpoint_path, from torchao._models.sam2.build_sam import build_sam2 from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from torchao._models.sam2.utils.amg import rle_to_mask - + device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) - + logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}") sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) - + logging.info(f"Using {points_per_batch} points_per_batch") mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") @@ -416,6 +417,18 @@ def main(checkpoint_path, if furious: set_furious(mask_generator) + # since autoquant is replicating what furious mode is doing, don't use these two together + elif use_autoquant: + from torchao import autoquant + from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + mask_generator.predictor.model = autoquant(mask_generator.predictor.model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + + mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40) + # NOTE: Not baseline feature + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision('high') + + with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) @@ -494,7 +507,7 @@ async def upload_rle(image: UploadFile = File(...)): await request_queue.put((image_tensor, response_future)) masks = await response_future return masks_to_rle_dict(masks) - + @app.post("/upload") async def upload_image(image: UploadFile = File(...)): image_tensor = file_bytes_to_image_tensor(bytearray(await image.read())) @@ -512,7 +525,7 @@ async def upload_image(image: UploadFile = File(...)): plt.savefig(buf, format='png') buf.seek(0) return StreamingResponse(buf, media_type="image/png") - + # uvicorn.run(app, host=host, port=port, log_level="info") uvicorn.run(app, host=host, port=port) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index ac2403d6dc..663db20b7b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1514,6 +1514,23 @@ def forward(self, x): assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) model(x_in) + @parameterized.expand(list(itertools.product(["cuda"], COMMON_DTYPES))) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_autoquant_min_sqnr(self, device, dtype): + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + out = model(example_input) + torchao.autoquant(model, min_sqnr=60) + out2 = model(example_input) + sqnr = SQNR(out, out2) + # without setting min_sqnr to 60, we get around 45-50 final sqnr + # setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr + self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 3e32ee356e..550c10febb 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -402,6 +402,8 @@ def main( model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) elif "autoquant-float8" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + if "autoquant-fp" == quantization: + model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs) else: model = autoquant(model, manual=True, example_input=inputs) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ff66e23cc9..344bdeea41 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -11,6 +11,7 @@ from .autoquant import ( DEFAULT_AUTOQUANT_CLASS_LIST, + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, autoquant, @@ -89,6 +90,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 87cb5e2655..1731b6cf39 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -18,7 +18,10 @@ MappingType, ZeroPointDomain, ) -from torchao.quantization.utils import quantize_activation_per_token_absmax +from torchao.quantization.utils import ( + compute_error, + quantize_activation_per_token_absmax, +) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 from .granularity import ( @@ -36,6 +39,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", ] @@ -69,7 +73,15 @@ class AutoQuantizableLinearWeight(torch.Tensor): """ @staticmethod - def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + def __new__( + cls, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, + ): kwargs["device"] = weight.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else weight.layout @@ -82,12 +94,19 @@ def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwarg return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( - self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs + self, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, ): self.weight = weight self.qtensor_class_list = qtensor_class_list self.logged_data = {} self.mode = mode + self.min_sqnr = min_sqnr def __repr__(self): return ( @@ -123,9 +142,25 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): else torch.randn(bias_shape, dtype=act_dtype, device=self.device) ) try: - res = q_cls._autoquant_test( - act_mat, self.weight, bias, best_time, self.mode + ref_output = AQDefaultLinearWeight._quantized_linear_op( + act_mat, self.weight, bias ) + q_output = q_cls._quantized_linear_op( + act_mat, q_cls.from_float(self.weight), bias + ) + if ( + self.min_sqnr is not None + and (sqnr := compute_error(q_output, ref_output)) + < self.min_sqnr + ): + print( + f"skipping q_cls: {q_cls} because the sqnr is too small, minimum expected sqnr: {self.min_sqnr}, got {sqnr}" + ) + res = torch.inf + else: + res = q_cls._autoquant_test( + act_mat, self.weight, bias, best_time, self.mode + ) except Exception as e: print( f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}" @@ -141,7 +176,7 @@ def to_quantized(self, error_on_unseen, **kwargs): ) elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen - self = AQFloatLinearWeight.from_float(self.weight) + self = AQDefaultLinearWeight.from_float(self.weight) return self # only want to print shape (at start) and final result (at end) @@ -194,34 +229,49 @@ def count_shapes(self, do_print=True): print( f">time (all shapes): {cur_time:0.4f}ms for {q_cls}, prev_best: {best_time:0.4f}ms" ) - if best_time >= cur_time: + if cur_time != torch.inf and best_time >= cur_time: best_time = cur_time best_cls = q_cls # if no new benchmarking was done, don't print the final result, it will be the same as for another layer if ran_new_benchmarks: print(f"best_cls={best_cls}\n") + + if best_cls is None: + best_cls = AQDefaultLinearWeight + # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) return self def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode + fn(self.weight), + self.qtensor_class_list, + dtype=self.dtype, + mode=self.mode, + min_sqnr=self.min_sqnr, ) def __tensor_flatten__(self): - return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] + return ["weight"], [ + self.qtensor_class_list, + self.mode, + self.min_sqnr, + self.dtype, + self.shape, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): weight = tensor_data_dict["weight"] - qtensor_class_list, mode, dtype, shape = tensor_attributes + qtensor_class_list, mode, min_sqnr, dtype, shape = tensor_attributes return cls( weight, qtensor_class_list, - mode, + mode=mode, + min_sqnr=min_sqnr, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride, @@ -608,7 +658,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 -class AQFloatLinearWeight(torch.Tensor, AQMixin): +class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the @@ -629,6 +679,81 @@ def from_float(cls, weight): return weight +class AQFloat32LinearWeight(torch.Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(torch.float32), + w_qtensor, + bias.to(torch.float32) if bias is not None else bias, + ).to(dtype=orig_dtype) + + @classmethod + def from_float(cls, weight): + return weight.to(torch.float32) + + +class AQBFloat16LinearWeight(torch.Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(torch.bfloat16), + w_qtensor, + bias.to(torch.bfloat16) if bias is not None else bias, + ).to(dtype=orig_dtype) + + @classmethod + def from_float(cls, weight): + return weight.to(torch.bfloat16) + + +class AQFloat16LinearWeight(torch.Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(torch.float16), + w_qtensor, + bias.to(torch.float16) if bias is not None else bias, + ).to(dtype=orig_dtype) + + @classmethod + def from_float(cls, weight): + return weight.to(torch.float16) + + class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn @@ -742,7 +867,7 @@ def get_weight_block_size(x): # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, # AQInt8WeightOnlyQuantizedLinearWeight3, @@ -751,11 +876,17 @@ def get_weight_block_size(x): ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt4G64WeightOnlyQuantizedLinearWeight, ] +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, @@ -779,6 +910,7 @@ def _change_linears_to_autoquantizable(model, **kwargs): "qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST ) kwargs["mode"] = kwargs.get("mode", ["relu", None]) + kwargs["min_sqnr"] = kwargs.get("min_sqnr", None) from torchao.quantization.quant_api import ( _get_subclass_inserter, _replace_with_custom_fn_if_matches_filter, @@ -853,6 +985,7 @@ def autoquant( manual=False, set_inductor_config=True, supress_autoquant_errors=True, + min_sqnr=None, **aq_kwargs, ): """ @@ -887,6 +1020,9 @@ def autoquant( the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True) + min_sqnr (float, optional): minimum acceptable signal to quantization noise ration (https://en.wikipedia.org/wiki/Signal-to-quantization-noise_ratio) for output of quantized layer v.s. non-quantized layer, this is used to filter + out quantization methods that causes too large numerical impact, user can start with a resaonable + number like 40 and adjust depending on the result **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -919,6 +1055,7 @@ def autoquant( filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, + min_sqnr=min_sqnr, **aq_kwargs, ) From 719440e208e3444fa905522e49b458f51f05626b Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Wed, 27 Nov 2024 13:23:48 +0800 Subject: [PATCH 03/40] Add Int4CPULayout and update int4 woq (#1278) * Add Int4CPULayout and update int4 woq * Apply automatic Ruff fixes * Fix CI * Remote nightly * Apply automatic Ruff fixes --------- Co-authored-by: github-actions[bot] --- .github/workflows/regression_test.yml | 1 + test/dtypes/test_affine_quantized.py | 55 ++-- test/integration/test_integration.py | 18 +- test/quantization/test_quant_primitives.py | 10 +- torchao/dtypes/__init__.py | 2 + torchao/dtypes/uintx/__init__.py | 2 + .../dtypes/uintx/tensor_core_tiled_layout.py | 266 +++++++++++++++++- torchao/prototype/hqq/README.md | 2 +- torchao/prototype/hqq/hqq_tinygemm_linear.py | 28 +- torchao/quantization/GPTQ.py | 86 ++++-- torchao/quantization/qat/linear.py | 20 +- torchao/quantization/quant_api.py | 3 +- torchao/quantization/subclass.py | 30 +- torchao/quantization/utils.py | 18 +- 14 files changed, 448 insertions(+), 93 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index d9649b7f7e..0488e6d922 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -70,6 +70,7 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" + - name: CPU 2.3 runs-on: linux.4xlarge torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index e049500e3b..9e9144c601 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,7 +8,7 @@ run_tests, ) -from torchao.dtypes import SemiSparseLayout +from torchao.dtypes import Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, int4_weight_only, @@ -17,12 +17,12 @@ int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -def get_quantization_functions(do_sparse: bool, do_int4: bool): +def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): base_functions = [ int8_weight_only(), int8_dynamic_activation_int4_weight(), @@ -30,7 +30,12 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: - base_functions.append(int4_weight_only(group_size=32)) + if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: + base_functions.append( + int4_weight_only(group_size=32, layout=Int4CPULayout()) + ) + else: + base_functions.append(int4_weight_only(group_size=32)) if do_sparse: base_functions.append( @@ -152,30 +157,28 @@ class TestAffineQuantizedBasic(TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.bfloat16] - @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_flatten_unflatten(self, apply_quant, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(linear) - lp_tensor = ql.weight - tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = { - name: getattr(lp_tensor, name) for name in tensor_data_name_dict - } - outer_size = lp_tensor.size() - outer_stride = lp_tensor.stride() - reconstructed = type(lp_tensor).__tensor_unflatten__( - tensor_data_dict, tensor_attributes, outer_size, outer_stride - ) - example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) - ref = ql(*example_inputs) - ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) - reconstruct_res = ql(*example_inputs) - self.assertEqual(reconstruct_res, ref) + def test_flatten_unflatten(self, device, dtype): + apply_quant_list = get_quantization_functions(False, True, device) + for apply_quant in apply_quant_list: + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + ql = apply_quant(linear) + lp_tensor = ql.weight + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } + outer_size = lp_tensor.size() + outer_stride = lp_tensor.stride() + reconstructed = type(lp_tensor).__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) + example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) + ref = ql(*example_inputs) + ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) + reconstruct_res = ql(*example_inputs) + self.assertEqual(reconstruct_res, ref) common_utils.instantiate_parametrized_tests(TestAffineQuantized) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 663db20b7b..df20c5f03b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -19,7 +19,7 @@ from torchao.quantization.dynamic_quant import ( DynamicallyPerAxisQuantizedLinear, ) -from torchao.dtypes import TensorCoreTiledLayout +from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only, @@ -93,6 +93,7 @@ is_fbcode, benchmark_model ) +from torchao.dtypes.utils import is_device logger = logging.getLogger("INFO") @@ -133,7 +134,10 @@ def _int8da_int8w_api(mod): change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: + if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False) + unwrap_tensor_subclass(mod) + elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(), set_inductor_config=False) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) @@ -935,10 +939,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") + layout_list = [] + if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6: + layout_list.append(Int4CPULayout()) + else: + for inner_k_tiles in [4, 2]: + layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)) for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])): for groupsize in [64, 32]: - for inner_k_tiles in [4, 2]: - kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)} + for layout in layout_list: + kwargs = {"groupsize": groupsize, "layout": layout} def api(mod): kwargs_copy = kwargs.copy() diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 4e0663eb87..78556772d1 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -33,6 +33,7 @@ TORCH_VERSION_AT_LEAST_2_6, is_fbcode, ) +from torchao.dtypes.utils import is_device _SEED = 1234 torch.manual_seed(_SEED) @@ -102,7 +103,8 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) if TORCH_VERSION_AT_LEAST_2_5: - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 @@ -524,8 +526,10 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): groupsize = 128 if TORCH_VERSION_AT_LEAST_2_5: - input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize) + input_tmp = input + if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize) else: w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d1fbacdcb4..00305db348 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,6 +16,7 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( BlockSparseLayout, + Int4CPULayout, MarlinQQQLayout, MarlinSparseLayout, SemiSparseLayout, @@ -48,4 +49,5 @@ "UintxLayout", "MarlinQQQTensor", "MarlinQQQLayout", + "Int4CPULayout", ] diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index a6059f93a3..8fba2bb678 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -11,6 +11,7 @@ SemiSparseLayout, ) from .tensor_core_tiled_layout import ( + Int4CPULayout, TensorCoreTiledLayout, ) from .uintx_layout import ( @@ -23,5 +24,6 @@ "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", + "Int4CPULayout", "MarlinQQQLayout", ] diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index ced3fc8922..df79b653e8 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -13,7 +13,12 @@ ) from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, fill_defaults, find_multiple +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + fill_defaults, + find_multiple, +) aten = torch.ops.aten @@ -71,9 +76,14 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) + if is_device(input_tensor.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] @@ -383,3 +393,251 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout(self) -> Layout: return self._layout + + +@dataclass(frozen=True) +class Int4CPULayout(Layout): + """Only for PyTorch version at least 2.6""" + + pass + + +@register_layout(Int4CPULayout) +class Int4CPUAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm_for_cpu` + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + (unpacked Tensor shape is n * k) + Note: we also pack scale and zero point together here for tinygemm kernel + Note: technically Int4 CPU layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Int4CPULayout) + + if TORCH_VERSION_AT_LEAST_2_6: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) + elif TORCH_VERSION_AT_LEAST_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + else: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if not is_device(torch.device(self.device).type, device): + raise ValueError( + f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device), + self.transposed, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = Int4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + + cur_shape = self.shape + assert len(cur_shape) == 2 + original_shape = (cur_shape[0], cur_shape[1] * 2) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + zero_point_domain = ZeroPointDomain.FLOAT + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 8bf1d34260..1bdbcd96e1 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f - Times are in `ms`, see `benchmarks/benchmark_hqq.py`. - `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). -- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. GPU details: diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 8abdad039a..743c6128a7 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -12,7 +12,8 @@ from hqq.core.utils import * import torch.nn.functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.dtypes.utils import is_device class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - W_q_torch, self.inner_k_tiles - ) + if is_device(W_q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + W_q_torch, self.inner_k_tiles + ) + else: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) del W_q_torch, scales_torch, zeros_torch @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants( .contiguous() ) if TORCH_VERSION_AT_LEAST_2_5: - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val @@ -232,9 +239,14 @@ def pack_scales_and_zeros(self, scales, zeros): def matmul(self, x): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x, self.weight_int4pack, self.groupsize, self.scales_and_zeros - ) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) + else: + c = torch.ops.aten._weight_int4pack_mm( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) new_shape = origin_x_size[:-1] + (self.out_features,) c = c.reshape(new_shape) return c diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index dc68f59ceb..c169271e8f 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -17,8 +17,10 @@ import torch.nn.functional as F from torch.utils._pytree import tree_flatten, tree_unflatten +from torchao.dtypes.utils import is_device from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_6, find_multiple, ) @@ -537,12 +539,20 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x.to(precision), - weight_int4pack, - groupsize, - scales_and_zeros.to(scales_precision), - ).to(dtype=x.dtype) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) + else: + c = torch.ops.aten._weight_int4pack_mm( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -591,19 +601,32 @@ def __init__( assert ( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.zeros( - ( - out_features // 8, - in_features // (inner_k_tiles * 16), - 32, - inner_k_tiles // 2, + if is_device(device.type, "cpu"): + self.register_buffer( + "weight", + torch.zeros( + ( + out_features, + in_features // 2, + ), + dtype=torch.uint8, + device=device, ), - dtype=torch.int32, - device=device, - ), - ) + ) + else: + self.register_buffer( + "weight", + torch.zeros( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + device=device, + ), + ) self.dtype = dtype self.register_buffer( "scales_and_zeros", @@ -760,9 +783,19 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - w_int4x8.to(self.device), self.inner_k_tiles - ) + if ( + is_device(w_int4x8.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + weight_int4pack = ( + torch.ops.aten._convert_weight_to_int4pack_for_cpu( + w_int4x8.to(self.device), self.inner_k_tiles + ) + ) + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + w_int4x8.to(self.device), self.inner_k_tiles + ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( self.device @@ -846,9 +879,14 @@ def make_names_and_values_dict_func(q, qparams): # how much we need to pad the weight delta_k = int((new_k - k) / 2) q = q.to(self.device) - final_q = torch.ops.aten._convert_weight_to_int4pack( - F.pad(q, pad=(0, delta_k)), inner_k_tiles - ) + if is_device(self.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) + else: + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index cbe6296407..d5f2dca5b4 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F +from torchao.dtypes.utils import is_device from torchao.quantization.GPTQ import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, @@ -23,6 +24,7 @@ ) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 from .api import FakeQuantizeConfig from .fake_quantizer import FakeQuantizer @@ -363,6 +365,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): inner_k_tiles=inner_k_tiles, precision=child.weight.dtype, scales_precision=config.scale_precision, + device=next(child.parameters()).device, ) setattr(module, name, quantized_linear) @@ -373,10 +376,19 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), - child.inner_k_tiles, - ) + if ( + is_device(q_weight.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) + else: + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c730ec9046..ddeb4ef2fb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -630,7 +630,8 @@ def int4_weight_only( "tensor_core_tiled" layout for speedup with tinygemm kernel Note: - This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference + This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm` + and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference of quantization algorithm compared to the more traditional type of integer quantization is the following: 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 036109bc8d..9715d99e08 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -8,6 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import is_device from torchao.quantization.utils import ( dequantize_per_channel, dynamically_quantize_per_channel, @@ -15,7 +16,7 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from torchao.utils import find_multiple +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple __all__ = [ "Int8DynamicallyQuantizedLinearWeight", @@ -458,12 +459,20 @@ def _quantized_op(act_mat, w_qtensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # matmul - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) + if is_device(act_mat.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) + else: + y = aten._weight_int4pack_mm( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) # remove out_feature padding orig_out_features = ( @@ -609,5 +618,10 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( input_float, 4, groupsize, dtype=input_float.dtype ) - int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) + if is_device(input_float.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + int_data = aten._convert_weight_to_int4pack_for_cpu( + input_int4x8, inner_k_tiles + ) + else: + int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 9083dd7621..e1cf98b549 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -9,6 +9,7 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode +from torchao.dtypes.utils import is_device from torchao.kernel import ( int_scaled_matmul, ) @@ -19,7 +20,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 __all__ = [ "compute_error", @@ -402,13 +403,8 @@ def groupwise_affine_quantize_tensor_from_qparams( zero_point_domain=ZeroPointDomain.FLOAT, ) if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: - int_data_device_type = int_data.device.type - # Move to cpu, until issue with MPS memory management of temporary tensors is resolved - if int_data_device_type == "mps": - int_data = int_data.cpu() - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - if int_data_device_type == "mps": - int_data = int_data.to(device="mps") + if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data @@ -422,8 +418,10 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert groupsize > 1 assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path - if TORCH_VERSION_AT_LEAST_2_5 and ( - w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1 + if ( + TORCH_VERSION_AT_LEAST_2_5 + and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) + and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 From ed76e9caec342ad20133b3ddd84d6940af1adc64 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Wed, 27 Nov 2024 00:51:30 -0800 Subject: [PATCH 04/40] Reduce startup time for SAM2 AMG by using torch.export (#1358) --- examples/sam2_amg_server/cli.py | 16 +- examples/sam2_amg_server/server.py | 196 ++++++++++++++++-- .../_models/sam2/automatic_mask_generator.py | 4 +- torchao/_models/sam2/sam2_image_predictor.py | 2 +- 4 files changed, 197 insertions(+), 21 deletions(-) diff --git a/examples/sam2_amg_server/cli.py b/examples/sam2_amg_server/cli.py index 2fead4b5a4..9cf5bdc8f3 100644 --- a/examples/sam2_amg_server/cli.py +++ b/examples/sam2_amg_server/cli.py @@ -6,6 +6,8 @@ from server import model_type_to_paths from server import MODEL_TYPES_TO_MODEL from server import set_fast +from server import set_aot_fast +from server import load_aot_fast from server import set_furious from torchao._models.sam2.build_sam import build_sam2 from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator @@ -22,17 +24,20 @@ def main_docstring(): """ -def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False): +def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""): device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) if verbose: print(f"Loading model {sam2_checkpoint} with config {model_cfg}") sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") - if fast: - set_fast(mask_generator) if furious: set_furious(mask_generator) + if load_fast: + load_aot_fast(mask_generator, load_fast) + if fast: + set_fast(mask_generator, load_fast) + image_tensor = file_bytes_to_image_tensor(input_bytes) if verbose: print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.") @@ -50,7 +55,7 @@ def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=102 buf.seek(0) return buf.getvalue() -def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False): +def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""): input_bytes = bytearray(open(input_path, 'rb').read()) output_bytes = main_headless(checkpoint_path, model_type, @@ -59,7 +64,8 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch= output_format=output_format, verbose=verbose, fast=fast, - furious=furious) + furious=furious, + load_fast=load_fast) with open(output_path, "wb") as file: file.write(output_bytes) diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index 066b339c21..ba1aed7a00 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -332,14 +332,175 @@ def model_type_to_paths(checkpoint_path, model_type): model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}" return sam2_checkpoint, model_cfg -def set_fast(mask_generator): - # TODO: Using CUDA graphs can cause numerical differences? - mask_generator.predictor.model.image_encoder = torch.compile( - mask_generator.predictor.model.image_encoder, - mode="max-autotune", - fullgraph=True, - dynamic=False, + +def aot_compile(model_directory, name, fn, sample_args): + path = Path(model_directory) / Path(f"{name}.pt2") + print(f"Saving at {path=}") + options = { + "max_autotune": True, + "triton.cudagraphs": True, + } + + exported = torch.export.export_for_inference(fn, sample_args) + output_path = torch._inductor.aoti_compile_and_package( + exported, + package_path=str(path), + inductor_configs=options, ) + return output_path + + +def aot_load(path): + return torch._export.aot_load(path, "cuda") + +class FunctionModel(torch.nn.Module): + + def __init__(self, module, fn_name): + super().__init__() + self.module = module + self.fn_name = fn_name + + def forward(self, *args): + return getattr(self.module, self.fn_name)(*args) + + +def set_aot_fast(mask_generator, model_directory): + example_input = torch.empty(1, 3, 1024, 1024) + example_input = example_input.to(mask_generator.predictor._image_dtype) + example_input = (example_input.to(mask_generator.predictor.device),) + aot_compile(model_directory, + "sam2_image_encoder", + mask_generator.predictor.model.image_encoder, + example_input) + + # NOTE: THIS DOESN'T WORK YET! + # example_input_0_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_0_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device) + # example_input = ([example_input_0_0, example_input_0_1], + # example_input_1, + # example_input_2, + # example_input_3, + # None, + # None, + # True, + # True, + # -1) + # mask_generator.forward = mask_generator.predictor._predict_masks_with_features + # mask_generator(*example_input) + # aot_compile("sam2__predict_masks_with_features", + # mask_generator, + # example_input) + + # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device) + # aot_compile("sam2_sam_prompt_encoder", + # mask_generator.predictor.model.sam_prompt_encoder, + # ((example_input_2, example_input_3), + # None, + # None)) + + # NOTE: THIS DOESN'T WORK YET! + # example_input_0 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 2, 256, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + + # example_input_4_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_4_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device) + + # example_input = (example_input_0, + # example_input_1, + # example_input_2, + # example_input_3, + # True, + # True, + # [example_input_4_0, example_input_4_1]) + # print("Example") + # mask_generator.predictor.model.sam_mask_decoder(*example_input) + # print("Example done") + # aot_compile("sam2_sam_mask_decoder", + # mask_generator.predictor.model.sam_mask_decoder, + # example_input) + + # example_input_0 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 8, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input = (example_input_0, example_input_1, example_input_2) + + # mask_generator.predictor.model.sam_mask_decoder.transformer(*example_input) + # aot_compile("sam2_sam_mask_decoder_transformer", + # mask_generator.predictor.model.sam_mask_decoder.transformer, + # example_input) + + + + +class LoadedModel(torch.nn.Module): + + def __init__(self, aoti_compiled_model): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + + def forward(self, *args): + return self.aoti_compiled_model(*args) + +class LoadedDecoder(torch.nn.Module): + + def __init__(self, aoti_compiled_model, other): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + self.other = other + + def forward(self, *args): + return self.aoti_compiled_model(*args) + + def get_dense_pe(self, *args, **kwargs) -> torch.Tensor: + return self.other.get_dense_pe(*args, **kwargs) + +def load_aot_fast(mask_generator, model_directory): + t0 = time.time() + path = Path(model_directory) / Path(f"sam2_image_encoder.pt2") + assert path.exists(), f"Expected {path} to exist." + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = LoadedModel(pkg) + mask_generator.predictor.model.image_encoder = pkg_m + + # NOTE: This doesn't work yet! + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2")) + # pkg_m = LoadedModel(pkg) + # mask_generator.predictor._predict_masks_with_features = pkg_m.forward + + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_prompt_encoder.pt2")) + # pkg_m = LoadedDecoder(pkg, mask_generator.predictor.model.sam_prompt_encoder) + # mask_generator.predictor.model.sam_prompt_encoder = pkg_m + + # NOTE: This doesn't work yet! + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder.pt2")) + # pkg_m = LoadedModel(pkg) + # pkg_m.conv_s0 = mask_generator.predictor.model.sam_mask_decoder.conv_s0 + # pkg_m.conv_s1 = mask_generator.predictor.model.sam_mask_decoder.conv_s1 + # mask_generator.predictor.model.sam_mask_decoder = pkg_m + + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder_transformer.pt2")) + # pkg_m = LoadedModel(pkg) + # mask_generator.predictor.model.sam_mask_decoder.transformer = pkg_m + + print(f"End load. Took {time.time() - t0}s") + + +def set_fast(mask_generator, load_fast=""): + if load_fast == "": + # TODO: Using CUDA graphs can cause numerical differences? + mask_generator.predictor.model.image_encoder = torch.compile( + mask_generator.predictor.model.image_encoder, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) mask_generator.predictor._predict_masks = torch.compile( mask_generator.predictor._predict_masks, @@ -381,7 +542,9 @@ def main(checkpoint_path, port=5000, host="127.0.0.1", dry=False, - batch_size=1): + batch_size=1, + load_fast="", + save_fast=""): if verbose: logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', @@ -410,25 +573,32 @@ def main(checkpoint_path, logging.info(f"Using {points_per_batch} points_per_batch") mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") + if load_fast != "": + load_aot_fast(mask_generator, load_fast) + + if save_fast != "": + assert load_fast == "", "Can't save compiled models while loading them with --load-fast." + assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." + print(f"Saving compiled models under directory {save_fast}") + set_aot_fast(mask_generator, save_fast) + if fast: assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." - set_fast(mask_generator) + set_fast(mask_generator, load_fast) if furious: set_furious(mask_generator) - # since autoquant is replicating what furious mode is doing, don't use these two together elif use_autoquant: from torchao import autoquant from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST - mask_generator.predictor.model = autoquant(mask_generator.predictor.model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) - mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40) + # mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40) # NOTE: Not baseline feature mask_generator.predictor._transforms_device = mask_generator.predictor.device torch.set_float32_matmul_precision('high') - with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) diff --git a/torchao/_models/sam2/automatic_mask_generator.py b/torchao/_models/sam2/automatic_mask_generator.py index db544a9b61..891a2602ba 100644 --- a/torchao/_models/sam2/automatic_mask_generator.py +++ b/torchao/_models/sam2/automatic_mask_generator.py @@ -36,7 +36,7 @@ ) -class SAM2AutomaticMaskGenerator: +class SAM2AutomaticMaskGenerator(torch.nn.Module): def __init__( self, model: SAM2Base, @@ -105,7 +105,7 @@ def __init__( use_m2m (bool): Whether to add a one step refinement using previous mask predictions. multimask_output (bool): Whether to output multimask at each point of the grid. """ - + super().__init__() assert (points_per_side is None) != ( point_grids is None ), "Exactly one of points_per_side or point_grid must be provided." diff --git a/torchao/_models/sam2/sam2_image_predictor.py b/torchao/_models/sam2/sam2_image_predictor.py index 8fe01995ee..f404fe00e4 100644 --- a/torchao/_models/sam2/sam2_image_predictor.py +++ b/torchao/_models/sam2/sam2_image_predictor.py @@ -17,7 +17,7 @@ from torchao._models.sam2.utils.transforms import SAM2Transforms -class SAM2ImagePredictor: +class SAM2ImagePredictor(torch.nn.Module): def __init__( self, sam_model: SAM2Base, From c45d975c3b053e3a8e246ce4f3be4d06f9d73074 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 28 Nov 2024 11:59:25 -0800 Subject: [PATCH 05/40] Add support for quantize_() with Float8Linear module (#1344) --- test/float8/test_base.py | 19 +++++++++++++++++-- torchao/quantization/quant_api.py | 7 +++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d00b96d3bb..245abe0d02 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -531,6 +531,21 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): m(x) + @unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available") + def test_quantize(self): + x = torch.randn(32, 32, device="cuda") + m = nn.Sequential(nn.Linear(32, 32)).cuda() + m = convert_to_float8_training(m) + assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" + from torchao.quantization.quant_api import float8_weight_only, quantize_ + + quantize_(m, float8_weight_only()) + assert ( + m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn + ), "Post quantization dtype should be torch.float8_e4m3fn" + with torch.no_grad(): + m(x) + class TestScaledMM: @unittest.skipIf( @@ -576,7 +591,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: - atol, rtol = 2e-3, 2e-3 + atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) @unittest.skipIf(not is_cuda_8_9, "CUDA not available") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ddeb4ef2fb..60a7341e39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -39,6 +39,7 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -222,6 +223,12 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ + if isinstance(model, Float8Linear): + with torch.device("meta"): + new_module = nn.Linear(model.in_features, model.out_features) + new_module.weight = model.weight + new_module.bias = model.bias + model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization From e06fa8d250e8b21e1c9daca03f02030adb90d9a8 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 29 Nov 2024 12:56:52 +0800 Subject: [PATCH 06/40] Benchmark intel xpu (#1259) * support xpu * fix intel gpu peak mem * update benchmark for llama7b xpu * gupdate llama8b3.1 for intel GPU * update readme --- torchao/_models/llama/generate.py | 46 +++++++++++++++++++------------ torchao/quantization/README.md | 20 ++++++++++++-- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 550c10febb..d617ceb304 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -20,12 +20,14 @@ def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: print(f"device={device} is not yet suppported") -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' +default_device = 'cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu' # support running without installing as a package wd = Path(__file__).parent.parent.resolve() @@ -440,10 +442,13 @@ def main( prefill = torch.compile(prefill, fullgraph=True, dynamic=True) if memory_profile: - if device != "cuda": - print("Memory profiling only works on CUDA") - else: + if device == "cuda": torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + elif device == "xpu": + torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + else: + print("Memory profiling only works on CUDA or XPU devices") + aggregate_metrics = { 'tokens_per_sec': [], } @@ -453,6 +458,8 @@ def main( if i==0: if device == "cuda": torch.cuda.reset_peak_memory_stats() # MKG + elif device == "xpu": + torch.xpu.reset_peak_memory_stats() # MKG device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") @@ -520,24 +527,29 @@ def callback(x): print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") if memory_profile and i==0: - if device != "cuda": - print("Memory profiling only works on CUDA") - else: + if device == "cuda": snapshot = torch.cuda.memory._snapshot() - with open(f"{memory_profile}.pickle", 'wb') as f: - from pickle import dump - dump(snapshot, f) - print( - f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", - "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" - ) - break - + elif device == "xpu": + snapshot = torch.xpu.memory._snapshot() + else: + print("Memory profiling only works on CUDA or XPU devices") + + with open(f"{memory_profile}.pickle", 'wb') as f: + from pickle import dump + dump(snapshot, f) + print( + f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", + "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" + ) + break print("==========") tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() bandwidth = model_size * tokpersec - mem = torch.cuda.max_memory_reserved() /1e9 + if device == "cuda": + mem = torch.cuda.max_memory_reserved() /1e9 + elif device == "xpu": + mem = torch.xpu.max_memory_reserved() /1e9 print(f"Average tokens/sec: {tokpersec:.2f}") if batch_size > 1: print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 3c2eeb08f6..3fc2cb5ef0 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -3,7 +3,7 @@ Typically quantization algorithms will have different schemes for how the activa ## Benchmarks Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B. - +### CUDA backend | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | @@ -20,9 +20,16 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP | | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | | | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | | | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | +### XPU backend +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (bfloat16) | NA | 42.20 | 557.71 | 13.89 | 13.21 | +| | int8dq | NA | 9.87 | 65.35 | 14.60 | 6.62 | +| | int8wo | NA | 66.24 | 438.61 | 14.60 | 6.62 + -Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. +### CUDA backend | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 | @@ -31,6 +38,15 @@ Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a ma | | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 | | | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 | | | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 | +### XPU backend +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 | +| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 | +| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52 + + +Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU or Intel-Max1100 using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. From aeb19443fc7c5351667fb1c819eaf7e412ed31d0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 30 Nov 2024 01:27:42 -0800 Subject: [PATCH 07/40] Update README.md: Fix bibtex and sglang links (#1361) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1af5a7013c..6ba0e3be4c 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ We're also fortunate to be integrated into some of the leading open-source libra 3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) 4. [TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes 5. [torchchat](https://github.com/pytorch/torchchat) for post training quantization -6. [SGLang](https://github.com/sgl-project/sglang/pull/1341) for LLM inference quantization +6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341). ## Videos * [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009) @@ -205,4 +205,5 @@ If you find the torchao library useful, please cite it in your work as below. license = {BSD-3-Clause}, month = oct, year = {2024} +} ``` From 22bec74ad1ba260707c700b0dec100ec77db512e Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Sat, 30 Nov 2024 21:47:47 -0800 Subject: [PATCH 08/40] Update hardware check conditions (#1356) --- test/dtypes/test_affine_quantized.py | 10 +++-- test/dtypes/test_affine_quantized_float.py | 37 +++++++++++++------ test/float8/test_base.py | 34 ++++++++++------- test/float8/test_compile.py | 32 ++++++++-------- test/float8/test_fsdp2/test_fsdp2.py | 5 +-- .../test_fsdp2/test_fsdp2_fp8_comm_only.py | 5 +-- test/float8/test_numerics_integration.py | 17 ++++++--- test/integration/test_integration.py | 10 ++--- test/kernel/test_autotuner.py | 4 +- test/prototype/mx_formats/test_mx_linear.py | 9 ++--- test/prototype/mx_formats/test_mx_tensor.py | 7 +--- torchao/quantization/quant_api.py | 12 +++--- torchao/utils.py | 8 ++-- 13 files changed, 106 insertions(+), 84 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 9e9144c601..43d57b7d12 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -17,9 +17,11 @@ int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 - -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, +) def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): @@ -42,7 +44,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) - if is_cuda_8_9: + if is_sm_at_least_89(): base_functions.append(float8_weight_only()) return base_functions diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 74c130dc5e..4d8312b427 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -37,13 +37,14 @@ MappingType, choose_qparams_affine, ) +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) random.seed(0) torch.manual_seed(0) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - class ToyLinearModel(torch.nn.Module): def __init__(self, in_features, out_features): @@ -59,12 +60,14 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( - "granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()] + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] ) # Inputs are (M,..), K, N @common_utils.parametrize( @@ -134,12 +137,16 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -147,7 +154,9 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_unsupported_granularity(self): class UnsupportedGranularity: pass @@ -158,7 +167,9 @@ class UnsupportedGranularity: ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_per_row_with_float32(self): with pytest.raises( AssertionError, @@ -170,7 +181,9 @@ def test_per_row_with_float32(self): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model @@ -240,7 +253,9 @@ def test_serialization(self, mode: str): ), f"Scales do not match for {layer_name}" @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 245abe0d02..f61ff3738f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,11 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -60,10 +64,6 @@ torch.manual_seed(0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - - def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -219,7 +219,7 @@ def test_axiswise_reshape(self): ], ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") @@ -333,7 +333,9 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", @@ -415,7 +417,9 @@ def test_linear_from_recipe( config, ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -462,7 +466,9 @@ def test_autocast_outputs( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) @@ -523,7 +529,7 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s - @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() @@ -531,7 +537,7 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): m(x) - @unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") def test_quantize(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() @@ -549,7 +555,7 @@ def test_quantize(self): class TestScaledMM: @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -594,7 +600,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_cuda_8_9, "CUDA not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") @@ -630,7 +636,7 @@ def test_different_configs_error(self): a @ b @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ced5db7ff3..6d21686e32 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -46,10 +50,6 @@ from torchao.float8.float8_utils import e4m3_dtype from torchao.testing.float8.test_utils import get_test_float8_linear_config -# TODO(future PR): standardize IS_H100 with the rest of the codebase -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - def _test_compile_base( backend: str, @@ -99,7 +99,7 @@ def _test_compile_base( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( @@ -126,7 +126,7 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @@ -177,7 +177,7 @@ def test_aot_eager( [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -215,7 +215,9 @@ def test_inductor_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) -@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") +@unittest.skipIf( + not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" +) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() config = recipe_name_to_linear_config(recipe_name) @@ -253,7 +255,7 @@ def forward(self, x): # TODO(future): figure out why the test below fails on CUDA capability 8.9 @unittest.skipIf( - not torch.cuda.is_available() or not is_H100, + not torch.cuda.is_available() or not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): @@ -269,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self): torch.testing.assert_close(y_eager, y_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_input(self): @@ -293,7 +295,7 @@ def to_float(x): torch.testing.assert_close(y2_eager, y2_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_output(self): @@ -323,7 +325,7 @@ def test_float8_graph_output(self): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func(): @@ -364,7 +366,7 @@ def __exit__(self, *args): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func_cuda_graph_success(): @@ -396,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success(): @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index c3e31816ad..fbe5c9b508 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -6,7 +6,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -40,8 +40,7 @@ from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py index d5c0d7b853..d2e9a51c7f 100644 --- a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py +++ b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py @@ -3,7 +3,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -30,8 +30,7 @@ from torchao.float8.float8_tensor import GemmInputRole from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index e9028c8712..311964d831 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -34,9 +38,6 @@ from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - torch.manual_seed(0) @@ -176,7 +177,9 @@ def _test_impl(self, config: Float8LinearConfig) -> None: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) - @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_89(), reason="requires SM89 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_config_params( self, @@ -199,7 +202,9 @@ def test_encoder_fw_bw_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) - @pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_90(), reason="requires SM90 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_recipe( self, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index df20c5f03b..10f2d157f9 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -91,7 +91,8 @@ TORCH_VERSION_AT_LEAST_2_6, unwrap_tensor_subclass, is_fbcode, - benchmark_model + benchmark_model, + is_sm_at_least_90, ) from torchao.dtypes.utils import is_device @@ -105,7 +106,6 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -779,7 +779,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype @@ -799,7 +799,7 @@ def test_autoquantizable_flatten_unflatten(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): if dtype != torch.bfloat16: with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"): @@ -813,7 +813,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 4ed0974172..3e8c9b0a04 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -13,10 +13,10 @@ import pytest import torch from parameterized import parameterized +from torchao.utils import is_sm_at_least_90 logging.basicConfig(level=logging.INFO) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype): ("cuda", torch.float16), ] ) - @unittest.skipIf(not is_H100, "Needs H100") + @unittest.skipIf(not is_sm_at_least_90(), "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bc9b02deb5..4cac940313 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -20,11 +20,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -102,7 +99,7 @@ def test_linear_compile(elem_dtype, bias): Verify that compile does not change numerics of MX linear fw + bw """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") input_shape = (2, 4) grad_shape = (2, 6) @@ -173,7 +170,7 @@ def test_inference_compile_simple(elem_dtype): Smoke test for inference compile """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 964a575411..522785ae6f 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -24,11 +24,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -225,7 +222,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 60a7341e39..96ccb1889c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -53,8 +53,8 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_MI300, - is_sm_89, - is_sm_90, + is_sm_at_least_89, + is_sm_at_least_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -864,11 +864,11 @@ def _normalize_granularity( for _granularity in processed_granularity: if isinstance(_granularity, PerTensor): assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "PerTensor quantization only works for CUDA>=8.9 and MI300+" elif isinstance(_granularity, PerRow): assert ( - is_sm_90() or is_MI300() + is_sm_at_least_90() or is_MI300() ), "PerRow quantization only works for CUDA>=9.0 and MI300+" else: raise ValueError(f"Invalid granularity type: {_granularity}") @@ -966,7 +966,7 @@ def float8_dynamic_activation_float8_weight( """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -1023,7 +1023,7 @@ def float8_static_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/utils.py b/torchao/utils.py index ba91fb3fe0..d56191ed6b 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -33,8 +33,8 @@ "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", "is_MI300", - "is_sm_89", - "is_sm_90", + "is_sm_at_least_89", + "is_sm_at_least_90", ] @@ -612,7 +612,7 @@ def is_MI300(): return False -def is_sm_89(): +def is_sm_at_least_89(): return ( torch.cuda.is_available() and torch.version.cuda @@ -620,7 +620,7 @@ def is_sm_89(): ) -def is_sm_90(): +def is_sm_at_least_90(): return ( torch.cuda.is_available() and torch.version.cuda From cfabd6dd2dd3b172c296a6b063b00ec3f15e21a2 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Sun, 1 Dec 2024 19:35:44 -0800 Subject: [PATCH 09/40] Lint fixes test/quantization (#1359) --- ruff.toml | 2 +- test/quantization/test_galore_quant.py | 15 +- test/quantization/test_marlin_qqq.py | 8 +- test/quantization/test_qat.py | 459 +++++++++++++-------- test/quantization/test_quant_api.py | 287 ++++++++----- test/quantization/test_quant_primitives.py | 406 +++++++++++++----- 6 files changed, 793 insertions(+), 384 deletions(-) diff --git a/ruff.toml b/ruff.toml index 09d0a1ec97..13026345a6 100644 --- a/ruff.toml +++ b/ruff.toml @@ -9,7 +9,7 @@ include = [ "torchao/sparsity/**/*.py", "torchao/prototype/low_bit_optim/**.py", "test/float8/**/*.py", - "test/quantization/test_observer.py", + "test/quantization/**/*.py", "test/dtypes/**/*.py", "test/prototype/low_bit_optim/**.py", "torchao/utils.py", diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 37709c4128..3eb9b0a2c5 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -3,13 +3,16 @@ import pytest # Skip entire test if triton is not available, otherwise CI failure -try: - import triton -except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) - -from bitsandbytes.functional import create_dynamic_map, quantize_blockwise, dequantize_blockwise +try: # noqa: F401 + import triton # noqa: F401 +except ImportError: # noqa: F401 + pytest.skip("triton is not installed", allow_module_level=True) # noqa: F401 import torch +from bitsandbytes.functional import ( + create_dynamic_map, + dequantize_blockwise, + quantize_blockwise, +) from torchao.prototype.galore.kernels import ( triton_dequant_blockwise, diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 0dcaaf9c8c..ebdf2281e0 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -1,4 +1,5 @@ import copy +import unittest import pytest import torch @@ -19,9 +20,12 @@ choose_qparams_and_quantize_affine_qqq, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode -import unittest -@unittest.skipIf(is_fbcode(), "Skipping the test in fbcode since we don't have TARGET file for kernels") + +@unittest.skipIf( + is_fbcode(), + "Skipping the test in fbcode since we don't have TARGET file for kernels", +) class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 29f833c9ab..3a998635aa 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -13,9 +13,8 @@ import torch import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 -from torchao.dtypes import ( - TensorCoreTiledLayout, -) + +from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4 from torchao.quantization.granularity import ( PerAxis, PerGroup, @@ -26,33 +25,26 @@ ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, -) from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, + Int4WeightOnlyQATLinear, Int8DynActInt4WeightQATLinear, - Int4WeightOnlyQATLinear ) from torchao.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, - _get_qmin_qmax, _GenericFakeQuantize, -) -from torchao.quantization.quant_api import ( - int4_weight_only, - quantize_, + _get_qmin_qmax, ) from torchao.quantization.quant_primitives import ( - fake_quantize_affine, MappingType, TorchAODType, ZeroPointDomain, + fake_quantize_affine, ) from torchao.quantization.unified import ( TwoStepQuantizer, @@ -65,17 +57,12 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, -) - -from torchao.quantization.GPTQ import ( - _replace_linear_8da4w, - _replace_linear_int4 ) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() + class Sub(torch.nn.Module): def __init__(self): super().__init__() @@ -87,6 +74,7 @@ def example_inputs(self): def forward(self, x): return self.linear(x) + class M(torch.nn.Module): def __init__(self): super().__init__() @@ -103,6 +91,7 @@ def forward(self, x): x = self.linear2(x) return x + class M2(torch.nn.Module): def __init__(self): super().__init__() @@ -118,7 +107,9 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) @@ -132,20 +123,40 @@ def test_fake_quantize_per_channel_group(self): # fake quant op out = _fake_quantize_per_channel_group( - x, s, zp, qmin, qmax, group_size, + x, + s, + zp, + qmin, + qmax, + group_size, ) out.sum().backward() # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group( - x2, s, zp, qmin, qmax, torch.int8, group_size, + x2, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group( - out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_token(self): (qmin, qmax) = _get_qmin_qmax(8) @@ -161,10 +172,21 @@ def test_fake_quantize_per_token(self): # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_token( - x2, s, zp, qmin, qmax, torch.int8, + x2, + s, + zp, + qmin, + qmax, + torch.int8, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_token( - out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) @@ -182,9 +204,10 @@ def _set_ptq_weight( WeightOnlyInt4Linear, ) from torchao.quantization.qat.linear import ( - Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, + Int8DynActInt4WeightQATLinear, ) + n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) group_size = qat_linear.weight_fake_quantizer.config.group_size @@ -193,7 +216,13 @@ def _set_ptq_weight( fp32_weight = qat_linear.weight (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + fp32_weight, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) ptq_linear.weight = q_weight ptq_linear.scales = s @@ -201,28 +230,39 @@ def _set_ptq_weight( elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - qat_linear.weight, n_bit, group_size, + qat_linear.weight, + n_bit, + group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), qat_linear.inner_k_tiles, + q_weight.to("cuda"), + qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight ptq_linear.scales_and_zeros = scales_and_zeros else: raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_linear(self): - from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear group_size = 128 torch.manual_seed(self.SEED) qat_linear = Int8DynActInt4WeightQATLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) ptq_linear = Int8DynActInt4WeightLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) # Force the weights to be the same @@ -236,10 +276,12 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer(self): - from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer group_size = 16 torch.manual_seed(self.SEED) @@ -268,9 +310,13 @@ def test_qat_8da4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -282,7 +328,9 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -341,7 +389,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -363,8 +413,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) # Simulate training for both models - optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) - optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer1 = torch.optim.SGD( + nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + optimizer2 = torch.optim.SGD( + qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn1 = torch.nn.CrossEntropyLoss() loss_fn2 = torch.nn.CrossEntropyLoss() example_inputs = nn_model.example_inputs() @@ -382,9 +436,15 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + torch.testing.assert_close( + nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0 + ) def _test_qat_quantized_gradients(self, quantizer): """ @@ -394,7 +454,9 @@ def _test_qat_quantized_gradients(self, quantizer): torch.manual_seed(self.SEED) m = M() model = quantizer.prepare(m) - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer = torch.optim.SGD( + model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn = torch.nn.CrossEntropyLoss() # Simulate training @@ -426,13 +488,18 @@ def _test_qat_quantized_gradients(self, quantizer): optimizer.step() current_step += 1 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_gradients(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_generic_fake_quantize(self): """ Test that the generic fake quantize used in 8da4w QAT matches @@ -443,7 +510,9 @@ def test_qat_generic_fake_quantize(self): py_input = torch.randn(16, 64).float().requires_grad_() py_s = torch.randn(16).float() py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) - py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax) + py_out = torch.fake_quantize_per_channel_affine( + py_input, py_s, py_zp, 0, qmin, qmax + ) py_out.sum().backward() ao_input = copy.deepcopy(py_input) @@ -451,7 +520,9 @@ def test_qat_generic_fake_quantize(self): block_size = (1, ao_input.shape[-1]) ao_s = copy.deepcopy(py_s) ao_zp = copy.deepcopy(py_zp) - ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax) + ao_out = _GenericFakeQuantize.apply( + ao_input, block_size, ao_s, ao_zp, qmin, qmax + ) ao_out.sum().backward() torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) @@ -485,10 +556,14 @@ def test_qat_4w_primitives(self): # PTQ (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(device), inner_k_tiles, + q_weight.to(device), + inner_k_tiles, ) ptq_out = torch.ops.aten._weight_int4pack_mm( x, q_weight, group_size, scales_and_zeros @@ -497,9 +572,12 @@ def test_qat_4w_primitives(self): # QAT block_size = (1, group_size) quant_min = 0 - quant_max = 2 ** n_bit - 1 + quant_max = 2**n_bit - 1 scales, zero_points = get_groupwise_affine_qparams( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) w_fq = fake_quantize_affine( weight, @@ -509,27 +587,37 @@ def test_qat_4w_primitives(self): torch.int32, quant_min, quant_max, - zero_point_domain = ZeroPointDomain.FLOAT, + zero_point_domain=ZeroPointDomain.FLOAT, ) qat_out = torch.nn.functional.linear(x, w_fq) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 device = torch.device("cuda") dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) ptq_linear = WeightOnlyInt4Linear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) # Force the weights to be the same @@ -543,17 +631,22 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer + quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 @@ -563,10 +656,12 @@ def test_qat_4w_quantizer(self): m = M().to(device).to(dtype) m2 = copy.deepcopy(m) qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) @@ -589,13 +684,16 @@ def test_qat_4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) class _MyQATQuantizer(TwoStepQuantizer): """ Dummy quantizer that attaches a certain value to each nn.Linear's `_temp_quantizer_values` attribute. """ + ATTR_NAME = "_temp_quantizer_values" def __init__(self, value: str): @@ -626,19 +724,24 @@ def test_composable_qat_quantizer(self): self.assertEqual(values_list, ["quantizer1", "quantizer2"]) composable_quantizer.convert(model) values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) - self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"]) + self.assertEqual( + values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"] + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_embedding(self): from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer + model = M2() x = model.example_inputs() - out = model(*x) + model(*x) quantizer = Int4WeightOnlyEmbeddingQATQuantizer() prepared = quantizer.prepare(model) - prepared_out = prepared(*x) + prepared(*x) converted = quantizer.convert(model) - converted_out = converted(*x) + converted(*x) def test_fake_quantize_config_granularity(self): """ @@ -685,7 +788,9 @@ def test_fake_quantize_config_granularity_error_cases(self): Test incorrect settings of `FakeQuantizeConfig`'s granularity. """ # no granularity provided - with self.assertRaisesRegex(ValueError, "`granularity` or `group_size` must be set"): + with self.assertRaisesRegex( + ValueError, "`granularity` or `group_size` must be set" + ): FakeQuantizeConfig(torch.int8) # group_size with conflicting granularity @@ -718,8 +823,12 @@ def test_fake_quantize_config_mapping_type(self): """ # symmetric symmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token") - symmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=True) - symmetric_config3 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC) + symmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=True + ) + symmetric_config3 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC + ) self.assertEqual(symmetric_config1.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config2.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config3.mapping_type, MappingType.SYMMETRIC) @@ -728,8 +837,12 @@ def test_fake_quantize_config_mapping_type(self): self.assertTrue(symmetric_config3.is_symmetric) # asymmetric - asymmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - asymmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.ASYMMETRIC) + asymmetric_config1 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ) + asymmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.ASYMMETRIC + ) self.assertEqual(asymmetric_config1.mapping_type, MappingType.ASYMMETRIC) self.assertEqual(asymmetric_config2.mapping_type, MappingType.ASYMMETRIC) self.assertFalse(asymmetric_config1.is_symmetric) @@ -743,11 +856,15 @@ def test_fake_quantize_config_mapping_type(self): # bad config1: both mapping_type and is_symmetric are set msg = "Cannot set both `mapping_type` and `is_symmetric`" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False + ) # bad config2: not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR + ) def test_fake_quantize_config_dtype(self): """ @@ -781,7 +898,9 @@ def test_fake_quantize_config_dtype(self): FakeQuantizeConfig(TorchAODType.INT7, "per_token") FakeQuantizeConfig(torch.int8, "per_token") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_8da4w(self): """ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. @@ -792,7 +911,9 @@ def test_fake_quantized_linear_8da4w(self): 256, 688, bias=False, - activation_config=FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False), + activation_config=FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ), weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), ) @@ -801,7 +922,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant. """ # activations - (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) + (s, zp) = _choose_qparams_per_token_asymmetric( + x, torch.float32, torch.int32 + ) (qmin, qmax) = _get_qmin_qmax(8) x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax) @@ -809,7 +932,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.linear(x_fq, w_fq) # Compare linear values @@ -820,7 +945,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_8da4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_4w(self): """ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. @@ -849,7 +976,13 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) w_fq = _fake_quantize_per_channel_group( - weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT, + weight, + s, + zp, + qmin, + qmax, + group_size, + zero_point_domain=ZeroPointDomain.FLOAT, ) return F.linear(x, w_fq) @@ -860,50 +993,78 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: fq_out = fq_linear(x) baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_8da4w(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(isinstance(module[0], Int8DynActInt4WeightQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int8DynActInt4WeightQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_int4(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(isinstance(module[0], Int4WeightOnlyQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int4WeightOnlyQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_embedding_4w(self): """ Test that we can express int4 per group symmetric weight only fake quantization @@ -926,7 +1087,9 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.embedding(x, w_fq) # Compare embedding values @@ -937,59 +1100,15 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_prototype_bc(self): """ Just to make sure we can import all the old prototype paths. We will remove this test in the near future when we actually break BC. """ - from torchao.quantization.prototype.qat import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - ComposableQATQuantizer, - Int8DynActInt4WeightQATLinear, - Int4WeightOnlyEmbeddingQATQuantizer, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, - ) - from torchao.quantization.prototype.qat._module_swap_api import ( - disable_4w_fake_quant_module_swap, - enable_4w_fake_quant_module_swap, - disable_8da4w_fake_quant_module_swap, - enable_8da4w_fake_quant_module_swap, - Int4WeightOnlyQATQuantizerModuleSwap, - Int8DynActInt4WeightQATQuantizerModuleSwap, - ) - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( - AffineFakeQuantizedTensor, - to_affine_fake_quantized, - ) - from torchao.quantization.prototype.qat.api import ( - ComposableQATQuantizer, - FakeQuantizeConfig, - ) - from torchao.quantization.prototype.qat.embedding import ( - FakeQuantizedEmbedding, - Int4WeightOnlyEmbeddingQATQuantizer, - Int4WeightOnlyEmbedding, - Int4WeightOnlyQATEmbedding, - ) - from torchao.quantization.prototype.qat.fake_quantizer import ( - FakeQuantizer, - ) - from torchao.quantization.prototype.qat.linear import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - FakeQuantizedLinear, - Int4WeightOnlyQATLinear, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATLinear, - Int8DynActInt4WeightQATQuantizer, - ) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 458cd07810..eb5f1337d1 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -6,81 +6,86 @@ # mypy: ignore-errors # This test takes a long time to run +import copy +import gc +import tempfile import unittest +from pathlib import Path + import torch -import os from torch.ao.quantization.quantize_pt2e import ( - prepare_pt2e, convert_pt2e, + prepare_pt2e, ) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import TestCase -import torchao +from torchao import quantize_ +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import ( AffineQuantizedTensor, ) from torchao.quantization import ( LinearActivationQuantizedTensor, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) -from torchao.quantization.subclass import ( - Int8WeightOnlyQuantizedLinearWeight, - Int4WeightOnlyQuantizedLinearWeight, -) -from torchao import quantize_ from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, Quantizer, TwoStepQuantizer, - int8_dynamic_activation_int4_weight, + _replace_with_custom_fn_if_matches_filter, int4_weight_only, - int8_weight_only, + int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, + int8_weight_only, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) +from torchao.quantization.subclass import ( + Int4WeightOnlyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + unwrap_tensor_subclass, ) -from pathlib import Path -from torchao._models.llama.tokenizer import get_tokenizer -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao.utils import unwrap_tensor_subclass -import copy -import tempfile -import gc -from torch.testing._internal.common_utils import TestCase -from torch.testing._internal import common_utils def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs).module() - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) return m + def capture_and_prepare(model, example_inputs): m = torch.export.export(model, example_inputs) - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this m(*example_inputs) return m -class XNNPackDynamicQuantizer(TwoStepQuantizer): +class XNNPackDynamicQuantizer(TwoStepQuantizer): def prepare(self, model: torch.nn.Module) -> torch.nn.Module: _replace_with_custom_fn_if_matches_filter( model, - lambda linear_mod: capture_and_prepare(linear_mod, (torch.randn(1, linear_mod.in_features))), + lambda linear_mod: capture_and_prepare( + linear_mod, (torch.randn(1, linear_mod.in_features)) + ), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) return model @@ -93,11 +98,13 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: ) return model + class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: quantize_(model, int8_dynamic_activation_int8_weight()) return model + class ToyLinearModel(torch.nn.Module): def __init__(self, m=64, n=32, k=64): super().__init__() @@ -105,7 +112,11 @@ def __init__(self, m=64, n=32, k=64): self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) def forward(self, x): x = self.linear1(x) @@ -118,9 +129,11 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs The deprecated implementation for int8 dynamic quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _in_features_greater_than_16 - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import ( + _get_subclass_inserter, + _in_features_greater_than_16, + _is_linear, + ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight if filter_fn is None: @@ -129,37 +142,49 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs ) _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + model, + _get_subclass_inserter( + Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs + ), + filter_fn, ) + def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ The deprecated implementation for weight only quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs), + _get_subclass_inserter( + deprecated_tenosr_subclass, enable_parametrization=True, **kwargs + ), filter_fn, ) return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) + +_ref_change_linear_weights_to_int8_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) +) +_ref_change_linear_weights_to_int4_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) +) + class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() quantize_(m, int8_dynamic_activation_int8_weight()) - quantized = m(*example_inputs) + m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -182,7 +207,9 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) - @unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!") + @unittest.skip( + "FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!" + ) def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): quantizer = TorchCompileDynamicQuantizer() m = ToyLinearModel().eval() @@ -196,10 +223,8 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): - from torchao.quantization.quant_api import ( - change_linear_weights_to_int8_woqtensors, - ) m = ToyLinearModel().eval().cpu() + def api(model): quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) @@ -223,10 +248,12 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" + ) def test_8da4w_quantizer(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) m = ToyLinearModel().eval() @@ -242,8 +269,9 @@ def test_8da4w_quantizer(self): # https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_8da4w_gptq_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer from torchao._models._eval import InputRecorder, TransformerEvalWrapper + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer + # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -268,16 +296,20 @@ def test_8da4w_gptq_quantizer(self): input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int8DynActInt4WeightGPTQQuantizer( blocksize, @@ -287,7 +319,7 @@ def test_8da4w_gptq_quantizer(self): ) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -298,15 +330,17 @@ def test_8da4w_gptq_quantizer(self): 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.88, ( - f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.88 + ), f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower" + ) def test_8da4w_quantizer_eval(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer precision = torch.bfloat16 device = "cpu" @@ -325,7 +359,7 @@ def test_8da4w_quantizer_eval(self): quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision) q_model = quantizer.quantize(model) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( q_model, tokenizer, q_model.config.block_size, @@ -335,14 +369,18 @@ def test_8da4w_quantizer_eval(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4_weight_only(self): + from torchao._models._eval import ( + MultiTensorInputRecorder, + TransformerEvalWrapper, + ) from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer - from torchao._models._eval import MultiTensorInputRecorder, TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -367,18 +405,21 @@ def test_gptq_quantizer_int4_weight_only(self): calibration_seq_length = 100 input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = MultiTensorInputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - device="cpu", - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() - + inputs = ( + MultiTensorInputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int4WeightOnlyGPTQQuantizer( blocksize, @@ -398,14 +439,15 @@ def test_gptq_quantizer_int4_weight_only(self): ["wikitext"], None, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_quantizer_int4_weight_only(self): - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -435,13 +477,14 @@ def test_quantizer_int4_weight_only(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -456,7 +499,7 @@ def test_eval_wrapper(self): tokenizer_path, "Llama-2-7b-chat-hf", ) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -466,17 +509,20 @@ def test_eval_wrapper(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none']<7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" - checkpoint_path = Path(".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth") + checkpoint_path = Path( + ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" + ) model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) @@ -498,30 +544,43 @@ def test_eval_wrapper_llama3(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @common_utils.parametrize("mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR]) + @common_utils.parametrize( + "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR] + ) def test_quantized_tensor_subclass_8da4w(self, mapping_type): group_size = 32 m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size, mapping_type=mapping_type)) + quantize_( + m, + int8_dynamic_activation_int4_weight( + group_size=group_size, mapping_type=mapping_type + ), + ) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=mapping_type) + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=group_size, mapping_type=mapping_type + ) m_copy = quantizer.quantize(m_copy) assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) @@ -552,7 +611,6 @@ def test_quantized_tensor_subclass_int4(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_wo(self): @@ -568,13 +626,11 @@ def test_quantized_tensor_subclass_int8_wo(self): # reference _ref_change_linear_weights_to_int8_woqtensors(m_copy) - res = m(*example_inputs) ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") @@ -583,13 +639,19 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") + example_inputs = m.example_inputs( + batch_size=20, dtype=torch.bfloat16, device="cuda" + ) quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference _ref_change_linear_weights_to_int8_dqtensors(m_copy) @@ -601,6 +663,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # workaround for export path from torchao.utils import unwrap_tensor_subclass + m_unwrapped = unwrap_tensor_subclass(m) m = torch.export.export(m_unwrapped, example_inputs).module() @@ -630,12 +693,10 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") quantize_(m, int8_weight_only()) @@ -654,7 +715,6 @@ def test_int4wo_quantized_model_to_device(self): devices = ["cuda", "cuda:0"] for device in devices: m = ToyLinearModel().eval().to(torch.bfloat16).to(device) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) quantize_(m, int4_weight_only()) @@ -678,7 +738,7 @@ def test_quantized_tensor_subclass_save_load_map_location(self): f.seek(0) state_dict = torch.load(f.name, map_location="cpu", mmap=True) - with torch.device('meta'): + with torch.device("meta"): m_copy = ToyLinearModel().eval() m_copy.load_state_dict(state_dict, assign=True) @@ -710,12 +770,13 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) -class TestMultiTensorFlow(TestCase): +class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) tensor2 = torch.randn(3, 3) mt = MultiTensor(tensor1) @@ -728,6 +789,7 @@ def test_multitensor_add_tensors(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) mt = MultiTensor(tensor1) mt.pad_to_length(3) @@ -739,14 +801,13 @@ def test_multitensor_pad_unpad(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.ones(3, 3) mt = MultiTensor(tensor1) mt += 1 # In-place addition self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2))) - - common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 78556772d1..a3fef29fea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -7,25 +7,27 @@ # mypy: ignore-errors # This test takes a long time to run import unittest + import torch + +from torchao.dtypes.utils import is_device from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + choose_qparams_affine, + dequantize_affine, fake_quantize_affine, fake_quantize_affine_cachemask, quantize_affine, - dequantize_affine, - choose_qparams_affine, - MappingType, - ZeroPointDomain, ) + # TODO: remove test for utils? from torchao.quantization.utils import ( get_group_qparams_symmetric, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor_from_qparams, groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor_from_qparams, quantize_activation_per_token_absmax, ) - from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -33,11 +35,11 @@ TORCH_VERSION_AT_LEAST_2_6, is_fbcode, ) -from torchao.dtypes.utils import is_device _SEED = 1234 torch.manual_seed(_SEED) + # Helper function to run a function twice # and verify that the result is the same. # Adds some verification to avoid side effects. @@ -48,9 +50,12 @@ def check_idempotent(self, fn, *args, **kwargs): output0 = fn(*args, **kwargs) assert torch.is_tensor(output0) output1 = fn(*args, **kwargs) - self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.") + self.assertTrue( + torch.equal(output0, output1), f"Expected given function {fn} to be idempotent." + ) return output1 + # Legacy tinygemm ops def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): if groupsize > w.shape[-1]: @@ -71,6 +76,7 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1 dtype=dtype ).reshape(w.shape[0], -1) + def _groupwise_affine_quantize_tensor_from_qparams( w, scales, @@ -108,6 +114,7 @@ def _groupwise_affine_quantize_tensor_from_qparams( return w_int4x8 + def _groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, scales, @@ -138,7 +145,9 @@ def _groupwise_affine_dequantize_tensor_from_qparams( class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as @@ -147,7 +156,6 @@ def test_get_group_qparams_symmetric(self): n_bit = 4 qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1 - eps = torch.finfo(torch.float32).eps groupsize = 256 torch.manual_seed(self.SEED) weight = torch.randn(100, 256).to(torch.float16) @@ -160,14 +168,16 @@ def test_get_group_qparams_symmetric(self): quant_max=qmax, # This is needed to ensure `min_val` and `max_val` are fp16, # otherwise they default to fp32 and the qparams will be slightly off - factory_kwargs={"dtype": torch.float16} + factory_kwargs={"dtype": torch.float16}, ) obs(weight) (scale_obs, _) = obs.calculate_qparams() scale_obs = scale_obs.reshape(weight.shape[0], -1) # assert that scales are identical - (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16) + (scale_ao, _) = get_group_qparams_symmetric( + weight, n_bit, groupsize, precision=torch.float16 + ) torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0) def test_choose_qparams_group_sym(self): @@ -180,9 +190,19 @@ def test_choose_qparams_group_sym(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) @@ -197,13 +217,26 @@ def test_choose_qparams_group_sym_no_clipping_err(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) @@ -211,11 +244,29 @@ def test_choose_qparams_token_asym(self): dtype = torch.int8 block_size = (1, 10) if TORCH_VERSION_AT_LEAST_2_6: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float64, + zero_point_dtype=torch.int64, + ) else: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + ) - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) + scale_ref, zp_ref = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + input, dtype + ) + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -229,12 +280,15 @@ def test_choose_qparams_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -248,18 +302,24 @@ def test_choose_qparams_tensor_sym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -272,21 +332,35 @@ def test_quantize_activation_per_token_abs_max(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(scale, scale_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_zero_input(self): input = torch.zeros(10, 10) # make sure it still works quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_dtype(self): input = torch.zeros(10, 10, dtype=torch.bfloat16) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -300,18 +374,30 @@ def test_quantize_activation_per_token_abs_max_dtype(self): quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC dtype = torch.int8 block_size = (1, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) group_size = 2 quant_min = -128 @@ -320,23 +406,43 @@ def test_quantize_dequantize_group_sym(self): input, scale, zero_point, quant_min, quant_max, torch.int8, group_size ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + group_size, + output_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 1) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) output_dtype = torch.float32 quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) axis = 1 quant_min = -128 @@ -345,12 +451,21 @@ def test_quantize_dequantize_channel_asym(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) @@ -358,32 +473,61 @@ def test_quantize_dequantize_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) output_dtype = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) - axis = 1 quant_min = -128 quant_max = 127 quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor( input, scale, zero_point, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 1, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) axis = 2 quant_min = -128 @@ -392,20 +536,40 @@ def test_quantize_dequantize_channel_asym_4d(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 2, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02) @@ -414,11 +578,15 @@ def test_choose_qparams_tensor_asym_eps(self): mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) eps = torch.finfo(torch.float32).eps self.assertEqual(scale, eps) - @unittest.skipIf(not torch.cuda.is_available(), "skipping when cuda is not available") + @unittest.skipIf( + not torch.cuda.is_available(), "skipping when cuda is not available" + ) def test_get_group_qparams_symmetric_memory(self): """Check the memory usage of the op""" weight = torch.randn(1024, 1024).to(device="cuda") @@ -430,18 +598,20 @@ def test_get_group_qparams_symmetric_memory(self): self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use) def test_raises(self): - """Make sure some errors are raised when user requested an unsupported type of quantization - """ + """Make sure some errors are raised when user requested an unsupported type of quantization""" input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) # make sure we can't quantize int32 tensors: with self.assertRaisesRegex(AssertionError, "Unsupported input dtype:"): - _ = quantize_affine(input.to(torch.int32), block_size, scale, zero_point, dtype) + _ = quantize_affine( + input.to(torch.int32), block_size, scale, zero_point, dtype + ) # block_size and scale/zero_point shape mismatch block_size = (1, 1) @@ -460,7 +630,10 @@ def test_not_preserve_zero_not_supported(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - with self.assertRaisesRegex(ValueError, "preserve_zero == False is not supported for symmetric quantization"): + with self.assertRaisesRegex( + ValueError, + "preserve_zero == False is not supported for symmetric quantization", + ): choose_qparams_affine( input, mapping_type, @@ -474,11 +647,12 @@ def test_not_preserve_zero_not_supported(self): preserve_zero=False, ) - def test_get_groupwise_affine_qparams(self): input = torch.randn(10, 256) n_bit = 4 - scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) + scale_ref, zero_point_ref = _get_groupwise_affine_qparams( + input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16 + ) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 @@ -488,20 +662,19 @@ def test_get_groupwise_affine_qparams(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - scale, zero_point = \ - choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - quant_min, - quant_max, - eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=False, + zero_point_domain=ZeroPointDomain.FLOAT, + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zero_point_ref)) @@ -513,8 +686,12 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self): n_bit = 4 groupsize = 128 - w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) @@ -529,14 +706,22 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): input_tmp = input if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize + ) else: - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine(self): input = torch.randn(10, 10) @@ -548,14 +733,31 @@ def test_fake_quantize_affine(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) - fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + fake_quantized = fake_quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) torch.testing.assert_close(dequantized, fake_quantized) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) @@ -567,16 +769,36 @@ def test_fake_quantize_affine_cachemask(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) (fake_quantized, mask) = fake_quantize_affine_cachemask( - input, block_size, scale, zero_point, dtype, quant_min, quant_max, + input, + block_size, + scale, + zero_point, + dtype, + quant_min, + quant_max, ) expected_mask = torch.full(input.shape, True) torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + if __name__ == "__main__": unittest.main() From a558f7e90012d61cd636bb2068e027cc0062e5ca Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Sun, 1 Dec 2024 20:58:50 -0800 Subject: [PATCH 10/40] Lint fixes test sparsity (#1360) --- ruff.toml | 1 + test/sparsity/test_fast_sparse_training.py | 35 ++++++++---- test/sparsity/test_marlin.py | 63 +++++++++++++++------- test/sparsity/test_wanda.py | 8 ++- 4 files changed, 75 insertions(+), 32 deletions(-) diff --git a/ruff.toml b/ruff.toml index 13026345a6..0d02de24ed 100644 --- a/ruff.toml +++ b/ruff.toml @@ -11,6 +11,7 @@ include = [ "test/float8/**/*.py", "test/quantization/**/*.py", "test/dtypes/**/*.py", + "test/sparsity/**/*.py", "test/prototype/low_bit_optim/**.py", "torchao/utils.py", diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index f2d5686fd3..e3f5626d49 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -1,19 +1,18 @@ -import logging -import unittest import copy +import unittest import torch -import torch.nn.functional as F from torch import nn from torch.testing._internal.common_utils import TestCase from torchao.sparsity.training import ( + SemiSparseLinear, swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, - SemiSparseLinear ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode + class ToyModel(nn.Module): def __init__(self): super().__init__() @@ -26,8 +25,8 @@ def forward(self, x): x = self.linear2(x) return x -class TestRuntimeSemiStructuredSparsity(TestCase): +class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @@ -35,6 +34,7 @@ class TestRuntimeSemiStructuredSparsity(TestCase): def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() model = ToyModel().half().cuda() @@ -42,7 +42,9 @@ def test_runtime_weight_sparsification(self): for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): - sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense() + sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort( + mod.weight.detach() + ).to_dense() mod.weight = nn.Parameter(sparse) dense_result = model(input) @@ -62,8 +64,12 @@ def test_runtime_weight_sparsification(self): sparse_result.backward(grad) # check grad - assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1) - assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1) + assert torch.allclose( + model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1 + ) + assert torch.allclose( + model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1 + ) # check that swap back works swap_semi_sparse_linear_with_linear(model_c) @@ -77,6 +83,7 @@ def test_runtime_weight_sparsification(self): def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() model = ToyModel().half().cuda() @@ -84,7 +91,9 @@ def test_runtime_weight_sparsification_compile(self): for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): - sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense() + sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort( + mod.weight.detach() + ).to_dense() mod.weight = nn.Parameter(sparse) model = torch.compile(model, fullgraph=True) @@ -106,8 +115,12 @@ def test_runtime_weight_sparsification_compile(self): sparse_result.backward(grad) # check grad - assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1) - assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1) + assert torch.allclose( + model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1 + ) + assert torch.allclose( + model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1 + ) # check that swap back works swap_semi_sparse_linear_with_linear(model_c) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 173afd7dab..4da7304a24 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -1,28 +1,24 @@ -import torch import copy -import pytest +import pytest +import torch from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + from torchao.dtypes import MarlinSparseLayout -from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.quantization.quant_api import int4_weight_only, quantize_ -from torchao.sparsity.marlin import ( - pack_to_marlin_24, - unpack_from_marlin_24, - inject_24 -) from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, choose_qparams_affine, quantize_affine, - ZeroPointDomain, - MappingType, ) +from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class SparseMarlin24(TestCase): - def setUp(self): super().setUp() torch.manual_seed(0) @@ -53,7 +49,9 @@ def test_quant_sparse_marlin_layout_eager(self): quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + assert torch.allclose( + dense_result, sparse_result, atol=3e-1 + ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @@ -71,7 +69,9 @@ def test_quant_sparse_marlin_layout_compile(self): self.model.forward = torch.compile(self.model.forward, fullgraph=True) sparse_result = self.model(self.input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + assert torch.allclose( + dense_result, sparse_result, atol=3e-1 + ), "Results are not close" @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_pack_unpack_equivalence(self): @@ -94,9 +94,30 @@ def test_pack_unpack_equivalence(self): # Inject 2:4 sparsity mask w_24, _ = inject_24(w, *w.shape) - # Quantize weights - scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) - w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain) + # Quantize weights + scales, zeros = choose_qparams_affine( + w_24, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + w_q_24 = quantize_affine( + w_24, + block_size, + scales, + zeros, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) scales = scales.reshape(-1, w_q_24.shape[1]) # Test pack/unpack equivalence @@ -107,8 +128,12 @@ def test_pack_unpack_equivalence(self): q_w_comp, packed_scales, meta, shape, group_size, num_bits ) - assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights" - assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales" + assert torch.equal( + w_q_24, unpacked_q_w + ), "Unpacked weights do not match original weights" + assert torch.equal( + scales, unpacked_scales + ), "Unpacked scales do not match original scales" if __name__ == "__main__": diff --git a/test/sparsity/test_wanda.py b/test/sparsity/test_wanda.py index fcb94036aa..e02ea9822a 100644 --- a/test/sparsity/test_wanda.py +++ b/test/sparsity/test_wanda.py @@ -3,12 +3,13 @@ import torch from torch import nn -from torchao.sparsity import WandaSparsifier from torch.ao.pruning import FakeSparsity from torch.nn.utils.parametrize import is_parametrized from torch.testing._internal.common_pruning import SimpleLinear from torch.testing._internal.common_utils import TestCase +from torchao.sparsity import WandaSparsifier + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -29,7 +30,9 @@ def test_prepare(self): assert hasattr(module.parametrizations["weight"][0], "mask") # Check parametrization exists and is correct assert is_parametrized(module, "weight") - assert type(module.parametrizations.weight[0]) == FakeSparsity + assert isinstance( + module.parametrizations.weight[0], FakeSparsity + ), "FakeSparsity not found" # check activation observer is present assert hasattr(module, "activation_post_process") @@ -110,5 +113,6 @@ def test_two_layer_mlp_unstructured(self): sparsifier.squash_mask() + if __name__ == "__main__": unittest.main() From 2e36daa263daceeaf1680fd03f017fd451541168 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 2 Dec 2024 08:02:47 -0800 Subject: [PATCH 11/40] add option to use SAC in float8 training profiling script (#1354) Summary: SAC is often used in training jobs, adding it here to speed up debugging of microbenchmarks with SAC. Test Plan: ``` python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20241120_test --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed --enable_activation_checkpointing True ``` Reviewers: Subscribers: Tasks: Tags: --- benchmarks/float8/profile_linear_float8.py | 34 ++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index f4f2813a37..e545ea4665 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -6,6 +6,7 @@ import copy import io +import functools import os import random from contextlib import nullcontext, redirect_stdout @@ -22,6 +23,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import ( + checkpoint, + create_selective_checkpoint_contexts, + CheckpointPolicy, +) from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -254,6 +260,22 @@ def profile_function( return prof +# set up AC for max(abs(tensor)) +# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts +ops_to_save = [ + torch.ops.aten.abs.default, + torch.ops.aten.max.default, +] + +def policy_fn(ctx, op, *args, **kwargs): + if op in ops_to_save: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + +context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + + def main( profile_path_prefix: pathlib.Path, compile: bool = True, @@ -265,6 +287,7 @@ def main( dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, enable_sync_amax_history: bool = True, + enable_activation_checkpointing: bool = False, ): assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") @@ -294,6 +317,7 @@ def main( print(f"Compile is set to | {compile}") print(f"model_type is set to | {model_type}") print(f"scaling_repr is set to | {scaling_repr}") + print(f"enable_activation_checkpointing is set to {enable_activation_checkpointing}") device = "cuda" ref_dtype = torch.bfloat16 @@ -338,11 +362,17 @@ def main( convert_to_float8_training(m_float8, config=config) def ref_forw_backward(x): - out = m_ref(x) + if enable_activation_checkpointing: + out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_ref(x) out.sum().backward() def float8_forw(x): - out = m_float8(x) + if enable_activation_checkpointing: + out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_float8(x) return out sync_amax_history = sync_float8_amax_and_scale_history From 65b885fb3b6010d16c078f71036362a7b4041316 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 3 Dec 2024 01:04:04 +0900 Subject: [PATCH 12/40] check `scale.ndim` before applying `t`/`transpose` (#1339) * check `scale.ndim` before applying `t`/`transpose` because (a) `scale` could be 0D/1D and `transpose` and (b) the args and kwargs of `torch.ops.aten.transpose.int` would supply `dim0` and `dim1`, leading to cause dim canonicalization to fail. e.g. [`torch._prims_common.canonicalize_dims`](https://github.com/pytorch/pytorch/blob/07906f2/torch/_prims_common/__init__.py#L704) Signed-off-by: Masaki Kozuki * add test of `.t()` and `.transpose(0, 1)` Signed-off-by: Masaki Kozuki * change cond to transpose scale Signed-off-by: Masaki Kozuki --------- Signed-off-by: Masaki Kozuki --- test/float8/test_base.py | 19 +++++++++++++++++++ torchao/float8/float8_ops.py | 5 ++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index f61ff3738f..ba6281deaf 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -141,6 +141,25 @@ def test_copy_(self): fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) + def test_transpose(self): + a = torch.rand((16, 16), dtype=torch.bfloat16) + for axiswise_dim in (None, 0, -1): + scale_a = tensor_to_scale(a, e4m3_dtype) + fp8_a = hp_tensor_and_scale_to_float8( + a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim + ) + fp8_b = hp_tensor_and_scale_to_float8( + a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim + ) + + fp8_a_transposed = fp8_a.transpose(0, 1) + fp8_b_t = fp8_b.t() + + torch.testing.assert_close( + (fp8_a_transposed._data, fp8_a_transposed._scale), + (fp8_b_t._data, fp8_b_t._scale), + ) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) def test_axiswise_dynamic_cast(self, shape, axiswise_dim): diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 921d50e093..2af4160de4 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -85,7 +85,10 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): ) def float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) - new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + if args[0]._scale.ndim > 1: + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + else: + new_scale = args[0]._scale if aten_op == aten.transpose.int: _assert_tensorwise_scale(aten_op, args[0]._scale) From 541fe9cf2982593ee0440ec192f7d8ae0a490159 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 2 Dec 2024 08:17:30 -0800 Subject: [PATCH 13/40] bump main version to 0.8 (#1364) Summary: The v0.7 release has just been cut, bumping the version to v0.8 on main. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index faef31a435..a3df0a6959 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.7.0 +0.8.0 From c6e32b88d225cba984de43847603e8c76cdb6e4c Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 2 Dec 2024 12:45:40 -0800 Subject: [PATCH 14/40] Lint fixes torchao files (#1366) Lint fixes --- ruff.toml | 5 +- torchao/_executorch_ops.py | 37 +++++++-- torchao/ops.py | 165 ++++++++++++++++++++++++++++--------- 3 files changed, 160 insertions(+), 47 deletions(-) diff --git a/ruff.toml b/ruff.toml index 0d02de24ed..45bb63d900 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,13 +8,14 @@ include = [ "torchao/dtypes/**/*.py", "torchao/sparsity/**/*.py", "torchao/prototype/low_bit_optim/**.py", + "torchao/utils.py", + "torchao/ops.py", + "torchao/_executorch_ops.py", "test/float8/**/*.py", "test/quantization/**/*.py", "test/dtypes/**/*.py", "test/sparsity/**/*.py", "test/prototype/low_bit_optim/**.py", - "torchao/utils.py", - ] lint.ignore = ["E731"] diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 6a1a66ab77..3cf94ee53d 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -10,9 +10,14 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.quantize_per_channel_group( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **kwargs): @@ -24,9 +29,14 @@ def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **k in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): @@ -38,9 +48,14 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.dequantize_per_channel_group( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): @@ -52,9 +67,12 @@ def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later.") + raise ImportError( + "Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): @@ -66,6 +84,9 @@ def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later.") + raise ImportError( + "Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later." + ) diff --git a/torchao/ops.py b/torchao/ops.py index 9713f68eb2..2774deb08a 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -3,13 +3,22 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - lib = torch.library.Library("torchao", "FRAGMENT") -lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") -lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") -lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") -lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") -lib.define("marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor") +lib.define( + "quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor" +) +lib.define( + "unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor" +) +lib.define( + "dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor" +) +lib.define( + "marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor" +) +lib.define( + "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" +) def register_custom_op(name): @@ -18,6 +27,7 @@ def decorator(func): return torch.library.register_fake(f"{name}")(func) else: return torch.library.impl_abstract(f"{name}")(func) + return decorator @@ -43,7 +53,9 @@ def quant_llm_linear( Returns output of linear layer """ - return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK) + return torch.ops.torchao.quant_llm_linear.default( + EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK + ) @register_custom_op("torchao::quant_llm_linear") @@ -55,12 +67,29 @@ def _( _scales: Tensor, splitK: int = 1, ) -> Tensor: - torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") - torch._check(_in_feats.dtype in (torch.float16, torch.bfloat16), lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}") - torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") - torch._check(_weights.dtype is torch.uint8, lambda: f"weight must be UINT8, got {_weights.dtype}") - torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") - torch._check(_scales.dtype in (torch.float16, torch.bfloat16), lambda: f"scale must be FP16 or BF16, got {_scales.dtype}") + torch._check( + _in_feats.dim() == 2, + lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D", + ) + torch._check( + _in_feats.dtype in (torch.float16, torch.bfloat16), + lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}", + ) + torch._check( + _weights.dim() == 2, + lambda: f"weight should be a 2d tensor, got {_weights.dim()}D", + ) + torch._check( + _weights.dtype is torch.uint8, + lambda: f"weight must be UINT8, got {_weights.dtype}", + ) + torch._check( + _scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D" + ) + torch._check( + _scales.dtype in (torch.float16, torch.bfloat16), + lambda: f"scale must be FP16 or BF16, got {_scales.dtype}", + ) BS, IC = _in_feats.shape OC, _ = _weights.shape @@ -71,7 +100,6 @@ def _( return _in_feats.new_empty((BS, OC)) - def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. @@ -115,7 +143,10 @@ def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor: return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) -def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: + +def dequantize_tensor_core_tiled_layout( + packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int +) -> Tensor: """ Dequantizes by: - Unpacking weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K` @@ -143,7 +174,9 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens @register_custom_op("torchao::dequantize_tensor_core_tiled_layout") -def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: +def _( + packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int +) -> Tensor: # packed_w preconditions torch._check( packed_w.dim() == 4, @@ -166,12 +199,28 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles K = packed_w.size(1) * inner_k_tiles * 16 # scales_and_zeros preconditions - torch._check(scales_and_zeros.dtype is torch.bfloat16, lambda: "scales_and_zeros must be bfloat16") - torch._check(scales_and_zeros.dim() == 3, lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}") - torch._check(group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, lambda: "qGroupSize must be 32, 64, 128, or 256") - torch._check(scales_and_zeros.size(0) == K // group_size, lambda: "scales_and_zeros must have K // qGroupSize at dim 0") - torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") - torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") + torch._check( + scales_and_zeros.dtype is torch.bfloat16, + lambda: "scales_and_zeros must be bfloat16", + ) + torch._check( + scales_and_zeros.dim() == 3, + lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}", + ) + torch._check( + group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, + lambda: "qGroupSize must be 32, 64, 128, or 256", + ) + torch._check( + scales_and_zeros.size(0) == K // group_size, + lambda: "scales_and_zeros must have K // qGroupSize at dim 0", + ) + torch._check( + scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1" + ) + torch._check( + scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2" + ) return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) @@ -224,27 +273,55 @@ def _( MAX_PARALLELISM = 64 # Verify num_bits - torch._check(bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}") + torch._check( + bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}" + ) pack_factor = 32 // bits # Verify M - torch._check(size_m == x.size(0), lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}") + torch._check( + size_m == x.size(0), + lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}", + ) # Verify K - torch._check(size_k == x.size(1), lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}") - torch._check(size_k % TILE_SIZE == 0, lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}") - torch._check((size_k // TILE_SIZE // 2) == weight_marlin.size(0), lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}") + torch._check( + size_k == x.size(1), + lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}", + ) + torch._check( + size_k % TILE_SIZE == 0, + lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}", + ) + torch._check( + (size_k // TILE_SIZE // 2) == weight_marlin.size(0), + lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}", + ) # Verify N - torch._check(s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}") - torch._check(weight_marlin.size(1) % TILE_SIZE == 0, lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}") + torch._check( + s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}" + ) + torch._check( + weight_marlin.size(1) % TILE_SIZE == 0, + lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}", + ) actual_size_n = (weight_marlin.size(1) // TILE_SIZE) * pack_factor - torch._check(size_n == actual_size_n, lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}") + torch._check( + size_n == actual_size_n, + lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}", + ) # Verify meta - torch._check(meta.size(0) == size_k // 8 // 2 // 2, lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}") - torch._check(meta.size(1) == size_n * 2, lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}") + torch._check( + meta.size(0) == size_k // 8 // 2 // 2, + lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}", + ) + torch._check( + meta.size(1) == size_n * 2, + lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}", + ) # Verify A device and strides torch._check(x.is_cuda, lambda: "x is not on GPU") @@ -252,7 +329,9 @@ def _( # Verify B device and strides torch._check(weight_marlin.is_cuda, lambda: "weight_marlin is not on GPU") - torch._check(weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous") + torch._check( + weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous" + ) # Verify meta device and strides torch._check(meta.is_cuda, lambda: "meta is not on GPU") @@ -265,15 +344,27 @@ def _( # Verify groupsize groupsize = -1 if s.size(0) > 1: - torch._check(size_k % s.size(0) == 0, lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}") + torch._check( + size_k % s.size(0) == 0, + lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}", + ) groupsize = size_k // s.size(0) groupsize //= 2 # Because of 24 - torch._check(groupsize == -1 or groupsize == 64, lambda: f"Unexpected groupsize = {groupsize}") + torch._check( + groupsize == -1 or groupsize == 64, + lambda: f"Unexpected groupsize = {groupsize}", + ) # Verify workspace size - torch._check(size_n % MIN_THREAD_N == 0, lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}") + torch._check( + size_n % MIN_THREAD_N == 0, + lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}", + ) min_workspace_size = (size_n // MIN_THREAD_N) * MAX_PARALLELISM - torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}") + torch._check( + workspace.numel() >= min_workspace_size, + lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}", + ) return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device) From 2f97b0955953fa1a46594a27f0df2bc48d93e79d Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 2 Dec 2024 16:19:52 -0500 Subject: [PATCH 15/40] Fix find_multiple import in GPTQ.py (#1367) This file relied on torchao/quantization/utils.py to import the function from torchao/utils.py, but that was removed in #1244. We actually don't need these relative imports since we already have an absolute import at the top of GPTQ.py anyway, so we should just remove these imports. --- torchao/quantization/GPTQ.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index c169271e8f..cb7c8d0481 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -580,8 +580,6 @@ def __init__( super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: - from .utils import find_multiple - self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -761,8 +759,6 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - from .utils import find_multiple - logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) @@ -1151,8 +1147,6 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - from .utils import find_multiple - logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) From 63d142ce8644242251e9017c32f1154a4066f436 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 2 Dec 2024 14:19:10 -0800 Subject: [PATCH 16/40] Lint fixes torchao/profiler and torchao/testing (#1368) --- ruff.toml | 2 + torchao/profiler/__init__.py | 2 - torchao/testing/float8/dtensor_utils.py | 1 - torchao/testing/float8/fsdp2_utils.py | 30 +++++--- torchao/testing/float8/test_utils.py | 6 +- torchao/testing/utils.py | 98 ++++++++++++++----------- 6 files changed, 79 insertions(+), 60 deletions(-) diff --git a/ruff.toml b/ruff.toml index 45bb63d900..b20cab030c 100644 --- a/ruff.toml +++ b/ruff.toml @@ -7,6 +7,8 @@ include = [ "torchao/quantization/**/*.py", "torchao/dtypes/**/*.py", "torchao/sparsity/**/*.py", + "torchao/profiler/**/*.py", + "torchao/testing/**/*.py", "torchao/prototype/low_bit_optim/**.py", "torchao/utils.py", "torchao/ops.py", diff --git a/torchao/profiler/__init__.py b/torchao/profiler/__init__.py index e748438e87..976d4e3a05 100644 --- a/torchao/profiler/__init__.py +++ b/torchao/profiler/__init__.py @@ -1,4 +1,3 @@ - # Re-exports from .device_spec import CUDADeviceSpec, DeviceSpec from .performance_counter import ( @@ -20,4 +19,3 @@ "DeviceSpec", "total_model_params", ] - diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/float8/dtensor_utils.py index 1fab31d850..84e4095263 100644 --- a/torchao/testing/float8/dtensor_utils.py +++ b/torchao/testing/float8/dtensor_utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import torch import torch.nn as nn import torch.nn.functional as F diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index 7744ae4e92..af46b7fa71 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -1,16 +1,13 @@ -import contextlib -from typing import List, Optional +from typing import List import torch import torch.distributed as dist import torch.nn as nn -import torchao.float8.config as config from torchao.float8.config import ( Float8LinearConfig, ScalingType, ) - from torchao.float8.float8_linear_utils import ( linear_requires_sync, sync_float8_amax_and_scale_history, @@ -52,7 +49,11 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) def check_parity_bf16_mp( @@ -87,7 +88,11 @@ def check_parity_bf16_mp( ref_model.parameters(), ref_model_bf16.parameters() ): param_bf16.detach().copy_(param_fp32) - test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) def check_parity_fp8_comm_only( @@ -104,7 +109,6 @@ def check_parity_fp8_comm_only( for iter_idx in range(10): losses: List[torch.Tensor] = [] for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): - optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) losses.append(model(local_inp).sum()) losses[-1].backward() @@ -123,9 +127,15 @@ def check_parity_fp8_comm_only( and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC ): precompute_float8_dynamic_scale_for_fsdp(model) - + if compile: # When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now. - assert (torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}" + assert ( + torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any() + ), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}" else: - test_cls.assertEqual(losses[0], losses[1], f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 7f37c3f30a..7b8ac121b6 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -1,9 +1,9 @@ import torch + from torchao.float8.config import ( - ScalingGranularity, - ScalingType, - CastConfig, + CastConfig, Float8LinearConfig, + ScalingType, ) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 39edc50085..d88241783f 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -1,15 +1,19 @@ -import unittest -import functools import copy -import torch -import torchao -import os +import functools +import unittest +import torch +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils -from torchao.dtypes import AffineQuantizedTensor -from torchao.dtypes import to_affine_quantized_intx +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + +import torchao +from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx +from torchao.quantization import int8_weight_only, quantize_ from torchao.quantization.quant_primitives import MappingType -from torchao.quantization import quantize_, int8_weight_only from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 """ @@ -36,10 +40,9 @@ class MyTestCase(TorchAOBasicTestCase): unittest.main() """ + # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 -def copy_tests( - my_cls, other_cls, suffix, test_failures=None, xfail_prop=None -): # noqa: B902 +def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 for name, value in my_cls.__dict__.items(): if name.startswith("test_"): # You cannot copy functions in Python, so we use closures here to @@ -70,7 +73,6 @@ def new_test(self, value=value): setattr(other_cls, f"{name}_{suffix}", new_test) - class TorchAOBasicTestCase(common_utils.TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -90,17 +92,21 @@ def test_flatten_unflatten(self): hp_tensor = torch.randn(4, 128) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } outer_size = lp_tensor.size() outer_stride = lp_tensor.stride() - reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize()) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_hp_tensor_device_dtype(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) - lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + self.FACTORY_FN(hp_tensor, **self.kwargs) @common_utils.parametrize("device1", COMMON_DEVICES) @common_utils.parametrize("device2", COMMON_DEVICES) @@ -141,7 +147,10 @@ def test_linear(self, device, dtype): hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) - self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + self.assertGreater( + torchao.quantization.utils.compute_error(hp_res, lp_res), + self.LINEAR_MIN_SQNR, + ) class TorchAOCompileTestCase(common_utils.TestCase): @@ -165,6 +174,7 @@ class TorchAOCompileTestCase(common_utils.TestCase): def test_input_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): return tensor @@ -179,6 +189,7 @@ def f(tensor): def test_input_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): return tensor.dequantize() @@ -192,6 +203,7 @@ def f(tensor): @common_utils.parametrize("dtype", COMMON_DTYPES) def test_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + def f(hp_tensor): return self.FACTORY_FN(hp_tensor, **self.kwargs) @@ -201,7 +213,12 @@ def f(hp_tensor): self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS)) # bfloat16 seems to result in much larger numerical differences if dtype != torch.bfloat16: - self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR) + self.assertGreater( + torchao.quantization.utils.compute_error( + ref.dequantize(), compiled.dequantize() + ), + self.COMPILE_MIN_SQNR, + ) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -211,22 +228,18 @@ def test_linear_compile(self, device, dtype): hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) - l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) - l.weight = torch.nn.Parameter(lp_tensor) - lp_res = torch.compile(l)(hp_act_tensor) - self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + linear = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) + linear.weight = torch.nn.Parameter(lp_tensor) + lp_res = torch.compile(linear)(hp_act_tensor) + self.assertGreater( + torchao.quantization.utils.compute_error(hp_res, lp_res), + self.LINEAR_MIN_SQNR, + ) -import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, - NUM_DEVICES, -) class TorchAOTensorParallelTestCase(DTensorTestBase): - """Basic test case for tensor subclasses - """ + """Basic test case for tensor subclasses""" + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor @@ -247,9 +260,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m @staticmethod @@ -266,9 +277,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m def quantize(self, m: torch.nn.Module) -> torch.nn.Module: @@ -289,7 +298,9 @@ def test_tp(self, dtype): class M(torch.nn.Module): def __init__(self, in_features, out_features, **kwargs) -> None: super().__init__(**kwargs) - self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + self.linear = torch.nn.Linear( + in_features, out_features, bias=False, device="cuda" + ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -301,12 +312,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: proj_up = M(1024, 2048).to(device).to(dtype) proj_dn = M(2048, 1024).to(device).to(dtype) example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) - y = proj_dn(proj_up(example_input)) + proj_dn(proj_up(example_input)) # Quantize the model up_quant = self.quantize(proj_up) dn_quant = self.quantize(proj_dn) - y_q = dn_quant(up_quant(example_input)) + dn_quant(up_quant(example_input)) mesh = self.build_device_mesh() mesh.device_type = "cuda" @@ -316,11 +327,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist = self.rowwise_shard(dn_quant, mesh) # We need to turn inputs into DTensor form as well -- just a format change - input_dtensor = DTensor.from_local( - example_input, mesh, [Replicate()] - ) + input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()]) - y_d = dn_dist(up_dist(input_dtensor)) + dn_dist(up_dist(input_dtensor)) if not TORCH_VERSION_AT_LEAST_2_6: # Need torch 2.6 to support compiled tensor parallelism @@ -329,7 +338,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) - y_dn = dn_compiled(y_up) + dn_compiled(y_up) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) From 8a51e1ace6d19ec69e1d67f8825004304cb1e440 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 2 Dec 2024 17:40:22 -0800 Subject: [PATCH 17/40] Fix bfloat16/float16/float32 options (#1369) * Fix bfloat16/float16/float32 options Summary: There was some problems with previous implementation of bfloat16/float16/float32 since it does not convert activation to the correct dtype after quantization, this PR fixes it Test Plan: llama: ``` python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-fp ``` same2: ``` server: python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant client: time xargs -I {} curl -s -w "\n" -X POST http://localhost:8000/upload_rle -F 'image=@{}' < sav_val_image_paths_shuf_1000 > results/sav_val_masks_baseline_shuf_1000 ``` Reviewers: Subscribers: Tasks: Tags: * ruff --- examples/sam2_amg_server/server.py | 23 +-- torchao/_models/llama/generate.py | 20 ++- .../prototype/quantization/autoquant_v2.py | 143 +++++++++++++++++- torchao/quantization/autoquant.py | 138 ++++++++++++----- 4 files changed, 266 insertions(+), 58 deletions(-) diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index ba1aed7a00..4c81342ff6 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -468,7 +468,7 @@ def load_aot_fast(mask_generator, model_directory): pkg = torch._inductor.aoti_load_package(str(path)) pkg_m = LoadedModel(pkg) mask_generator.predictor.model.image_encoder = pkg_m - + # NOTE: This doesn't work yet! # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2")) # pkg_m = LoadedModel(pkg) @@ -526,6 +526,18 @@ def set_furious(mask_generator): # NOTE: Not baseline feature mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 +def set_autoquant(mask_generator): + from torchao import autoquant + from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + # NOTE: Not baseline feature + mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision('high') + # NOTE: this fails when we run + # python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant --unittest + # https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e + # mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + def main(checkpoint_path, model_type, @@ -590,14 +602,7 @@ def main(checkpoint_path, set_furious(mask_generator) # since autoquant is replicating what furious mode is doing, don't use these two together elif use_autoquant: - from torchao import autoquant - from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST - mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) - - # mask_generator.predictor.model.image_encoder = mask_generator.predictor.model.image_encoder.to(torch.float16, min_sqnr=40) - # NOTE: Not baseline feature - mask_generator.predictor._transforms_device = mask_generator.predictor.device - torch.set_float32_matmul_precision('high') + set_autoquant(mask_generator) with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index d617ceb304..9619721614 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -357,11 +357,19 @@ def main( ) if "autoquant_v2-int4" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) elif "autoquant_v2-float8" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + elif "autoquant_v2-fp" == quantization: + model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + elif "autoquant_v2-all" == quantization: + all_qtensor_classes = torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST + model = autoquant_v2(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs, batch_size=calibration_seq_length) else: - model = autoquant_v2(model, manual=True, example_input=inputs) + model = autoquant_v2(model, manual=True, example_input=inputs, batch_size=calibration_seq_length) print("running generate") generate( @@ -406,6 +414,12 @@ def main( model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) if "autoquant-fp" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs) + if "autoquant-all" == quantization: + all_qtensor_classes = torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST + model = autoquant(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs) else: model = autoquant(model, manual=True, example_input=inputs) diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index a11fe861e4..bf6dbb2a46 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -30,7 +30,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, - benchmark_model, + TorchAOBaseTensor, ) from torchao.quantization.granularity import ( @@ -61,6 +61,7 @@ "autoquant_v2", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", "_is_linear", ] @@ -288,7 +289,7 @@ def to_quantized(self, error_on_unseen, **kwargs): ) elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen - self = AQFloatLinearWeight.from_float(self.weight) + self = AQDefaultLinearWeight.from_float(self.weight) return self # only want to print shape (at start) and final result (at end) @@ -360,7 +361,7 @@ def count_shapes(self, do_print=True): print(f"best_cls={best_cls}\n") # TODO handle random cls args/kwargs? or should they be curried? if best_cls is None: - best_cls = AQFloatLinearWeight + best_cls = AQDefaultLinearWeight self = best_cls.from_float(self.weight) return self @@ -802,7 +803,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 -class AQFloatLinearWeight(torch.Tensor, AQMixin): +class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the @@ -823,6 +824,130 @@ def from_float(cls, weight): return weight +class Float32Tensor(TorchAOBaseTensor): + """ Tensor subclass tensor for fp32 dtype + """ + def __init__(self, weight): + self.weight = weight.to(torch.float32) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + + @classmethod + def from_float(cls, weight): + return cls(weight) + +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat32LinearWeight, cls).from_float(weight) + + +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) + + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat16LinearWeight, cls).from_float(weight) + + class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn @@ -936,7 +1061,7 @@ def get_weight_block_size(x): # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, # AQInt8WeightOnlyQuantizedLinearWeight3, @@ -945,11 +1070,17 @@ def get_weight_block_size(x): ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt4G64WeightOnlyQuantizedLinearWeight, ] +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 1731b6cf39..b486683290 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -22,7 +22,11 @@ compute_error, quantize_activation_per_token_absmax, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) from .granularity import ( PerRow, @@ -679,79 +683,133 @@ def from_float(cls, weight): return weight -class AQFloat32LinearWeight(torch.Tensor, AQMixin): - """ - AutoQuantizable version for float32 precision weight +class Float32Tensor(TorchAOBaseTensor): + """Tensor subclass tensor for fp32 dtype""" - (also converts input activation and bias to float32, and restores the original precision after - linear) - """ - - def __init__(self): - super().__init__() + def __init__(self, weight): + self.weight = weight.to(torch.float32) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 orig_dtype = act_mat.dtype return torch.nn.functional.linear( - act_mat.to(torch.float32), - w_qtensor, - bias.to(torch.float32) if bias is not None else bias, + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, ).to(dtype=orig_dtype) + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + @classmethod def from_float(cls, weight): - return weight.to(torch.float32) + return cls(weight) -class AQBFloat16LinearWeight(torch.Tensor, AQMixin): - """ - AutoQuantizable version for bfloat16 precision weight +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - (also converts input activation and bias to bfloat16, and restores the original precision after - linear) - """ - def __init__(self): - super().__init__() +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 orig_dtype = act_mat.dtype return torch.nn.functional.linear( - act_mat.to(torch.bfloat16), - w_qtensor, - bias.to(torch.bfloat16) if bias is not None else bias, + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, ).to(dtype=orig_dtype) + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + @classmethod def from_float(cls, weight): - return weight.to(torch.bfloat16) + return super(AQFloat32LinearWeight, cls).from_float(weight) -class AQFloat16LinearWeight(torch.Tensor, AQMixin): +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): """ - AutoQuantizable version for float16 precision weight + AutoQuantizable version for bfloat16 precision weight - (also converts input activation and bias to float16, and restores the original precision after + (also converts input activation and bias to bfloat16, and restores the original precision after linear) """ - def __init__(self): - super().__init__() + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) - @staticmethod - def _quantized_linear_op(act_mat, w_qtensor, bias): - orig_dtype = act_mat.dtype - return torch.nn.functional.linear( - act_mat.to(torch.float16), - w_qtensor, - bias.to(torch.float16) if bias is not None else bias, - ).to(dtype=orig_dtype) + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ @classmethod def from_float(cls, weight): - return weight.to(torch.float16) + return super(AQFloat16LinearWeight, cls).from_float(weight) class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): From ebb908623cdd88508b5ab13123281655d6a81548 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 3 Dec 2024 08:48:06 -0800 Subject: [PATCH 18/40] Move profiler -> prototype (#1370) --- test/{profiler => prototype}/test_device_spec.py | 2 +- test/{profiler => prototype}/test_performance_counter.py | 4 ++-- test/{profiler => prototype}/utils.py | 2 +- torchao/_models/llama/perf_profile.py | 6 +++--- torchao/{ => prototype}/profiler/__init__.py | 0 torchao/{ => prototype}/profiler/device_spec.py | 0 torchao/{ => prototype}/profiler/performance_counter.py | 0 torchao/{ => prototype}/profiler/utils.py | 0 8 files changed, 7 insertions(+), 7 deletions(-) rename test/{profiler => prototype}/test_device_spec.py (97%) rename test/{profiler => prototype}/test_performance_counter.py (99%) rename test/{profiler => prototype}/utils.py (98%) rename torchao/{ => prototype}/profiler/__init__.py (100%) rename torchao/{ => prototype}/profiler/device_spec.py (100%) rename torchao/{ => prototype}/profiler/performance_counter.py (100%) rename torchao/{ => prototype}/profiler/utils.py (100%) diff --git a/test/profiler/test_device_spec.py b/test/prototype/test_device_spec.py similarity index 97% rename from test/profiler/test_device_spec.py rename to test/prototype/test_device_spec.py index 1ede428fe0..dd159f5336 100644 --- a/test/profiler/test_device_spec.py +++ b/test/prototype/test_device_spec.py @@ -8,7 +8,7 @@ import torch from utils import patch_device -from torchao.profiler.device_spec import ( +from torchao.prototype.profiler.device_spec import ( _AVAILABLE_GPU_SPECS, CUDADeviceSpec, get_chip_name, diff --git a/test/profiler/test_performance_counter.py b/test/prototype/test_performance_counter.py similarity index 99% rename from test/profiler/test_performance_counter.py rename to test/prototype/test_performance_counter.py index 2cd1a33581..6ece2c6398 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/prototype/test_performance_counter.py @@ -30,8 +30,8 @@ qkv_proj_io_check, ) -from torchao.profiler.device_spec import CUDADeviceSpec, DeviceSpec -from torchao.profiler.performance_counter import ( +from torchao.prototype.profiler.device_spec import CUDADeviceSpec, DeviceSpec +from torchao.prototype.profiler.performance_counter import ( CUDAPerformanceTimer, PerformanceCounterMode, PerformanceStats, diff --git a/test/profiler/utils.py b/test/prototype/utils.py similarity index 98% rename from test/profiler/utils.py rename to test/prototype/utils.py index 7b2b999809..8c402b8114 100644 --- a/test/profiler/utils.py +++ b/test/prototype/utils.py @@ -5,7 +5,7 @@ import torch -from torchao.profiler import PerformanceTimer +from torchao.prototype.profiler import PerformanceTimer @contextmanager diff --git a/torchao/_models/llama/perf_profile.py b/torchao/_models/llama/perf_profile.py index 1a0d4e36c0..f613982221 100644 --- a/torchao/_models/llama/perf_profile.py +++ b/torchao/_models/llama/perf_profile.py @@ -2,9 +2,9 @@ ## Performance Profiling Example -An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.profiler.TransformerPerformanceCounter`. +An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.prototype.profiler.TransformerPerformanceCounter`. - Outputs from gpt-fast are prefixed with GPT-Fast -- Outputs from `torchao.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. +- Outputs from `torchao.prototype.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. ## Usage ```python @@ -118,7 +118,7 @@ from torchao._models.llama.model import Transformer from torchao._models.llama.tokenizer import get_tokenizer -from torchao.profiler import ( +from torchao.prototype.profiler import ( CUDADeviceSpec, TransformerPerformanceCounter, total_model_params, diff --git a/torchao/profiler/__init__.py b/torchao/prototype/profiler/__init__.py similarity index 100% rename from torchao/profiler/__init__.py rename to torchao/prototype/profiler/__init__.py diff --git a/torchao/profiler/device_spec.py b/torchao/prototype/profiler/device_spec.py similarity index 100% rename from torchao/profiler/device_spec.py rename to torchao/prototype/profiler/device_spec.py diff --git a/torchao/profiler/performance_counter.py b/torchao/prototype/profiler/performance_counter.py similarity index 100% rename from torchao/profiler/performance_counter.py rename to torchao/prototype/profiler/performance_counter.py diff --git a/torchao/profiler/utils.py b/torchao/prototype/profiler/utils.py similarity index 100% rename from torchao/profiler/utils.py rename to torchao/prototype/profiler/utils.py From b7630f19dff77c8f99538fa6d2e25a57e0bccdce Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Tue, 3 Dec 2024 15:09:57 -0800 Subject: [PATCH 19/40] SAM2 AMG load_fast fix (#1374) --- examples/sam2_amg_server/server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index 4c81342ff6..060a5ad5dd 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -588,6 +588,12 @@ def main(checkpoint_path, if load_fast != "": load_aot_fast(mask_generator, load_fast) + if furious: + set_furious(mask_generator) + # since autoquant is replicating what furious mode is doing, don't use these two together + elif use_autoquant: + set_autoquant(mask_generator) + if save_fast != "": assert load_fast == "", "Can't save compiled models while loading them with --load-fast." assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." @@ -598,12 +604,6 @@ def main(checkpoint_path, assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." set_fast(mask_generator, load_fast) - if furious: - set_furious(mask_generator) - # since autoquant is replicating what furious mode is doing, don't use these two together - elif use_autoquant: - set_autoquant(mask_generator) - with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) From 1a0dbf1c41ad1c6f28d6501e1134b30ea2f2590d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 3 Dec 2024 21:05:05 -0500 Subject: [PATCH 20/40] Add TTFT benchmarks + update sparsity benchmarks (#1140) This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available. Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before. --- scripts/prepare.sh | 4 + test/prototype/test_sparse_api.py | 3 + torchao/_models/llama/benchmarks.sh | 21 +++- torchao/_models/llama/generate.py | 116 ++++++++++++++++++--- torchao/dtypes/uintx/semi_sparse_layout.py | 8 +- 5 files changed, 136 insertions(+), 16 deletions(-) diff --git a/scripts/prepare.sh b/scripts/prepare.sh index db426e3b11..9cbc8295ee 100644 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B python scripts/download.py --repo_id meta-llama/Llama-3.2-3B +python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4 python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B +# neuralmagic doesn't come with tokenizer, so we need to copy it over +mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4 diff --git a/test/prototype/test_sparse_api.py b/test/prototype/test_sparse_api.py index 757eb9f913..f3cdbe8386 100644 --- a/test/prototype/test_sparse_api.py +++ b/test/prototype/test_sparse_api.py @@ -50,6 +50,9 @@ def test_sparse(self): sparsify_(model, semi_sparse_weight()) sparse_result = model(input) + if compile: + model = torch.compile(model) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 63733c736d..c8cd4bf39c 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128 + +# TTFT benchmarks +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured + +# 2:4 sparse model +export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 9619721614..065cc9c56d 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -17,6 +17,29 @@ from torchao.quantization.quant_primitives import MappingType from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + +class HostEvent: + def __init__(self): + self.event_time = None + + def record(self): + self.event_time = time.perf_counter() + + def elapsed_time(self, other_event): + if self.event_time is None: + raise ValueError("Event not recorded!") + # return ms to match cuda event + return abs(other_event.event_time - self.event_time) * 1000 + +def device_timer(device): + if "cuda" in device: + return torch.cuda.Event(enable_timing=True) + elif ("cpu" in device) or ("mps" in device): + return HostEvent() + else: + print(f"device={device} is not yet suppported") + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) @@ -98,6 +121,10 @@ def generate( kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, + prefill_start_event: Optional[torch.cuda.Event]=None, + prefill_end_event: Optional[torch.cuda.Event]=None, + decode_start_event: Optional[torch.cuda.Event]=None, + decode_end_event: Optional[torch.cuda.Event]=None, **sampling_kwargs ) -> torch.Tensor: """ @@ -128,12 +155,21 @@ def generate( model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T) # execute prefill + if prefill_start_event is not None: + prefill_start_event.record() next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() seq[:, T] = next_token.squeeze() + if prefill_end_event is not None: + prefill_end_event.record() + # execute token generation + if decode_start_event is not None: + decode_start_event.record() input_pos = torch.tensor([T], device=device, dtype=torch.int) generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + if decode_end_event is not None: + decode_end_event.record() return seq @@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision): B_INST, E_INST = "[INST]", "[/INST]" def main( + prefill_size: Optional[int] = None, prompt: str = "Hello, my name is", interactive: bool = False, num_samples: int = 5, @@ -166,6 +203,7 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, + sparsity: Optional[str] = None, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, @@ -181,6 +219,10 @@ def main( """Generates text samples based on a pre-trained Transformer model and tokenizer. """ + if prefill_size is not None and prefill_size > 0: + # create prompt of prefill size + prompt = "prompt " * (int(prefill_size)-3) + torchao.quantization.utils.recommended_inductor_config_setter() assert checkpoint_path.is_file(), checkpoint_path @@ -205,6 +247,14 @@ def main( torch.manual_seed(1234) + def ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn + + def not_ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) + + def ffn_or_attn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn) if quantization: from torchao.quantization import ( @@ -228,9 +278,14 @@ def main( apply_spinquant(model) if "int8wo" in quantization: quantize_(model, int8_weight_only()) - elif "int8dq" in quantization: - quantize_(model, int8_dynamic_activation_int8_weight()) - elif "int4wo" in quantization: + if "int8dq" in quantization: + if sparsity and "semi" in sparsity: + from torchao.dtypes import SemiSparseLayout + quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only) + quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only) + else: + quantize_(model, int8_dynamic_activation_int8_weight()) + if "int4wo" in quantization: if "hqq" in quantization: use_hqq=True else: @@ -250,9 +305,9 @@ def main( layout=MarlinQQQLayout(), ), ) - else: + elif "semi" in sparsity: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) + quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) elif "embed-int8wo" in quantization: @@ -440,6 +495,13 @@ def main( if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) + # standalone sparsity + elif sparsity: + from torchao.sparsity import semi_sparse_weight, sparsify_ + if "semi" in sparsity: + #TODO there is a bug here, need to fix + sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 if save: @@ -465,6 +527,9 @@ def main( aggregate_metrics = { 'tokens_per_sec': [], + 'time': [], + 'decode_tokens_per_sec': [], + 'prefill_time': [], } start = -1 if compile else 0 @@ -499,6 +564,8 @@ def callback(x): else: callback = lambda x : x t0 = time.perf_counter() + prefill_start_event, prefill_end_event = device_timer(device), device_timer(device) + decode_start_event, decode_end_event = device_timer(device), device_timer(device) import contextlib if (i != num_samples - 1 or not profile): prof = contextlib.nullcontext() @@ -518,6 +585,10 @@ def callback(x): kv_cache_quantization=kv_cache_quantization, cache_size=cache_size, linear_causal_mask=linear_causal_mask, + prefill_start_event=prefill_start_event, + prefill_end_event=prefill_end_event, + decode_start_event=decode_start_event, + decode_end_event=decode_end_event, ) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") @@ -527,7 +598,7 @@ def callback(x): device_sync(device=device) # MKG t = time.perf_counter() - t0 - if not interactive: + if not interactive and prefill_size is None: tok_list = y[0].tolist() # truncate text after end of string token tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())] @@ -537,7 +608,14 @@ def callback(x): tokens_generated = (y.size(-1) - prompt_length) tokens_sec = tokens_generated / t aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + aggregate_metrics['time'].append(t) + decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000 + decode_tokens_sec = tokens_generated / decode_time + aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec) + prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000 + aggregate_metrics['prefill_time'].append(prefill_time) + print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", + f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") if memory_profile and i==0: @@ -558,8 +636,15 @@ def callback(x): break print("==========") + #ignore first sample for warmup tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item() + decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item() bandwidth = model_size * tokpersec + mem = torch.cuda.max_memory_reserved() /1e9 + print(f"Average overall tokens/sec: {tokpersec:.2f}") + print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s") + print(f"Average TTFT: {ttft:.04f} s") if device == "cuda": mem = torch.cuda.max_memory_reserved() /1e9 elif device == "xpu": @@ -571,15 +656,17 @@ def callback(x): print(f"Peak Memory Usage: {mem:.02f} GB") print(f"Model Size: {model_size:.02f} GB") if write_result: - result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " + result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " + result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " result_txt += f"repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" + result_txt += f"--sparsity {sparsity} " if sparsity else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " result_txt += f"--compile " if compile else "" result_txt += f"--compile_prefill " if compile_prefill else "" + result_txt += f"--prefill_size {prefill_size}" if prefill_size else "" result_txt += f"--profile {profile} " if profile else "" result_txt += f"--profile {memory_profile} " if memory_profile else "" result_txt += f"--interactive " if interactive else "" @@ -601,7 +688,7 @@ def callback(x): if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Your CLI description.') - + parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode') parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') @@ -617,6 +704,11 @@ def callback(x): +'embed-int8wo, marlin_qqq' ) ) + parser.add_argument('-s', '--sparsity', type=str, + help=( + 'Which sparsity techniques to apply: semi-structured' + ) + ) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') @@ -631,6 +723,6 @@ def callback(x): args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, + args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result ) diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index e2c94a7a38..d832731657 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -41,13 +41,17 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( w_vals_int8 = weight_tensor.tensor_impl.int_data w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # must pad + row, col = tmp.shape + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( w_vals_int8, - tmp.t(), + tmp_padded.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t() + ).t()[:row, :] y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) From 04d611a1abdc0f1507603c8cbb37a0f8c6707c00 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Wed, 4 Dec 2024 08:29:49 -0800 Subject: [PATCH 21/40] Unskip `test_int8_dynamic_quant_subclass_api` in `test_integration.py` (#1375) This UT passes for both CPU & CUDA --- test/integration/test_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 10f2d157f9..6aae8b2e31 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -885,7 +885,6 @@ def _test_lin_weight_subclass_api_impl( @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( _int8da_int8w_api, device, 35, test_dtype=dtype From 53d24866f052459c07d5023399f1caf8ee1a8c69 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 4 Dec 2024 09:54:23 -0800 Subject: [PATCH 22/40] fix lint (#1379) Summary: run `ruff format` to fix lint on main branch Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/uintx/semi_sparse_layout.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index d832731657..a554fd9bc6 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -44,6 +44,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( # must pad row, col = tmp.shape from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( @@ -51,7 +52,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( tmp_padded.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t()[:row, :] + ).t()[:row, :] y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) From 6a177c90d88eebab292d8655cfcf69e6236c5f2b Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Wed, 4 Dec 2024 22:13:02 +0100 Subject: [PATCH 23/40] [float8] Re-enable slow-accum in the bwd of axis-wise scaling schemes (#1377) --- torchao/float8/config.py | 23 ------------------ torchao/float8/float8_python_api.py | 36 +++++++++++++++++++---------- 2 files changed, 24 insertions(+), 35 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 175ab03f3c..6a092d5f38 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -170,7 +170,6 @@ class Float8LinearConfig: # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` - # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() @@ -317,21 +316,10 @@ def recipe_name_to_linear_config( cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only - # fast with `use_fast_accum=True`. Note that rowwise scaling is more - # accurate than tensorwise scaling, so the overall impact on accuracy - # of tensorwise vs rowwise taking this flag into account will vary. - gc_o = Float8GemmConfig(use_fast_accum=True) - gc_gi = Float8GemmConfig(use_fast_accum=True) - gc_gw = Float8GemmConfig(use_fast_accum=True) - return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, - gemm_config_output=gc_o, - gemm_config_grad_input=gc_gi, - gemm_config_grad_weight=gc_gw, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: @@ -359,14 +347,6 @@ def recipe_name_to_linear_config( cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) - # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only - # fast with `use_fast_accum=True`. Note that rowwise scaling is more - # accurate than tensorwise scaling, so the overall impact on accuracy - # of tensorwise vs rowwise taking this flag into account will vary. - gc_o = Float8GemmConfig(use_fast_accum=True) - gc_gi = Float8GemmConfig(use_fast_accum=True) - gc_gw = Float8GemmConfig(use_fast_accum=True) - return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, @@ -374,9 +354,6 @@ def recipe_name_to_linear_config( cast_config_input_for_grad_weight=cc_i_gw, cast_config_weight_for_grad_input=cc_w_gi, cast_config_grad_output_for_grad_weight=cc_go_gw, - gemm_config_output=gc_o, - gemm_config_grad_input=gc_gi, - gemm_config_grad_weight=gc_gw, ) else: diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py index 6608dba958..402ce2eb0f 100644 --- a/torchao/float8/float8_python_api.py +++ b/torchao/float8/float8_python_api.py @@ -37,19 +37,25 @@ def addmm_float8_unwrapped( a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() - if output_dtype == torch.float32 and bias is not None: + post_inverse_scale = None + if ( + a_scale.shape == (a_data.shape[0], 1) + and b_scale.shape == (1, b_data.shape[1]) + and not use_fast_accum + ): + # The rowwise CUTLASS-based kernel is so slow without fast-accum that + # we'd rather use the tensorwise cuBLAS-based kernel and do the scaling + # manually afterwards (hoping Inductor will be able to fuse it). + post_inverse_scale = a_inverse_scale * b_inverse_scale + a_inverse_scale = a_inverse_scale.new_ones(()) + b_inverse_scale = a_inverse_scale.new_ones(()) + + post_bias = None + if output_dtype == torch.float32: # Bias is not supported by _scaled_mm when output is fp32 - output = torch._scaled_mm( - a_data, - b_data, - scale_a=a_inverse_scale, - scale_b=b_inverse_scale, - scale_result=output_scale, - out_dtype=output_dtype, - use_fast_accum=use_fast_accum, - ) - output += bias - return output + post_bias = bias + bias = None + output = torch._scaled_mm( a_data, b_data, @@ -60,4 +66,10 @@ def addmm_float8_unwrapped( out_dtype=output_dtype, use_fast_accum=use_fast_accum, ) + + if post_inverse_scale is not None: + output *= post_inverse_scale + if post_bias is not None: + output += post_bias + return output From 4f8021f2a013203cb3fc410a196321e60ea754ae Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Wed, 4 Dec 2024 22:14:54 +0100 Subject: [PATCH 24/40] [float8] Allow specifying arbitrary dtype for each tensor (#1378) --- test/float8/test_base.py | 6 +- test/float8/test_compile.py | 2 +- test/float8/test_dtensor.py | 8 +-- torchao/float8/config.py | 82 ++++++++++++++------- torchao/float8/float8_linear.py | 85 ++++++++++++---------- torchao/float8/float8_linear_utils.py | 18 +++-- torchao/float8/float8_scaling_utils.py | 33 +++++---- torchao/float8/float8_tensor.py | 5 +- torchao/float8/float8_tensor_parallel.py | 21 ++++-- torchao/float8/float8_utils.py | 10 +-- torchao/float8/fsdp_utils.py | 92 ++++++++++++++++++------ 11 files changed, 229 insertions(+), 133 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index ba6281deaf..58df3a343c 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -30,6 +30,8 @@ Float8LinearRecipeName, ScalingGranularity, ScalingType, + e4m3_dtype, + e5m2_dtype, recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear @@ -53,8 +55,6 @@ from torchao.float8.float8_utils import ( FP8_TYPES, compute_error, - e4m3_dtype, - e5m2_dtype, fp8_tensor_statistics, tensor_to_scale, ) @@ -546,7 +546,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s + assert "i:dyn_ten_e4m3,w:del_ten_e4m3,go:dyn_ten_e5m2" in s @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 6d21686e32..9a9e555cb2 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -30,6 +30,7 @@ Float8LinearConfig, Float8LinearRecipeName, ScalingType, + e4m3_dtype, recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear @@ -47,7 +48,6 @@ LinearMMConfig, ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype from torchao.testing.float8.test_utils import get_test_float8_linear_config diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 5985a3f5b5..41b21e4406 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -31,9 +31,9 @@ from tqdm import tqdm from torchao.float8 import Float8LinearConfig -from torchao.float8.config import CastConfig, ScalingType +from torchao.float8.config import CastConfig, ScalingType, e4m3_dtype from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic +from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -45,7 +45,7 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) -from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale +from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.dtensor_utils import ToyModel @@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig()) + out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward() diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 6a092d5f38..d4a5516154 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -53,6 +53,35 @@ def short_str(self): return "axs" +@dataclass +class Float8TypeConfig: + """ + Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz. + + Currently, ROCm only supports fnuz variants. + """ + + # The preferred e4m3 type. + e4m3_dtype = torch.float8_e4m3fn + + # The preferred e5m2 type. + e5m2_dtype = torch.float8_e5m2 + + def __post_init__(self): + if torch.version.hip and torch.cuda.is_available(): + prop = torch.cuda.get_device_properties(0) + MI300_ARCH = ("gfx940", "gfx941", "gfx942") + if prop.gcnArchName.split(":")[0] in MI300_ARCH: + self.e4m3_dtype = torch.float8_e4m3fnuz + self.e5m2_dtype = torch.float8_e5m2fnuz + + +# User defined type for using the individual F8 type based on config +type_config = Float8TypeConfig() +e4m3_dtype = type_config.e4m3_dtype +e5m2_dtype = type_config.e5m2_dtype + + @dataclass(frozen=True) class CastConfig: """ @@ -62,9 +91,11 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None + target_dtype: Optional[torch.dtype] = None def short_str(self): - return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" + dtype = {e4m3_dtype: "e4m3", e5m2_dtype: "e5m2"}[self.target_dtype] + return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}" def __post_init__(self): if self.scaling_type is ScalingType.STATIC: @@ -75,6 +106,9 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" + assert self.target_dtype is None or ( + self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1 + ), "must specify a 8-bit floating-point dtype" @dataclass(frozen=True) @@ -101,29 +135,6 @@ def __post_init__(self): ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." -@dataclass -class Float8TypeConfig: - """ - Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz. - - Currently, ROCm only supports fnuz variants. - """ - - # The preferred e4m3 type. - e4m3_dtype = torch.float8_e4m3fn - - # The preferred e5m2 type. - e5m2_dtype = torch.float8_e5m2 - - def __post_init__(self): - if torch.version.hip and torch.cuda.is_available(): - prop = torch.cuda.get_device_properties(0) - MI300_ARCH = ("gfx940", "gfx941", "gfx942") - if prop.gcnArchName.split(":")[0] in MI300_ARCH: - self.e4m3_dtype = torch.float8_e4m3fnuz - self.e5m2_dtype = torch.float8_e5m2fnuz - - @dataclass(frozen=True) class Float8GemmConfig: """ @@ -276,6 +287,20 @@ def __post_init__(self): is_disabled_1 == is_disabled_2 ), f"incompatible operand precision for {gemm_name}" + for cc1, cc2, operand_name, default_dtype in [ + (cc_i, cc_i_gw, "input", e4m3_dtype), + (cc_w, cc_w_gi, "weight", e4m3_dtype), + (cc_go, cc_go_gw, "grad_output", e5m2_dtype), + ]: + # Override the dataclass being frozen + if cc1.target_dtype is None: + object.__setattr__(cc1, "target_dtype", default_dtype) + if cc2.target_dtype is None: + object.__setattr__(cc2, "target_dtype", default_dtype) + assert ( + cc1.target_dtype == cc2.target_dtype + ), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" + if self.use_fp8_all_gather_only: assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True" @@ -334,18 +359,23 @@ def recipe_name_to_linear_config( # * `input`, `weight` and `grad_output` now only need to be scaled # axiswise across a single dim compared to vanilla all-axiswise, # which is more amenable to fast kernels + # * the e4m3 dtype is used across the board, including for gradients # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) - cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) return Float8LinearConfig( cast_config_input=cc_i, diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 776de917f1..d412519c36 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -15,9 +15,9 @@ from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDelayed, - NoopFwToFloat8E5M2BwDynamic, - NoopFwToFloat8E5M2BwStatic, + NoopFwToFloat8BwDelayed, + NoopFwToFloat8BwDynamic, + NoopFwToFloat8BwStatic, _maybe_initialize_amaxes_scales_for_float8_cast, get_maybe_axiswise_dim, hp_tensor_to_float8_delayed, @@ -32,8 +32,6 @@ hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - e4m3_dtype, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -136,7 +134,7 @@ def forward( else: input_maybe_fp8 = hp_tensor_to_float8_dynamic( input_hp, - e4m3_dtype, + c.cast_config_input.target_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input.scaling_granularity, @@ -150,7 +148,7 @@ def forward( else: weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight.target_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight.scaling_granularity, @@ -186,7 +184,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output.scaling_granularity, @@ -204,7 +202,7 @@ def backward(ctx, grad_output): # the entire tensor. weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight_for_grad_input.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, @@ -236,7 +234,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output_for_grad_weight.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, @@ -250,7 +248,7 @@ def backward(ctx, grad_output): else: input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( input_hp_reshaped, - e4m3_dtype, + c.cast_config_input_for_grad_weight.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, @@ -347,11 +345,11 @@ def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len device = self.weight.device - # TODO(future PR): dtype values below don't have the other float8 - # flavors, fix it - default_input = torch.finfo(torch.float8_e4m3fn).max - default_weight = torch.finfo(torch.float8_e4m3fn).max - default_grad_output = torch.finfo(torch.float8_e5m2).max + default_input = torch.finfo(self.config.cast_config_input.target_dtype).max + default_weight = torch.finfo(self.config.cast_config_weight.target_dtype).max + default_grad_output = torch.finfo( + self.config.cast_config_grad_output.target_dtype + ).max # Note: for now, create all the buffers if any are needed, to postpone # the work to make the scale and amax syncing and history calculation @@ -438,14 +436,14 @@ def cast_input_to_float8( self.fp8_amax_history_input, self.fp8_scale_input, scale_fn_name, - e4m3_dtype, + self.config.cast_config_input.target_dtype, is_amax_initialized, reduce_amax=True, ) input_fp8 = hp_tensor_to_float8_delayed( input, self.fp8_scale_input, - e4m3_dtype, + self.config.cast_config_input.target_dtype, self.fp8_amax_input, linear_mm_config=self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, @@ -453,14 +451,17 @@ def cast_input_to_float8( elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( input, - e4m3_dtype, + self.config.cast_config_input.target_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC input_fp8 = hp_tensor_to_float8_static( - input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config + input, + self.fp8_static_scale_input, + self.config.cast_config_input.target_dtype, + self.linear_mm_config, ) return input_fp8 @@ -476,14 +477,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: self.fp8_amax_history_weight, self.fp8_scale_weight, scale_fn_name, - e4m3_dtype, + self.config.cast_config_weight.target_dtype, self.is_amax_initialized, reduce_amax=True, ) self.fp8_amax_weight.fill_(tensor_to_amax(weight)) return self.fp8_scale_weight elif self.scaling_type_weight is ScalingType.DYNAMIC: - return tensor_to_scale(weight, e4m3_dtype) + return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype) else: assert self.scaling_type_weight is ScalingType.STATIC return self.fp8_static_scale_weight @@ -499,7 +500,7 @@ def cast_weight_to_float8_t( weight_fp8 = hp_tensor_and_scale_to_float8( weight, weight_scale, - e4m3_dtype, + self.config.cast_config_weight.target_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) @@ -514,7 +515,7 @@ def cast_weight_to_original_t(self, weight: torch.Tensor): def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8E5M2BwDelayed.apply( + output = NoopFwToFloat8BwDelayed.apply( output, self.fp8_amax_grad_output, self.fp8_amax_history_grad_output, @@ -522,15 +523,21 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: scale_fn_name, self.is_amax_initialized, self.linear_mm_config, + self.config.cast_config_grad_output.target_dtype, ) elif self.scaling_type_grad_output is ScalingType.DYNAMIC: - output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config) + output = NoopFwToFloat8BwDynamic.apply( + output, + self.linear_mm_config, + self.config.cast_config_grad_output.target_dtype, + ) else: assert self.scaling_type_grad_output is ScalingType.STATIC - output = NoopFwToFloat8E5M2BwStatic.apply( + output = NoopFwToFloat8BwStatic.apply( output, self.fp8_static_scale_grad_output, self.linear_mm_config, + self.config.cast_config_grad_output.target_dtype, ) return output @@ -547,19 +554,16 @@ def float8_post_forward(self): return def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: - has_any_axiswise_scaling = ( - self.config.cast_config_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_input_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight_for_grad_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE + has_any_axiswise_scaling = any( + cc.scaling_granularity is ScalingGranularity.AXISWISE + for cc in [ + self.config.cast_config_input, + self.config.cast_config_weight, + self.config.cast_config_grad_output, + self.config.cast_config_input_for_grad_weight, + self.config.cast_config_weight_for_grad_input, + self.config.cast_config_grad_output_for_grad_weight, + ] ) if not has_any_axiswise_scaling: @@ -682,6 +686,7 @@ def from_float( WeightWithDynamicFloat8CastTensor( new_mod.weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.target_dtype, ) ) elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: @@ -692,6 +697,7 @@ def from_float( new_mod.fp8_amax_history_weight, new_mod.fp8_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.target_dtype, new_mod.is_amax_initialized, ) ) @@ -702,6 +708,7 @@ def from_float( new_mod.weight, new_mod.fp8_static_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.target_dtype, ) ) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index c4fc88eb37..64d2f7bc63 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -15,8 +15,6 @@ from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_utils import ( amax_history_to_scale_stack, - e4m3_dtype, - e5m2_dtype, ) log = logging.getLogger(__name__) @@ -227,6 +225,9 @@ def inner_func(): fp8_weight_amax_history_stack = [None] * len(fp8_layers) fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) + input_dtypes = set() + weight_dtypes = set() + grad_output_dtypes = set() scale_fn_recipes = set() for idx, child in enumerate(fp8_layers): @@ -238,8 +239,15 @@ def inner_func(): fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output + input_dtypes.add(child.config.cast_config_input.target_dtype) + weight_dtypes.add(child.config.cast_config_weight.target_dtype) + grad_output_dtypes.add(child.config.cast_config_grad_output.target_dtype) scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) + (input_dtype,) = input_dtypes + (weight_dtype,) = weight_dtypes + (grad_output_dtype,) = grad_output_dtypes + if len(scale_fn_recipes) != 1: raise ValueError( f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" @@ -297,13 +305,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, e4m3_dtype, scale_fn_recipe + fp8_input_amax_history_stack, input_dtype, scale_fn_recipe ) new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, e4m3_dtype, scale_fn_recipe + fp8_weight_amax_history_stack, weight_dtype, scale_fn_recipe ) new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, e5m2_dtype, scale_fn_recipe + fp8_grad_output_amax_history_stack, grad_output_dtype, scale_fn_recipe ) # Iterate through the layers and update the scales diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index c8fe61c8a4..3a9841e625 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -22,7 +22,6 @@ ) from torchao.float8.float8_utils import ( amax_history_to_scale, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -182,7 +181,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function): +class NoopFwToFloat8BwDelayed(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with delayed scaling, initialize if needed @@ -198,6 +197,7 @@ def forward( scale_fn_name, is_amax_initialized, linear_mm_config: LinearMMConfig, + target_dtype: torch.dtype, ): ctx.save_for_backward( fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output @@ -205,6 +205,7 @@ def forward( ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.linear_mm_config = linear_mm_config + ctx.target_dtype = target_dtype return tensor @staticmethod @@ -223,7 +224,7 @@ def backward(ctx, go): fp8_amax_history_grad_output, fp8_scale_grad_output, scale_fn_name, - e5m2_dtype, + ctx.target_dtype, is_amax_initialized, reduce_amax=True, ) @@ -233,16 +234,16 @@ def backward(ctx, go): res = hp_tensor_and_scale_to_float8( go, fp8_scale_grad_output, - e5m2_dtype, + ctx.target_dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - empty_grads = None, None, None, None, None, None + empty_grads = None, None, None, None, None, None, None return res, *empty_grads @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function): +class NoopFwToFloat8BwDynamic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with dynamic scaling @@ -253,27 +254,29 @@ def forward( ctx, tensor, linear_mm_config: LinearMMConfig, + target_dtype: torch.dtype, ): ctx.linear_mm_config = linear_mm_config + ctx.target_dtype = target_dtype return tensor @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None - gradY_scale = tensor_to_scale(gradY, e5m2_dtype) + return gradY, None, None + gradY_scale = tensor_to_scale(gradY, ctx.target_dtype) fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.target_dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None + return fp8_tensor, None, None @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwStatic(torch.autograd.Function): +class NoopFwToFloat8BwStatic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with static scaling @@ -285,21 +288,23 @@ def forward( tensor, scale, linear_mm_config: LinearMMConfig, + target_dtype: torch.dtype, ): ctx.save_for_backward(scale) ctx.linear_mm_config = linear_mm_config + ctx.target_dtype = target_dtype return tensor @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None + return gradY, None, None, None (gradY_scale,) = ctx.saved_tensors fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.target_dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None, None + return fp8_tensor, None, None, None diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 20f40330a8..fe2498e2b0 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -10,7 +10,6 @@ from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( - e4m3_dtype, to_fp8_saturated, ) @@ -133,7 +132,7 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, @@ -213,7 +212,7 @@ def backward(ctx, g): def hp_tensor_and_scale_to_float8( hp_tensor: torch.Tensor, s: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a3fc4ce7e5..37cb67c7e7 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -8,13 +8,12 @@ RowwiseParallel, ) -from torchao.float8.config import ScalingType +from torchao.float8.config import ScalingType, e4m3_dtype from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDynamic, + NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole -from torchao.float8.float8_utils import e4m3_dtype # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -49,7 +48,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -70,7 +69,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply( + outputs, + mod.linear_mm_config, + mod.config.cast_config_grad_output.target_dtype, + ) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -103,7 +106,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -123,7 +126,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply( + outputs, + mod.linear_mm_config, + mod.config.cast_config_grad_output.target_dtype, + ) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 29319f3814..90927659f8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -10,7 +10,7 @@ import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8TypeConfig, ScalingGranularity +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -28,12 +28,6 @@ } -# User defined type for using the individual F8 type based on config -type_config = Float8TypeConfig() -e4m3_dtype = type_config.e4m3_dtype -e5m2_dtype = type_config.e5m2_dtype - - @torch.no_grad() def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): """Converts the amax value of a tensor to the fp8 scale. @@ -173,7 +167,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def fp8_tensor_statistics( - tensor: torch.Tensor, float8_dtype=e4m3_dtype + tensor: torch.Tensor, float8_dtype: torch.dtype ) -> Tuple[int, ...]: """Calculate FP8 tensor stats diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 8c60995a86..9fde9922af 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Set, Tuple import torch import torch.nn as nn @@ -22,7 +22,7 @@ LinearMMConfig, hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_utils import EPS, e4m3_dtype +from torchao.float8.float8_utils import EPS @torch.no_grad() @@ -54,9 +54,14 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + target_dtypes: Set[torch.dtype] = { + float8_linear.config.cast_config_weight.target_dtype + for float8_linear in float8_linears + } if not weights: return + (target_dtype,) = target_dtypes # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial @@ -69,7 +74,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # upcast to float64 to ensure same numeric between compile and eager origin_dtype = amax_tensor.dtype amax_tensor = amax_tensor.to(torch.float64) - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + scale_tensor = torch.finfo(target_dtype).max / amax_tensor # Replicate if origin_dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local().to(torch.float32) @@ -134,6 +139,7 @@ def __new__( cls, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( @@ -153,10 +159,12 @@ def __init__( self, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor self._linear_mm_config = linear_mm_config + self._dtype = dtype # for dynamic scaling # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step @@ -166,9 +174,10 @@ def __init__( def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._linear_mm_config + args[0]._tensor, args[0]._linear_mm_config, args[0]._dtype ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal mm_config @@ -176,6 +185,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -185,40 +199,42 @@ def unwrap(t): if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( - torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + torch.Tensor, + lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config, dtype), + out, ) def __tensor_flatten__(self): + tensors = ["_tensor"] if self._precomputed_scale: - return ["_tensor", "_precomputed_scale"], self._linear_mm_config - else: - return ["_tensor"], self._linear_mm_config + tensors.append("_precomputed_scale") + return tensors, {"mm_config": self._linear_mm_config, "dtype": self._dtype} @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], getattr(inner_tensors, "_precomputed_scale", None), ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) else: float8_tensor = hp_tensor_to_float8_dynamic( self._tensor, - e4m3_dtype, + self._dtype, self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, @@ -268,6 +284,7 @@ def __new__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): return torch.Tensor._make_wrapper_subclass( @@ -290,6 +307,7 @@ def __init__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): self._tensor = tensor @@ -297,6 +315,7 @@ def __init__( self._amax_history_buffer = amax_history_buffer self._scale_buffer = scale_buffer self._linear_mm_config = linear_mm_config + self._dtype = dtype # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -312,9 +331,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): args[0]._amax_history_buffer, args[0]._scale_buffer, args[0]._linear_mm_config, + args[0]._dtype, args[0].is_amax_initialized, ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None amax_buffer: Optional[torch.Tensor] = None amax_history_buffer: Optional[torch.Tensor] = None scale_buffer: Optional[torch.Tensor] = None @@ -326,6 +347,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype nonlocal amax_buffer if amax_buffer is None: amax_buffer = t._amax_buffer @@ -354,6 +380,7 @@ def unwrap(t): amax_history_buffer, scale_buffer, mm_config, + dtype, is_amax_initialized, ), out, @@ -369,6 +396,7 @@ def __tensor_flatten__(self): ], { "mm_config": self._linear_mm_config, + "dtype": self._dtype, "is_amax_initialized": self.is_amax_initialized, }, ) @@ -381,11 +409,12 @@ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): inner_tensors["_amax_history_buffer"], inner_tensors["_scale_buffer"], metadata["mm_config"], + metadata["dtype"], metadata["is_amax_initialized"], ) def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): # initialize if needed @@ -401,7 +430,7 @@ def fsdp_pre_all_gather(self, mesh): self._amax_history_buffer, self._scale_buffer, "max", # TODO(before land): read this from parent - e4m3_dtype, + self._dtype, self.is_amax_initialized, reduce_amax=True, ) @@ -410,7 +439,7 @@ def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_to_float8_delayed( self._tensor, self._scale_buffer, - e4m3_dtype, + self._dtype, self._amax_buffer, self._linear_mm_config, GemmInputRole.WEIGHT, @@ -447,6 +476,7 @@ def __new__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -466,19 +496,25 @@ def __init__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): self._tensor = tensor self._static_scale = static_scale self._linear_mm_config = linear_mm_config + self._dtype = dtype @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithStaticFloat8CastTensor( - args[0]._tensor, args[0]._static_scale, args[0]._linear_mm_config + args[0]._tensor, + args[0]._static_scale, + args[0]._linear_mm_config, + args[0]._dtype, ) static_scale: Optional[torch.Tensor] = None mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal static_scale @@ -489,6 +525,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -499,30 +540,35 @@ def unwrap(t): return out return pytree.tree_map_only( torch.Tensor, - lambda x: WeightWithStaticFloat8CastTensor(x, static_scale, mm_config), + lambda x: WeightWithStaticFloat8CastTensor( + x, static_scale, mm_config, dtype + ), out, ) def __tensor_flatten__(self): - return ["_tensor", "_static_scale"], self._linear_mm_config + return ["_tensor", "_static_scale"], { + "mm_config": self._linear_mm_config, + "dtype": self._dtype, + } @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithStaticFloat8CastTensor( inner_tensors["_tensor"], inner_tensors["_static_scale"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], ) def __repr__(self): - return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config}, dtype={self.dtype})" def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._static_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) From abff563ba515576fc48cd4ac0feb923dd65dc267 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 5 Dec 2024 10:56:44 -0500 Subject: [PATCH 25/40] Add script to pre-process raw release notes (#1380) Summary: This commit adds a helper script to pre-process the raw release notes produced by github. This script produces a template with the standard sections and pre-sorts all commits into categories based on github labels and keywords in the commit titles. ``` python scripts/clean_release_notes.py raw_release_notes.txt ``` Test Plan: Manual testing on v0.7.0 release notes --- scripts/clean_release_notes.py | 225 +++++++++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 scripts/clean_release_notes.py diff --git a/scripts/clean_release_notes.py b/scripts/clean_release_notes.py new file mode 100644 index 0000000000..1055f288d0 --- /dev/null +++ b/scripts/clean_release_notes.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# ============================================================= +# This script automatically cleans up the raw release notes +# generated by github by doing an initial pass to sort the +# commits. The output still requires manual reviewing. +# +# This script uses PyGithub. If you don't have it yet, please +# install it using: +# +# pip install PyGithub +# +# We expect the following format for the input release notes: +# +# ## What's Changed +# * commit1_title by @userX in https://github.com/pytorch/ao/pull/123 +# * commit2_title by @userY in https://github.com/pytorch/ao/pull/234 +# * commit3_title by @userZ in https://github.com/pytorch/ao/pull/345 +# +# ## New Contributors +# * @userX made their first contribution in https://github.com/pytorch/ao/pull/123 +# * @userY made their first contribution in https://github.com/pytorch/ao/pull/234 +# +# Example output: +# +# ## Highlights +# +# We are excited to announce the X.Y.Z release of torchao! This release adds support for A, B, C, D! +# +# ### Highlight Feature 1 +# +# ### Highlight Feature 2 +# +# ## BC-Breaking +# +# ## Deprecation +# +# ## New Features +# * commit1_title (https://github.com/pytorch/ao/pull/123) +# +# ## Improvement +# * commit2_title (https://github.com/pytorch/ao/pull/234) +# +# ## Bug Fixes +# * commit3_title (https://github.com/pytorch/ao/pull/345) +# +# ## Performance +# +# ## Documentation +# +# ## Developers +# +# ## New Contributors +# * @userX made their first contribution in https://github.com/pytorch/ao/pull/123 +# * @userY made their first contribution in https://github.com/pytorch/ao/pull/234 +# +# ============================================================= + + +import os +import re +import sys +from typing import Dict, List, Optional + +try: + from github import Github +except ImportError as err: + raise ValueError("PyGithub not installed, please run 'pip install PyGithub'") from err + +if len(sys.argv) != 2: + print("Usage: python clean_release_notes.py [raw_release_notes.txt]") + sys.exit(1) + +input_file = sys.argv[1] +output_file = input_file + ".out" +VERBOSE = os.getenv("VERBOSE", "true").lower() == "true" +GITHUB_LABEL_TO_CATEGORY = { + "topic: bc-breaking": "BC Breaking", + "topic: deprecation": "Deprecation", + "topic: new feature": "New Features", + "topic: improvement": "Improvement", + "topic: bug fix": "Bug Fixes", + "topic: performance": "Performance", + "topic: documentation": "Documentation", + "topic: for developer": "Developers", +} + + +def clean_release_notes(): + """ + Main entry point for this script. + + This function pre-processes the raw release notes and produces a template + with all the standard sections and pre-sorts the commits into different + categories based on github labels and commit title keywords. + """ + + # Write the header section + with open(output_file, "w") as out_f: + out_f.write("## Highlights\n\n") + out_f.write("We are excited to announce the X.Y.Z release of torchao! This release adds support for A, B, C, D!\n\n") + out_f.write("### Highlight Feature 1\n\n") + out_f.write("### Highlight Feature 2\n\n") + + # Sort commits into different categories and write them to output file + # For lines after the commits, just copy them to the output file as is + commit_lines = [] + commit_start = False + commits_by_category = { + "BC Breaking": [], + "Deprecations": [], + "New Features": [], + "Improvement": [], + "Bug Fixes": [], + "Performance": [], + "Documentation": [], + "Developers": [], + } + with open(input_file, "r") as in_f, open(output_file, "a") as out_f: + for line in in_f.readlines(): + if line.startswith("## What's Changed"): + commit_start = True + elif commit_start and line.startswith("*"): + commit_lines.append(line) + elif commit_start: + # End of commits, fetch PR labels based on commits collected so far + commit_start = False + pr_number_to_label = fetch_pr_labels(commit_lines) + # Assign each commit to a category + for commit_line in commit_lines: + category = get_commit_category(commit_line, pr_number_to_label) + if category is not None: + commits_by_category[category].append(commit_line) + # Write all commits to the output file by category + for category, commits in commits_by_category.items(): + out_f.write("## %s\n\n" % category) + for commit_line in commits: + out_f.write(format_commit(commit_line)) + out_f.write("\n") + else: + # Not a commit, just copy to the output file + out_f.write(line) + print("Wrote to %s." % output_file) + + +def parse_pr_number(commit_line: str) -> int: + """ + Helper function to parse PR number from commit line. + """ + return int(re.match(".*pytorch/ao/pull/(.*)", commit_line).groups()[0]) + + +def fetch_pr_labels(commit_lines: List[str]) -> Dict[int, str]: + """ + Fetch the relevant github labels starting with "topic: " from all PRs. + If such a label exists for a given PR, store the first one. + """ + pr_number_to_label = {} + all_pr_numbers = [parse_pr_number(line) for line in commit_lines] + smallest_pr_number = min(all_pr_numbers) + repo = Github().get_repo("pytorch/ao") + + # This call fetches 30 PRs at a time in descending order of when the PR was created + pulls = repo.get_pulls(state="closed") + for pr in pulls: + if pr.number < smallest_pr_number: + break + labels = [l.name for l in pr.labels if l.name.startswith("topic: ")] + if len(labels) > 0: + if VERBOSE: + print("Found label for PR %s: '%s'" % (pr.number, labels[0])) + pr_number_to_label[pr.number] = labels[0] + return pr_number_to_label + + +def get_commit_category(commit_line: str, pr_number_to_label: Dict[int, str]) -> Optional[str]: + """ + Assign the commit to a category based on: + (1) The github label if it exists + (2) Keywords in the PR title + + If the commit is not meant to be user facing, remove None. + Otherwise, return "Improvement" by default. + """ + pr_number = parse_pr_number(commit_line) + if pr_number in pr_number_to_label: + label = pr_number_to_label[pr_number] + if label == "topic: not user facing": + return None + if label in GITHUB_LABEL_TO_CATEGORY: + return GITHUB_LABEL_TO_CATEGORY[label] + elif any(x in commit_line.lower() for x in ["revert", "version.txt"]): + return None + elif any(x in commit_line.lower() for x in ["doc", "readme", "tutorial", "typo", "example", "spelling"]): + return "Documentation" + elif any(x in commit_line.lower() for x in ["test", "lint", " ci", "nightl"]): + return "Developers" + elif " fix" in commit_line.lower(): + return "Bug Fixes" + elif " add" in commit_line.lower(): + return "New Features" + else: + return "Improvement" + + +def format_commit(commit_line: str) -> str: + """ + Format the commit line as follows: + Before: * commit title by @userX in https://github.com/pytorch/ao/pull/123 + After: * Commit title (https://github.com/pytorch/ao/pull/123) + """ + # Remove author, put PR link in parentheses + commit_line = re.sub(" by @.* in (.*)", " (\g<1>)", commit_line) + # Capitalize first letter + commit_line = commit_line.lstrip("* ") + commit_line = "* " + commit_line[0].upper() + commit_line[1:] + return commit_line + + +if __name__ == "__main__": + clean_release_notes() From 8a805d08898e5c961fb9b4f6ab61ffd5d5bdbca5 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Thu, 5 Dec 2024 17:42:43 -0800 Subject: [PATCH 26/40] SAM2: video_profile.py (#1386) --- examples/sam2_amg_server/video_profile.py | 427 ++++++++++++++++++ torchao/_models/sam2/build_sam.py | 2 +- .../sam2/modeling/backbones/image_encoder.py | 3 +- .../sam2/modeling/position_encoding.py | 12 +- .../_models/sam2/modeling/sam/transformer.py | 5 +- torchao/_models/sam2/modeling/sam2_base.py | 17 +- torchao/_models/sam2/sam2_video_predictor.py | 5 +- 7 files changed, 451 insertions(+), 20 deletions(-) create mode 100644 examples/sam2_amg_server/video_profile.py diff --git a/examples/sam2_amg_server/video_profile.py b/examples/sam2_amg_server/video_profile.py new file mode 100644 index 0000000000..400e879a0e --- /dev/null +++ b/examples/sam2_amg_server/video_profile.py @@ -0,0 +1,427 @@ +import argparse +import time +import os +from datetime import datetime + +import numpy as np +import torch +from PIL import Image, ImageDraw +from torchao._models.sam2.build_sam import build_sam2_video_predictor +from server import MODEL_TYPES_TO_MODEL +from server import model_type_to_paths +from pathlib import Path + +from torch._inductor import config as inductorconfig +inductorconfig.triton.unique_kernel_names = True +inductorconfig.coordinate_descent_tuning = True +inductorconfig.coordinate_descent_check_all_directions = True + +from torch.nn.attention import SDPBackend, sdpa_kernel + +# timer.py +import time +from collections import defaultdict + + +class CodeTimer: + def __init__(self): + self.start_times = {} + self.elapsed_times = defaultdict(list) + self.enabled = False + + def tic(self, section_name): + self.start_times[section_name] = time.time() + + def toc(self, section_name): + if section_name in self.start_times: + elapsed_time = time.time() - self.start_times[section_name] + self.elapsed_times[section_name].append(elapsed_time) + del self.start_times[section_name] + + def get_average_time(self, section_name, warmup: int = 1): + times = self.elapsed_times.get(section_name, []) + times = times[warmup:] + return sum(times) / len(times) if times else 0.0 + + def reset(self): + self.start_times.clear() + self.elapsed_times.clear() + + def print_all_timings(self, warmup: int = 5): + if not self.elapsed_times: + print("No timings recorded.") + return + print("Average timings for all sections:") + for section_name in self.elapsed_times: + average_time = self.get_average_time(section_name, warmup) + print(f"{section_name}, {average_time*1000.0:.6f}") + + +global_timer = CodeTimer() + + +def max_memory_allocated(): + max_memory_allocated_bytes = torch.cuda.max_memory_allocated() + _, total_memory = torch.cuda.mem_get_info() + max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) + max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 + print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%") + + +def synthesize_video_data( + out_dir: str, + radius: int, + seed: int, + speed: int, + width: int, + height: int, + n_frames: int, + x: int, + y: int, + synthesize_overwrite: bool, +): + circle_color = (255, 0, 0) # red + + os.makedirs(out_dir, exist_ok=True) + + np.random.seed(seed) + # Initial position and velocity + x = np.random.randint(radius, width - radius) + y = np.random.randint(radius, height - radius) + vx = np.random.choice([-1, 1]) * speed + vy = np.random.choice([-1, 1]) * speed + + # TODO: If these frames exist, they will not be deleted in subsequent runs with less frames. + print(f"Generate {n_frames} frames") + if not synthesize_overwrite and len(os.listdir(out_dir)) > 0: + raise ValueError("Expected folder to be empty unless --synthesize-overwrite is specified.") + # Generate 100 frames + for i in range(n_frames): + # Create a new image with a black background + img = Image.new("RGB", (width, height), (0, 0, 0)) + draw = ImageDraw.Draw(img) + # Draw the circle at its current position + draw.ellipse( + [(x - radius, y - radius), (x + radius, y + radius)], fill=circle_color + ) + # Save the image as a JPEG file + filename = f"{i:03d}.jpg" + img.save(os.path.join(out_dir, filename)) + # Update the circle's position for the next frame + x += vx + y += vy + # Bounce off the edges + if x - radius < 0 or x + radius > width: + vx *= -1 + if y - radius < 0 or y + radius > height: + vy *= -1 + + +def profiler_runner(path, fn, *args, **kwargs): + if path is None: + path = os.path.join( + os.path.expanduser("~/traces"), + f'{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json.gz', + ) + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + ) as prof: + result = fn(*args, **kwargs) + prof.export_chrome_trace(path) + print(f"Exported trace to {path}") + return result + + +def main_loop(predictor, inference_state, time_profile=True, accumulate_result=False, count_result=False): + results = [] + num_output_frames = 0 + with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION]): + with torch.autograd.profiler.record_function("main_loop"): + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( + inference_state + ): + if accumulate_result: + results.append(out_mask_logits) + if count_result: + num_output_frames += 1 + assert not (accumulate_result and count_result) + if accumulate_result: + return torch.cat(results) + if count_result: + return num_output_frames + + +def run_test( + checkpoint_path: str, + model_type: str, + profile: bool, + video_dir: str, + radius: int, + seed: int, + speed: int, + width: int, + height: int, + n_frames: int, + use_compile: bool, + frame_batch_size: int, + synthesize: bool, + synthesize_overwrite: bool, + store_output: str, + compare_output: str, + print_all_timings: bool, +): + np.random.seed(seed) + start_x = np.random.randint(radius, width - radius) + start_y = np.random.randint(radius, height - radius) + if synthesize: + synthesize_video_data( + out_dir=video_dir, + radius=radius, + seed=seed, + speed=speed, + width=width, + height=height, + n_frames=n_frames, + x=start_x, + y=start_y, + synthesize_overwrite=synthesize_overwrite, + ) + + # use bfloat16 for the entire notebook + torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) + + device = "cuda:0" + # hydra_overrides_extra = ["++model.compile_image_encoder=true"] + predictor = build_sam2_video_predictor( + model_cfg, + sam2_checkpoint, + device=device, + # hydra_overrides_extra=hydra_overrides_extra, + ) + predictor._frame_batch_size = frame_batch_size + + inference_state = predictor.init_state( + video_path=video_dir, async_loading_frames=False + ) + _, out_obj_ids, out_mask_logits = predictor.add_new_points( + inference_state=inference_state, + frame_idx=0, + obj_id=1, + points=np.array([[start_x, start_y]], dtype=np.float32), + labels=np.array([1], dtype=np.int32), + ) + + if use_compile: + print("Using torch.compile") + predictor.image_encoder.trunk.forward = torch.compile( + predictor.image_encoder.trunk.forward, + # mode="max-autotune-no-cudagraphs", + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + predictor.sam_prompt_encoder.forward = torch.compile( + predictor.sam_prompt_encoder.forward, + # mode="max-autotune-no-cudagraphs", + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + predictor.sam_mask_decoder.transformer = torch.compile( + predictor.sam_mask_decoder.transformer, + mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + predictor._forward_sam_heads = torch.compile( + predictor._forward_sam_heads, + mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + predictor.memory_attention = torch.compile( + predictor.memory_attention, + # mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=True, + ) + + predictor.memory_encoder.forward = torch.compile( + predictor.memory_encoder.forward, + mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + print("\nWarm-up round and gather outputs.") + global_timer.reset() + result = main_loop(predictor=predictor, inference_state=inference_state, accumulate_result=True) + if store_output: + print(f"Writing results to {store_output}") + torch.save(result, store_output) + if compare_output: + print(f"Comparing to results from {compare_output}") + ref_result = torch.load(compare_output) + torch.testing.assert_close(result, ref_result) + print("Passed comparison!") + if print_all_timings: + global_timer.print_all_timings() + + global_timer.reset() + print("\nProfile round.") + if profile is None: + main_loop(predictor=predictor, inference_state=inference_state) + else: + profiler_runner( + profile, + main_loop, + predictor=predictor, + inference_state=inference_state, + ) + if print_all_timings: + global_timer.print_all_timings() + + print("\nFinal timing and memory usage round.") + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + global_timer.reset() + t0 = time.time() + num_output_frames = main_loop(predictor=predictor, inference_state=inference_state, count_result=True) + t = time.time() - t0 + print(f"main_loop took {t}s for {num_output_frames} frames at {num_output_frames / t}fps") + max_memory_allocated() + if print_all_timings: + global_timer.print_all_timings() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_path", + type=str, + help="Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints", + ) + parser.add_argument( + "model_type", + type=str, + help=f"Choose one of {list(MODEL_TYPES_TO_MODEL.keys())}", + ) + parser.add_argument( + "--video_dir", + type=str, + default="/tmp/segment-anything-2/synth_video", + help="Directory to store the synthetic video", + ) + parser.add_argument( + "--profile", + type=str, + dest="profile", + help="If specified stores profile at given path.", + ) + parser.add_argument( + "--radius", + type=int, + default=50, + help="Radius of the circle for synthetic video", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Seed for initial position and velocity", + ) + parser.add_argument( + "--speed", type=int, default=20, help="Speed of the circle for synthetic video" + ) + parser.add_argument( + "--width", type=int, default=1024, help="Width of the synthetic video" + ) + parser.add_argument( + "--height", type=int, default=1024, help="Height of the synthetic video" + ) + parser.add_argument( + "--n_frames", + type=int, + default=200, + help="Number of frames in the synthetic video", + ) + parser.add_argument( + "--use-compile", + action="store_true", + dest="use_compile", + help="Use torch.compile to speed things up. First iteration will be much slower.", + ) + parser.add_argument( + "--frame_batch_size", + type=int, + default=1, + help="frame_batch_size", + ) + parser.add_argument( + "--synthesize", + action="store_true", + dest="synthesize", + help="Synthesize data for the benchmark.", + ) + parser.add_argument( + "--synthesize-overwrite", + action="store_true", + dest="synthesize_overwrite", + help="Overwrite data if it already exists when synthesizing.", + ) + parser.add_argument( + "--store-output", + type=str, + default="", + help="Pass a .pt file to store outputs in.", + ) + parser.add_argument( + "--compare-output", + type=str, + default="", + help="Pass a .pt file to load for comparison.", + ) + parser.add_argument( + "--print-all-timings", + action="store_true", + dest="print_all_timings", + help="Use torch.compile to speed things up. First iteration will be much slower.", + ) + + args = parser.parse_args() + + run_test( + args.checkpoint_path, + args.model_type, + profile=args.profile, + video_dir=args.video_dir, + radius=args.radius, + seed=args.seed, + speed=args.speed, + width=args.width, + height=args.height, + n_frames=args.n_frames, + use_compile=args.use_compile, + frame_batch_size=args.frame_batch_size, + synthesize=args.synthesize, + synthesize_overwrite=args.synthesize_overwrite, + store_output=args.store_output, + compare_output=args.compare_output, + print_all_timings=args.print_all_timings, + ) diff --git a/torchao/_models/sam2/build_sam.py b/torchao/_models/sam2/build_sam.py index 470cbfff99..d6847ede83 100644 --- a/torchao/_models/sam2/build_sam.py +++ b/torchao/_models/sam2/build_sam.py @@ -107,7 +107,7 @@ def build_sam2_video_predictor( **kwargs, ): hydra_overrides = [ - "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + "++model._target_=torchao._models.sam2.sam2_video_predictor.SAM2VideoPredictor", ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() diff --git a/torchao/_models/sam2/modeling/backbones/image_encoder.py b/torchao/_models/sam2/modeling/backbones/image_encoder.py index 37e9266bc9..3f3a938857 100644 --- a/torchao/_models/sam2/modeling/backbones/image_encoder.py +++ b/torchao/_models/sam2/modeling/backbones/image_encoder.py @@ -28,7 +28,8 @@ def __init__( def forward(self, sample: torch.Tensor): # Forward through backbone - features, pos = self.neck(self.trunk(sample)) + with torch.autograd.profiler.record_function("self.neck(self.trunk(sample))"): + features, pos = self.neck(self.trunk(sample)) if self.scalp > 0: # Discard the lowest resolution features features, pos = features[: -self.scalp], pos[: -self.scalp] diff --git a/torchao/_models/sam2/modeling/position_encoding.py b/torchao/_models/sam2/modeling/position_encoding.py index 5ba359d8d2..f4cd77fd4b 100644 --- a/torchao/_models/sam2/modeling/position_encoding.py +++ b/torchao/_models/sam2/modeling/position_encoding.py @@ -164,18 +164,18 @@ def forward_with_coords( # 3. https://github.com/lucidrains/rotary-embedding-torch -def init_t_xy(end_x: int, end_y: int): - t = torch.arange(end_x * end_y, dtype=torch.float32) +def init_t_xy(end_x: int, end_y: int, device=None): + t = torch.arange(end_x * end_y, dtype=torch.float32, device=device) t_x = (t % end_x).float() t_y = torch.div(t, end_x, rounding_mode="floor").float() return t_x, t_y -def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): - freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) - freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, device=None): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)) - t_x, t_y = init_t_xy(end_x, end_y) + t_x, t_y = init_t_xy(end_x, end_y, device=device) freqs_x = torch.outer(t_x, freqs_x) freqs_y = torch.outer(t_y, freqs_y) freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index 2e3d85ccd4..5574cb3fa2 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -325,9 +325,10 @@ def forward( # Apply rotary position encoding w = h = math.sqrt(q.shape[-2]) - self.freqs_cis = self.freqs_cis.to(q.device) + # NOTE: Disabling this. + # self.freqs_cis = self.freqs_cis.to(q.device) if self.freqs_cis.shape[0] != q.shape[-2]: - self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device) # .to(q.device) if q.shape[-2] != k.shape[-2]: assert self.rope_k_repeat diff --git a/torchao/_models/sam2/modeling/sam2_base.py b/torchao/_models/sam2/modeling/sam2_base.py index 20874e0581..f467c448a6 100644 --- a/torchao/_models/sam2/modeling/sam2_base.py +++ b/torchao/_models/sam2/modeling/sam2_base.py @@ -628,7 +628,7 @@ def _prepare_memory_conditioned_features( if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = torch.tensor(pos_list, device=device) + obj_pos = torch.tensor(pos_list).pin_memory().to(device=device, non_blocking=True) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) @@ -709,8 +709,8 @@ def _encode_new_memory( maskmem_out = self.memory_encoder( pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied ) - maskmem_features = maskmem_out["vision_features"] - maskmem_pos_enc = maskmem_out["vision_pos_enc"] + maskmem_features = maskmem_out["vision_features"].clone() + maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] # add a no-object embedding to the spatial memory to indicate that the frame # is predicted to be occluded (i.e. no object is appearing in the frame) if self.no_obj_embed_spatial is not None: @@ -809,6 +809,7 @@ def _encode_memory_in_output( current_out["maskmem_features"] = None current_out["maskmem_pos_enc"] = None + @torch.autograd.profiler.record_function("track_step") def track_step( self, frame_idx, @@ -854,13 +855,13 @@ def track_step( object_score_logits, ) = sam_outputs - current_out["pred_masks"] = low_res_masks - current_out["pred_masks_high_res"] = high_res_masks - current_out["obj_ptr"] = obj_ptr + current_out["pred_masks"] = low_res_masks.clone() + current_out["pred_masks_high_res"] = high_res_masks.clone() + current_out["obj_ptr"] = obj_ptr.clone() if not self.training: # Only add this in inference (to avoid unused param in activation checkpointing; # it's mainly used in the demo to encode spatial memories w/ consolidated masks) - current_out["object_score_logits"] = object_score_logits + current_out["object_score_logits"] = object_score_logits.clone() # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) @@ -870,7 +871,7 @@ def track_step( point_inputs, run_mem_encoder, high_res_masks, - object_score_logits, + object_score_logits.clone(), current_out, ) diff --git a/torchao/_models/sam2/sam2_video_predictor.py b/torchao/_models/sam2/sam2_video_predictor.py index c7e01ccf97..cbd69005e4 100644 --- a/torchao/_models/sam2/sam2_video_predictor.py +++ b/torchao/_models/sam2/sam2_video_predictor.py @@ -11,8 +11,8 @@ from tqdm import tqdm -from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base -from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames +from torchao._models.sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from torchao._models.sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames class SAM2VideoPredictor(SAM2Base): @@ -909,6 +909,7 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size): features = (expanded_image,) + features return features + @torch.autograd.profiler.record_function("_run_single_frame_inference") def _run_single_frame_inference( self, inference_state, From 23db9bf72a544773ec5fcc9f8b2bc36bc4dcc17d Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 6 Dec 2024 16:48:06 -0800 Subject: [PATCH 27/40] Move MarlinQQQTensor out of AQT (#1385) --- torchao/dtypes/README.md | 19 ++++++ torchao/dtypes/__init__.py | 4 +- torchao/dtypes/affine_quantized_tensor.py | 56 ----------------- torchao/dtypes/affine_quantized_tensor_ops.py | 2 +- torchao/dtypes/uintx/__init__.py | 6 +- ...lin_qqq_layout.py => marlin_qqq_tensor.py} | 62 +++++++++++++++++++ 6 files changed, 89 insertions(+), 60 deletions(-) create mode 100644 torchao/dtypes/README.md rename torchao/dtypes/uintx/{marlin_qqq_layout.py => marlin_qqq_tensor.py} (79%) diff --git a/torchao/dtypes/README.md b/torchao/dtypes/README.md new file mode 100644 index 0000000000..c1124c648f --- /dev/null +++ b/torchao/dtypes/README.md @@ -0,0 +1,19 @@ +# README + +## File Structure of the `dtypes` Folder + +The `dtypes` folder contains several important files and subfolders that are organized as follows: + +- **affine_quantized_tensor.py**: This is the main file, from which the subfolders `uintx` and `floatx` inherit. It contains the base tensor subclass `AffineQuantizedTensor` and code for layout and tensorImpl registration. + +- **affine_quantized_tensor_ops.py**: This file defines all the overriden aten ops and different dispatch kernels related to affine quantized tensors. + +- **utils.py**: A utility file that provides helper functions and common utilities used across different files in the `dtypes` folder. + +- **nf4tensor.py**: This file is specific to the NF4 tensor implementation, and layouts. + +### Subfolders + +- **uintx**: A subfolder that contains layouts and tensor subclasses inheriting from `affine_quantized_tensor.py`. It is specialized for handling unsigned integer quantized tensors. + +- **floatx**: Similar to `uintx`, this subfolder contains layouts and tensor subclasses that inherit from `affine_quantized_tensor.py`, but it is focused on floating-point quantized tensors. diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 00305db348..c7d98cb56e 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,14 +1,12 @@ from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, - MarlinQQQTensor, to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, - to_marlinqqq_quantized_intx, ) from .floatx import ( Float8Layout, @@ -18,10 +16,12 @@ BlockSparseLayout, Int4CPULayout, MarlinQQQLayout, + MarlinQQQTensor, MarlinSparseLayout, SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, + to_marlinqqq_quantized_intx, ) from .utils import ( Layout, diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 93d2766d1e..7aca25ecc5 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -16,10 +16,8 @@ choose_qparams_affine, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, - choose_qparams_and_quantize_affine_qqq, dequantize_affine, dequantize_affine_floatx, - dequantize_affine_qqq, quantize_affine, quantize_affine_floatx, ) @@ -33,14 +31,12 @@ __all__ = [ "AffineQuantizedTensor", - "MarlinQQQTensor", "register_layout", "to_affine_quantized_intx", "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", "to_affine_quantized_fpx", - "to_marlinqqq_quantized_intx", ] @@ -459,57 +455,6 @@ def _apply_fn_to_data(self, fn): # 2 - we're given non-floats - quantizing long to int8 is crazy -class MarlinQQQTensor(AffineQuantizedTensor): - """ - MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. - - To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, - please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq - """ - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - int_data, s_group, s_channel = self.tensor_impl.get_plain() - nbits = int(math.log2(self.quant_max - self.quant_min + 1)) - group_size = max(self.block_size) - return dequantize_affine_qqq( - int_data, s_group, s_channel, nbits, group_size, output_dtype - ) - - @classmethod - def from_hp_to_intx( - cls, - input_float: torch.Tensor, - block_size: Tuple[int, ...], - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, - _layout: Optional[Layout] = None, - ): - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - nbits = int(math.log2(quant_max - quant_min + 1)) - group_size = max(block_size) - data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( - input_float, nbits, group_size - ) - data = _layout.post_process(data) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype, - ) - - ###################################################### # Layout and TensorImpl Subclass Registration # ###################################################### @@ -522,7 +467,6 @@ def from_hp_to_intx( to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx -to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bd7ff7d333..8938e7472c 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -20,7 +20,7 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) -from torchao.dtypes.uintx.marlin_qqq_layout import ( +from torchao.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, ) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 8fba2bb678..4b1f3d39c8 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,8 +1,10 @@ from .block_sparse_layout import ( BlockSparseLayout, ) -from .marlin_qqq_layout import ( +from .marlin_qqq_tensor import ( MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, ) from .marlin_sparse_layout import ( MarlinSparseLayout, @@ -26,4 +28,6 @@ "TensorCoreTiledLayout", "Int4CPULayout", "MarlinQQQLayout", + "MarlinQQQTensor", + "to_marlinqqq_quantized_intx", ] diff --git a/torchao/dtypes/uintx/marlin_qqq_layout.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py similarity index 79% rename from torchao/dtypes/uintx/marlin_qqq_layout.py rename to torchao/dtypes/uintx/marlin_qqq_tensor.py index c3b2a78394..b75d959b41 100644 --- a/torchao/dtypes/uintx/marlin_qqq_layout.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -1,5 +1,7 @@ import logging +import math from dataclasses import dataclass +from typing import Optional, Tuple import torch from torch.utils._python_dispatch import ( @@ -8,18 +10,75 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.uintx.plain_layout import ( _aqt_is_int8_reduced_range, ) from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + choose_qparams_and_quantize_affine_qqq, + dequantize_affine_qqq, +) logger = logging.getLogger(__name__) aten = torch.ops.aten +class MarlinQQQTensor(AffineQuantizedTensor): + """ + MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + + To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + int_data, s_group, s_channel = self.tensor_impl.get_plain() + nbits = int(math.log2(self.quant_max - self.quant_min + 1)) + group_size = max(self.block_size) + return dequantize_affine_qqq( + int_data, s_group, s_channel, nbits, group_size, output_dtype + ) + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + block_size: Tuple[int, ...], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Optional[Layout] = None, + ): + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + nbits = int(math.log2(quant_max - quant_min + 1)) + group_size = max(block_size) + data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + input_float, nbits, group_size + ) + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + @dataclass(frozen=True) class MarlinQQQLayout(Layout): pass @@ -279,3 +338,6 @@ def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bi if bias is not None: out += bias.to(out.dtype) return out + + +to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx From a6f867607d91286268b8e7abe1af8144bd40a63d Mon Sep 17 00:00:00 2001 From: Phil Butler Date: Mon, 9 Dec 2024 18:41:15 -0500 Subject: [PATCH 28/40] Update SAM AMG README with more descriptive install instructions (#1337) --- examples/sam2_amg_server/README.md | 22 ++++++++++++++- .../sam2/configs/sam2/sam2_hiera_b+.yaml | 28 +++++++++---------- .../sam2/configs/sam2/sam2_hiera_s.yaml | 28 +++++++++---------- .../sam2/configs/sam2/sam2_hiera_t.yaml | 28 +++++++++---------- 4 files changed, 63 insertions(+), 43 deletions(-) diff --git a/examples/sam2_amg_server/README.md b/examples/sam2_amg_server/README.md index 43fc2b2528..c09b012c26 100644 --- a/examples/sam2_amg_server/README.md +++ b/examples/sam2_amg_server/README.md @@ -41,12 +41,32 @@ The 'ao' mode is a copy of the baseline with modifications to make the code more ### 0. Download checkpoints and install requirements ``` -pip install -r requirements.txt +# From the top-level "ao" directory + +# If necessary, create and activate a virtual environment +# Ex: +python -m venv venv && source venv/bin/activate + +# Install requirements for this example +pip install -r examples/sam2_amg_server/requirements.txt + +# If you have an older version of torch in your current environment, uninstall it first +pip uninstall torch + +# Install torch nightly +pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 + +# Build ao from source for now +python setup.py develop + +# On your mark, get set... +cd examples/sam2_amg_server/ ``` Download `sam2.1_hiera_large.pt` from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints and put it into `~/checkpoints/sam2` ### 1. Create a random subset of 1000 images +Using images with corresponding mask annotations, like from the Segment Anything Video (SA-V) [Dataset](https://github.com/facebookresearch/sam2/tree/main/sav_dataset#download-the-dataset) is suggested, to later compare any drop in accuracy using `--furious` (using `torch.float16`). ``` find sav_val -type f > sav_val_image_paths shuf -n 1000 sav_val_image_paths > sav_val_image_paths_shuf_1000 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml b/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml index 58f3eb8155..b3ba469471 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml +++ b/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -2,18 +2,18 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -24,17 +24,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -45,7 +45,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -57,23 +57,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml b/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml index 26e5d4d39f..b051d3be63 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml +++ b/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml b/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml index a62c903aaa..6b108e708f 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml +++ b/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 From f258d82b38601bfd86e5079242206749f36c2d4c Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 10 Dec 2024 15:17:21 -0800 Subject: [PATCH 29/40] Add llama benchmark run to GHA (#1398) * Add llama benchmark run to github action Summary: Added 1. output json result following https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database 2. benchmark run for llama model 3. upload the result to S3 Test Plan: CI/dashboard Reviewers: Subscribers: Tasks: Tags: * Update dashboard_perf_test.yml to run benchmark * Fix script path * Run ruff and ufmt * Fix model path * mkdir * Ready to land --------- Co-authored-by: Jerry Zhang --- .github/workflows/dashboard_perf_test.yml | 46 ++ torchao/_models/llama/generate.py | 711 ++++++++++++++++------ 2 files changed, 574 insertions(+), 183 deletions(-) create mode 100644 .github/workflows/dashboard_perf_test.yml diff --git a/.github/workflows/dashboard_perf_test.yml b/.github/workflows/dashboard_perf_test.yml new file mode 100644 index 0000000000..c2933be107 --- /dev/null +++ b/.github/workflows/dashboard_perf_test.yml @@ -0,0 +1,46 @@ +name: A100-perf-nightly + +on: + workflow_dispatch: + schedule: + - cron: 0 7 * * 0-6 + +jobs: + benchmark: + runs-on: linux.aws.a100 + strategy: + matrix: + torch-spec: + - '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124' + steps: + - uses: actions/checkout@v3 + + - name: Setup miniconda + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: "3.9" + + - name: Run benchmark + shell: bash + run: | + set -eux + ${CONDA_RUN} python -m pip install --upgrade pip + ${CONDA_RUN} pip install ${{ matrix.torch-spec }} + ${CONDA_RUN} pip install -r dev-requirements.txt + ${CONDA_RUN} pip install . + + export CHECKPOINT_PATH=checkpoints + export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + ${CONDA_RUN} python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + ${CONDA_RUN} python scripts/convert_hf_checkpoint.py --checkpoint_dir "${CHECKPOINT_PATH}/${MODEL_REPO}" + + mkdir -p ${{ runner.temp }}/benchmark-results + ${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --output_json_path ${{ runner.temp }}/benchmark-results/benchmark-results.json + + - name: Upload the benchmark results to OSS benchmark database for the dashboard + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: ${{ runner.temp }}/benchmark-results + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 065cc9c56d..7570700c65 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -3,22 +3,26 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import json import os +import platform import sys import time +from datetime import datetime from pathlib import Path from typing import Optional, Tuple -from datetime import datetime + import torch -import torchao import torch._dynamo.config import torch._inductor.config -from torchao.utils import get_model_size_in_bytes + +import torchao from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5 torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + class HostEvent: def __init__(self): self.event_time = None @@ -32,6 +36,15 @@ def elapsed_time(self, other_event): # return ms to match cuda event return abs(other_event.event_time - self.event_time) * 1000 + +def get_arch_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_name() + else: + # This returns x86_64 or arm64 (for aarch64) + return platform.machine() + + def device_timer(device): if "cuda" in device: return torch.cuda.Event(enable_timing=True) @@ -40,6 +53,7 @@ def device_timer(device): else: print(f"device={device} is not yet suppported") + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) @@ -50,19 +64,63 @@ def device_sync(device): else: print(f"device={device} is not yet suppported") -default_device = 'cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu' + +def write_json_result(output_json_path, headers, row): + """ + Write the result into JSON format, so that it can be uploaded to the benchmark database + to be displayed on OSS dashboard. The JSON format is defined at + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + mapping_headers = {headers[i]: v for i, v in enumerate(row)} + record = { + "benchmark": { + "name": "TorchAO benchmark", + "mode": "inference", + "dtype": mapping_headers["dtype"], + "extra_info": { + "device": mapping_headers["device"], + "arch": mapping_headers["arch"], + }, + }, + "model": { + "name": mapping_headers["name"], + "type": "model", + "origins": ["pytorch"], + }, + "metric": { + "name": mapping_headers["metric"], + "benchmark_values": [mapping_headers["actual"]], + "target_value": mapping_headers["target"], + }, + } + + with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f: + print(json.dumps(record), file=f) + + +default_device = ( + "cuda" + if torch.cuda.is_available() + else "xpu" + if torch.xpu.is_available() + else "cpu" +) # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.model import prepare_inputs_for_model, Transformer from torchao._models.llama.tokenizer import get_tokenizer -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): logits = logits / max(temperature, 1e-5) @@ -73,23 +131,38 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non probs = torch.nn.functional.softmax(logits, dim=-1) return probs + def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): probs = logits_to_probs(logits[:, -1], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs -def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: # input_pos: [B, S] logits = model(x, input_pos) return sample(logits, **sampling_kwargs)[0] -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [B, 1] assert input_pos.shape[-1] == 1 logits = model(x, input_pos) return sample(logits, **sampling_kwargs) -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): new_tokens, new_probs = [], [] for i in range(num_new_tokens): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): @@ -109,6 +182,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc def model_forward(model, x, input_pos): return model(x, input_pos) + @torch.no_grad() def generate( model: Transformer, @@ -117,15 +191,15 @@ def generate( batch_size: int, *, interactive: bool, - callback = lambda x: x, + callback=lambda x: x, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, - linear_causal_mask: bool=False, - prefill_start_event: Optional[torch.cuda.Event]=None, - prefill_end_event: Optional[torch.cuda.Event]=None, - decode_start_event: Optional[torch.cuda.Event]=None, - decode_end_event: Optional[torch.cuda.Event]=None, - **sampling_kwargs + linear_causal_mask: bool = False, + prefill_start_event: Optional[torch.cuda.Event] = None, + prefill_end_event: Optional[torch.cuda.Event] = None, + decode_start_event: Optional[torch.cuda.Event] = None, + decode_end_event: Optional[torch.cuda.Event] = None, + **sampling_kwargs, ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. @@ -136,12 +210,14 @@ def generate( T = prompt.size(-1) # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) - max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + max_seq_length = ( + min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + ) new_tokens = max_seq_length - T # format model input prompt, input_pos = prepare_inputs_for_model(prompt) - prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize + prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize # full prompt+output will be stored in seq seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) @@ -151,13 +227,23 @@ def generate( with torch.device(device): if cache_size is None: cache_size = max_seq_length - assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt" - model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T) + assert ( + cache_size >= max_seq_length + ), "need cache_size to be greater than max_new_tokens + size-of-prompt" + model.setup_caches( + max_batch_size=batch_size, + max_seq_length=cache_size, + kv_cache_quantization=kv_cache_quantization, + linear_causal_mask=linear_causal_mask, + prompt_length=T, + ) # execute prefill if prefill_start_event is not None: prefill_start_event.record() - next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() + next_token = prefill( + model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs + ).clone() seq[:, T] = next_token.squeeze() if prefill_end_event is not None: prefill_end_event.record() @@ -166,19 +252,28 @@ def generate( if decode_start_event is not None: decode_start_event.record() input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) - seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(batch_size, -1), + input_pos, + new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq = torch.cat((seq[:, : T + 1], *generated_tokens), dim=-1) if decode_end_event is not None: decode_end_event.record() return seq + def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) if bos: tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) + def _load_model(checkpoint_path, device, precision): checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) if "model" in checkpoint and "stories" in str(checkpoint_path): @@ -190,8 +285,10 @@ def _load_model(checkpoint_path, device, precision): return model.eval() + B_INST, E_INST = "[INST]", "[/INST]" + def main( prefill_size: Optional[int] = None, prompt: str = "Hello, my name is", @@ -201,12 +298,14 @@ def main( batch_size: int = 1, top_k: int = 200, temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + checkpoint_path: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" + ), quantization: Optional[str] = None, sparsity: Optional[str] = None, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, - linear_causal_mask: bool=False, + linear_causal_mask: bool = False, save: bool = False, compile: bool = True, compile_prefill: bool = False, @@ -215,13 +314,13 @@ def main( device=default_device, precision=torch.bfloat16, write_result: Optional[Path] = None, + output_json_path: Optional[Path] = None, ) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" if prefill_size is not None and prefill_size > 0: - # create prompt of prefill size - prompt = "prompt " * (int(prefill_size)-3) + # create prompt of prefill size + prompt = "prompt " * (int(prefill_size) - 3) torchao.quantization.utils.recommended_inductor_config_setter() @@ -236,8 +335,7 @@ def main( t0 = time.time() model = _load_model(checkpoint_path, device, precision) - - device_sync(device=device) # MKG + device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) @@ -248,54 +346,70 @@ def main( torch.manual_seed(1234) def ffn_only(mod, fqn): - return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn + return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn def not_ffn_only(mod, fqn): return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) def ffn_or_attn_only(mod, fqn): - return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn) + return isinstance(mod, torch.nn.Linear) and ( + "feed_forward" in fqn or "attention" in fqn + ) if quantization: from torchao.quantization import ( - quantize_, autoquant, - int8_weight_only, - int8_dynamic_activation_int8_weight, + float8_dynamic_activation_float8_weight, + float8_weight_only, + fpx_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, - fpx_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, uintx_weight_only, - float8_weight_only, - float8_dynamic_activation_float8_weight, ) - from torchao.utils import unwrap_tensor_subclass - from torchao.quantization.granularity import PerTensor, PerRow + from torchao.quantization.granularity import PerRow, PerTensor from torchao.utils import unwrap_tensor_subclass + if "spinquant" in quantization: from torchao.prototype.spinquant import apply_spinquant + apply_spinquant(model) if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: if sparsity and "semi" in sparsity: from torchao.dtypes import SemiSparseLayout - quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only) - quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only) + + quantize_( + model, + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + filter_fn=ffn_only, + ) + quantize_( + model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only + ) else: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: if "hqq" in quantization: - use_hqq=True + use_hqq = True else: - use_hqq=False - group_size=int(quantization.split("-")[1]) - assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + use_hqq = False + group_size = int(quantization.split("-")[1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size)) if "marlin" in quantization: if "qqq" in quantization: from torchao.dtypes import MarlinQQQLayout + quantize_( model, int8_dynamic_activation_int4_weight( @@ -307,25 +421,41 @@ def ffn_or_attn_only(mod, fqn): ) elif "semi" in sparsity: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only) + + quantize_( + model, + int4_weight_only(layout=MarlinSparseLayout()), + filter_fn=ffn_or_attn_only, + ) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) elif "embed-int8wo" in quantization: - quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding)) + quantize_( + model, + int8_weight_only(group_size=64), + filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), + ) elif quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - from torchao.prototype.awq.example import get_calib_dataset + if not TORCH_VERSION_AT_LEAST_2_3: print("Awq requires torch2.3+") exit() - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear + from torchao.prototype.awq import ( + awq_uintx, + AWQObservedLinear, + insert_awq_observer_, + ) + quant_dtype = quantization.split("-")[1] group_size = int(quantization.split("-")[2]) quant_dtype = getattr(torch, quant_dtype, torch.uint8) - model=model.to(device) + model = model.to(device) # get calibration data - insert_awq_observer_(model, 1, 256, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) TransformerEvalWrapper( model=model.to(device), tokenizer=tokenizer, @@ -333,12 +463,18 @@ def ffn_or_attn_only(mod, fqn): input_prep_func=prepare_inputs_for_model, device=device, ).run_eval( - tasks=['wikitext'], + tasks=["wikitext"], limit=1, ) is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) use_hqq = "hqq" in quantization - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) elif "uintx" in quantization: # uintx-nbits-group_size, e.g. "uintx-2-64" if "hqq" in quantization: @@ -349,18 +485,36 @@ def ffn_or_attn_only(mod, fqn): _quant_args = quantization.split("-") nbits = int(_quant_args[1]) assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8" - _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} + _NBITS_TO_DTYPE = { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + 8: torch.uint8, + } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) elif "int8_dynamic_activation_intx_weight" in quantization: - from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight - assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires fp32 precision" + from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, + ) + + assert ( + precision == torch.float32 + ), "int8_dynamic_activation_intx_weight requires fp32 precision" # Build kernels in temp location, and load them in torch # This requires an ARM CPU from torchao.experimental.temp_build import temp_build_and_load_torchao_ops - temp_build_and_load_torchao_ops(cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + "/../../experimental") + + temp_build_and_load_torchao_ops( + cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + + "/../../experimental" + ) # Quantize model _quant_args = quantization.split("-") @@ -380,31 +534,38 @@ def ffn_or_attn_only(mod, fqn): quantize_(model, float8_weight_only()) elif "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) - if granularity=="tensor": + if granularity == "tensor": granularity = PerTensor() - elif granularity=="row": + elif granularity == "row": granularity = PerRow() else: granularity = PerTensor() - quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity)) + quantize_( + model, float8_dynamic_activation_float8_weight(granularity=granularity) + ) elif "autoquant_v2" in quantization: - from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 from torchao._models._eval import InputRecorder from torchao._models.llama.model import prepare_inputs_for_model + from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 calibration_seq_length = 256 calibration_limit = 1 - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - prepare_inputs_for_model, - False, # pad_calibration_inputs - model.config.vocab_size, - device="cuda" - ).record_inputs( - ["wikitext"], - 1, - ).get_inputs()[0].values[0] + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda", + ) + .record_inputs( + ["wikitext"], + 1, + ) + .get_inputs()[0] + .values[0] + ) inputs = prepare_inputs_for_model(inputs) with torch.device("cuda"): model.setup_caches( @@ -412,19 +573,54 @@ def ffn_or_attn_only(mod, fqn): ) if "autoquant_v2-int4" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) elif "autoquant_v2-float8" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) elif "autoquant_v2-fp" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) elif "autoquant_v2-all" == quantization: - all_qtensor_classes = torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + all_qtensor_classes = ( + torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) if torchao.utils.is_sm_89(): # this is fp8 related subclasses, should rename - all_qtensor_classes += torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST - model = autoquant_v2(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs, batch_size=calibration_seq_length) + all_qtensor_classes += ( + torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST + ) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=all_qtensor_classes, + example_input=inputs, + batch_size=calibration_seq_length, + ) else: - model = autoquant_v2(model, manual=True, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + example_input=inputs, + batch_size=calibration_seq_length, + ) print("running generate") generate( @@ -446,17 +642,22 @@ def ffn_or_attn_only(mod, fqn): calibration_seq_length = 256 calibration_limit = 1 - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - prepare_inputs_for_model, - False, # pad_calibration_inputs - model.config.vocab_size, - device="cuda" - ).record_inputs( - ["wikitext"], - 1, - ).get_inputs()[0].values[0] + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda", + ) + .record_inputs( + ["wikitext"], + 1, + ) + .get_inputs()[0] + .values[0] + ) inputs = prepare_inputs_for_model(inputs) with torch.device("cuda"): model.setup_caches( @@ -464,17 +665,43 @@ def ffn_or_attn_only(mod, fqn): ) if "autoquant-int4" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) elif "autoquant-float8" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) if "autoquant-fp" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) if "autoquant-all" == quantization: - all_qtensor_classes = torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + all_qtensor_classes = ( + torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) if torchao.utils.is_sm_89(): # this is fp8 related subclasses, should rename - all_qtensor_classes += torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST - model = autoquant(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs) + all_qtensor_classes += ( + torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST + ) + model = autoquant( + model, + manual=True, + qtensor_class_list=all_qtensor_classes, + example_input=inputs, + ) else: model = autoquant(model, manual=True, example_input=inputs) @@ -498,8 +725,9 @@ def ffn_or_attn_only(mod, fqn): # standalone sparsity elif sparsity: from torchao.sparsity import semi_sparse_weight, sparsify_ + if "semi" in sparsity: - #TODO there is a bug here, need to fix + # TODO there is a bug here, need to fix sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 @@ -507,39 +735,48 @@ def ffn_or_attn_only(mod, fqn): if save: output_dir = str(checkpoint_path.cwd()) filename = str(checkpoint_path.name).split(".")[0] - torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt")) + torch.save( + model.state_dict(), + os.path.join(output_dir, filename + f"-{quantization}.pt"), + ) if compile: print("Compiling Model") global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) if memory_profile: if device == "cuda": - torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + torch.cuda.memory._record_memory_history( + True, trace_alloc_max_entries=250000, trace_alloc_record_context=True + ) elif device == "xpu": - torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + torch.xpu.memory._record_memory_history( + True, trace_alloc_max_entries=250000, trace_alloc_record_context=True + ) else: print("Memory profiling only works on CUDA or XPU devices") - + aggregate_metrics = { - 'tokens_per_sec': [], - 'time': [], - 'decode_tokens_per_sec': [], - 'prefill_time': [], + "tokens_per_sec": [], + "time": [], + "decode_tokens_per_sec": [], + "prefill_time": [], } start = -1 if compile else 0 for i in range(start, num_samples): - if i==0: + if i == 0: if device == "cuda": - torch.cuda.reset_peak_memory_stats() # MKG + torch.cuda.reset_peak_memory_stats() # MKG elif device == "xpu": - torch.xpu.reset_peak_memory_stats() # MKG - device_sync(device=device) # MKG + torch.xpu.reset_peak_memory_stats() # MKG + device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") if is_chat: @@ -548,8 +785,9 @@ def ffn_or_attn_only(mod, fqn): if interactive and i >= 0: buffer = [] - period_id = tokenizer.encode('.')[0] + period_id = tokenizer.encode(".")[0] done_generating = False + def callback(x): nonlocal done_generating if done_generating: @@ -558,16 +796,22 @@ def callback(x): if x.item() == tokenizer.eos_id(): done_generating = True if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) + print("".join(buffer), end="", flush=True) buffer.clear() # print(, end='', flush=True) + else: - callback = lambda x : x + callback = lambda x: x t0 = time.perf_counter() - prefill_start_event, prefill_end_event = device_timer(device), device_timer(device) - decode_start_event, decode_end_event = device_timer(device), device_timer(device) + prefill_start_event, prefill_end_event = device_timer(device), device_timer( + device + ) + decode_start_event, decode_end_event = device_timer(device), device_timer( + device + ) import contextlib - if (i != num_samples - 1 or not profile): + + if i != num_samples - 1 or not profile: prof = contextlib.nullcontext() else: torch.profiler._utils._init_for_cuda_graphs() @@ -595,60 +839,69 @@ def callback(x): continue if hasattr(prof, "export_chrome_trace"): prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG + device_sync(device=device) # MKG t = time.perf_counter() - t0 if not interactive and prefill_size is None: - tok_list = y[0].tolist() - # truncate text after end of string token - tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())] - print(tokenizer.decode(tokens)) + tok_list = y[0].tolist() + # truncate text after end of string token + tokens = ( + tok_list + if tokenizer.eos_id() not in tok_list + else tok_list[: tok_list.index(tokenizer.eos_id())] + ) + print(tokenizer.decode(tokens)) else: print() - tokens_generated = (y.size(-1) - prompt_length) + tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) - aggregate_metrics['time'].append(t) + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["time"].append(t) decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000 decode_tokens_sec = tokens_generated / decode_time - aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec) + aggregate_metrics["decode_tokens_per_sec"].append(decode_tokens_sec) prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000 - aggregate_metrics['prefill_time'].append(prefill_time) - print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", - f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec") + aggregate_metrics["prefill_time"].append(prefill_time) + print( + f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", + f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec", + ) print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") - if memory_profile and i==0: + if memory_profile and i == 0: if device == "cuda": snapshot = torch.cuda.memory._snapshot() elif device == "xpu": snapshot = torch.xpu.memory._snapshot() else: print("Memory profiling only works on CUDA or XPU devices") - - with open(f"{memory_profile}.pickle", 'wb') as f: + + with open(f"{memory_profile}.pickle", "wb") as f: from pickle import dump + dump(snapshot, f) print( f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", - "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" + "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html", ) break print("==========") - #ignore first sample for warmup - tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() - ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item() - decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item() + # ignore first sample for warmup + tokpersec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() + ttft = torch.mean(torch.tensor(aggregate_metrics["prefill_time"])).item() + decode_tokpersec = torch.mean( + torch.tensor(aggregate_metrics["decode_tokens_per_sec"]) + ).item() bandwidth = model_size * tokpersec - mem = torch.cuda.max_memory_reserved() /1e9 + mem = torch.cuda.max_memory_reserved() / 1e9 print(f"Average overall tokens/sec: {tokpersec:.2f}") print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s") print(f"Average TTFT: {ttft:.04f} s") - if device == "cuda": - mem = torch.cuda.max_memory_reserved() /1e9 + if device == "cuda": + mem = torch.cuda.max_memory_reserved() / 1e9 elif device == "xpu": - mem = torch.xpu.max_memory_reserved() /1e9 + mem = torch.xpu.max_memory_reserved() / 1e9 print(f"Average tokens/sec: {tokpersec:.2f}") if batch_size > 1: print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") @@ -658,71 +911,163 @@ def callback(x): if write_result: result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " - result_txt += f"repro: python generate.py " + result_txt += "repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" result_txt += f"--sparsity {sparsity} " if sparsity else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " - result_txt += f"--compile " if compile else "" - result_txt += f"--compile_prefill " if compile_prefill else "" + result_txt += "--compile " if compile else "" + result_txt += "--compile_prefill " if compile_prefill else "" result_txt += f"--prefill_size {prefill_size}" if prefill_size else "" result_txt += f"--profile {profile} " if profile else "" result_txt += f"--profile {memory_profile} " if memory_profile else "" - result_txt += f"--interactive " if interactive else "" + result_txt += "--interactive " if interactive else "" result_txt += f"--num_samples {num_samples} " result_txt += f"--max_new_tokens {max_new_tokens} " result_txt += f"--batch_size {batch_size} " result_txt += f"--top_k {top_k} " result_txt += f"--temperature {temperature} " result_txt += f"--cache_size {cache_size}" if cache_size else "" - result_txt += f"--kv_cache_quantization " if kv_cache_quantization else "" - result_txt += f"--linear_causal_mask " if linear_causal_mask else "" + result_txt += "--kv_cache_quantization " if kv_cache_quantization else "" + result_txt += "--linear_causal_mask " if linear_causal_mask else "" - f=open(write_result, "a") + f = open(write_result, "a") f.write(result_txt) f.close() + headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] + name = checkpoint_path.parent.name + arch = get_arch_name() + dtype = quantization or str(precision) + memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None] + performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None] + if output_json_path: + write_json_result(output_json_path, headers, memory_result) + write_json_result(output_json_path, headers, performance_result) -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode') - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, - help=( - 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' - +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, ' - +'embed-int8wo, marlin_qqq' - ) + + parser = argparse.ArgumentParser(description="Your CLI description.") + parser.add_argument( + "--prefill_size", type=int, default=0, help="Whether to run in ttft mode" + ) + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." ) - parser.add_argument('-s', '--sparsity', type=str, + parser.add_argument( + "--interactive", + action="store_true", + help="Whether to launch in interactive mode", + ) + parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size to benchmark with" + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "-q", + "--quantization", + type=str, help=( - 'Which sparsity techniques to apply: semi-structured' - ) + "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + + "autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, " + + "embed-int8wo, marlin_qqq" + ), + ) + parser.add_argument( + "-s", + "--sparsity", + type=str, + help=("Which sparsity techniques to apply: semi-structured"), + ) + parser.add_argument( + "--kv_cache_quantization", + action="store_true", + help="Whether to quantize the KV cache", + ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + help="Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size", + ) + parser.add_argument( + "--linear_causal_mask", + action="store_true", + help="Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)", + ) + parser.add_argument( + "--save", action="store_true", help="Whether to save the quantized model." + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--compile_prefill", + action="store_true", + help="Whether to compile the prefill (improves prefill perf, but higher compile times)", + ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") + parser.add_argument( + "--memory_profile", type=Path, default=None, help="filename for memory profile." + ) + parser.add_argument( + "--device", type=str, default=default_device, help="Device to use" + ) + parser.add_argument( + "--precision", + type=lambda x: getattr(torch, x.split(".")[-1]), + default=torch.bfloat16, + help="dtype precision to use", + ) + parser.add_argument( + "--write_result", type=Path, default=None, help="Path where to write the result" + ) + parser.add_argument( + "--output_json_path", + type=Path, + default=None, + help="Path where to write the json result for dashboard", ) - parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') - parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') - parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') - parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--memory_profile', type=Path, default=None, help='filename for memory profile.') - parser.add_argument('--device', type=str, default=default_device, help='Device to use') - parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') - parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') args = parser.parse_args() main( - args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.prefill_size, + args.prompt, + args.interactive, + args.num_samples, + args.max_new_tokens, + args.batch_size, + args.top_k, + args.temperature, + args.checkpoint_path, + args.quantization, + args.sparsity, + args.kv_cache_quantization, + args.cache_size, + args.linear_causal_mask, + args.save, + args.compile, + args.compile_prefill, + args.profile, + args.memory_profile, + args.device, + args.precision, + args.write_result, + args.output_json_path, ) From cac526145711745782fd16fe6b1471a55d891a60 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 11 Dec 2024 14:04:27 -0800 Subject: [PATCH 30/40] Add exhaustive config option to intmm kernel (#1392) * Add exhaustive config option to intmm kernel Summary: similar to https://github.com/pytorch/pytorch/pull/126220 we added exhaustive option for int8mm and scaled_mm kernels in torchao Note that there seems to be native int8mm and scaled_mm support in pytorch: https://github.com/pytorch/pytorch/blob/0610b9730e27d066e26396a2d655ba0d98c2012d/torch/_inductor/kernel/mm.py#L305 for int8mm and https://github.com/pytorch/pytorch/blob/0610b9730e27d066e26396a2d655ba0d98c2012d/torch/_inductor/kernel/mm_scaled.py#L575 for scaled mm maybe we should use that at some point. Test Plan: ``` cd benchmarks TORCHAO_AUTOTUNER_ENABLE=1 python intmm.py --file_path intmm_shapes.csv TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE TORCHAO_AUTOTUNER_ENABLE=1 python intmm.py --file_path intmm_shapes.csv ``` Reviewers: Subscribers: Tasks: Tags: * remove unused * enable all autoquant qtensor * guard float8 qtensor subclass * guard exhaustive config torch version --- torchao/_models/sam/eval_combo.py | 4 ++ torchao/kernel/README.md | 3 + torchao/kernel/intmm.py | 9 +-- torchao/kernel/intmm_triton.py | 65 ++++++++++++------- .../prototype/quantization/autoquant_v2.py | 10 +++ torchao/quantization/__init__.py | 2 + torchao/quantization/autoquant.py | 19 ++++++ 7 files changed, 83 insertions(+), 29 deletions(-) diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index 9c05d00b26..afc625a47d 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -350,6 +350,8 @@ def mlp_only(mod, name): autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) elif "autoquant_v2-float8" == compress: autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST) + elif "autoquant_v2-all" == compress: + autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.ALL_AUTOQUANT_CLASS_LIST) else: autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True) @@ -362,6 +364,8 @@ def mlp_only(mod, name): autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) elif "autoquant-float8" == compress: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) + elif "autoquant-all" == compress: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST) else: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True) predictor.model.image_encoder(example_input) diff --git a/torchao/kernel/README.md b/torchao/kernel/README.md index ab97d148f2..903bca5a68 100644 --- a/torchao/kernel/README.md +++ b/torchao/kernel/README.md @@ -6,6 +6,9 @@ Set this to a nonzero value to enable the kernels generated by the autotuner. This is turned off by default, because it is still an experimental feature and also can take a long time to run. +`TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE` +Use this to enable exhaustive search for both int8mm and scaled_mm kernels. + Searching a new config can take a long time and we'll save the updated data in `data.pkl`. If you'd like to contributed updated configs for your hardware or shapes, please open a pull request. `TORCHAO_AUTOTUNER_DATA_PATH=torchao/kernel/configs/data_a100.pkl` diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 81c2550246..afc5bcfa3f 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -10,7 +10,8 @@ from torchao.kernel import intmm_triton else: intmm_triton = None -except ImportError: +except ImportError as e: + print("import error:", e) # On cpu-only builds might not be available. intmm_triton = None @@ -56,7 +57,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: and j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) - + if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( @@ -75,8 +76,8 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: try: return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) except Exception: - # fallback path, would run on H100 for float8 dtypes - # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 4e84d9cd3c..f6f42e2f53 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -7,35 +7,50 @@ import triton.language as tl from torchao.kernel.autotuner import get_best_config_fn +from torchao.utils import TORCH_VERSION_AFTER_2_5 -int8_powers_of_two = [32, 64, 128, 256] -int8_mm_kernel_configs = sum( - [ - # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" +# TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option +int8_mm_kernel_configs = ( + sum( [ - (i, j, k, 1, 1), - (i, j, k, 1, 2), - (i, j, k, 2, 2), - (i, j, k, 1, 4), - (i, j, k, 2, 4), - (i, j, k, 3, 4), - (i, j, k, 4, 4), - (i, j, k, 1, 8), - (i, j, k, 2, 8), - (i, j, k, 3, 8), - (i, j, k, 4, 8), - (i, j, k, 5, 8), - (i, j, k, 6, 8), - (i, j, k, 7, 8), - (i, j, k, 8, 8), - ] - for (i, j, k) in itertools.product( - int8_powers_of_two, int8_powers_of_two, int8_powers_of_two - ) - ], - [], + # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" + [ + (i, j, k, 1, 1), + (i, j, k, 1, 2), + (i, j, k, 2, 2), + (i, j, k, 1, 4), + (i, j, k, 2, 4), + (i, j, k, 3, 4), + (i, j, k, 4, 4), + (i, j, k, 1, 8), + (i, j, k, 2, 8), + (i, j, k, 3, 8), + (i, j, k, 4, 8), + (i, j, k, 5, 8), + (i, j, k, 6, 8), + (i, j, k, 7, 8), + (i, j, k, 8, 8), + ] + for (i, j, k) in itertools.product( + [32, 64, 128, 256], repeat=3 + ) + ], + [] + ) ) +if TORCH_VERSION_AFTER_2_5: + if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": + int8_mm_kernel_configs = [ + (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] + for num_warps in [2, 4, 8] + ] + + # Baseline configs from pytorch/pytorch # https://github.com/pytorch/pytorch/blob/7718a1cd4f8e0b794c18a31ebd6353d6273c534e/torch/_inductor/kernel/mm_common.py#L132-L147 # int8_mm_kernel_configs = [ diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index bf6dbb2a46..977c1fd288 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -31,6 +31,8 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, + is_sm_at_least_89, + is_sm_at_least_90, ) from torchao.quantization.granularity import ( @@ -63,6 +65,7 @@ "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", + "ALL_AUTOQUANT_CLASS_LIST", "_is_linear", ] @@ -1087,6 +1090,13 @@ def get_weight_block_size(x): AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, ] +ALL_AUTOQUANT_CLASS_LIST = list(set(DEFAULT_AUTOQUANT_CLASS_LIST + DEFAULT_INT4_AUTOQUANT_CLASS_LIST + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST)) +if is_sm_at_least_89(): + ALL_AUTOQUANT_CLASS_LIST += [AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight] + +if is_sm_at_least_90(): + ALL_AUTOQUANT_CLASS_LIST += [AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight] + def _replace_with_custom_fn_if_matches_filter( model, diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 344bdeea41..8b46d97dc6 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -10,6 +10,7 @@ ) from .autoquant import ( + ALL_AUTOQUANT_CLASS_LIST, DEFAULT_AUTOQUANT_CLASS_LIST, DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, @@ -92,6 +93,7 @@ "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", + "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", "int8_dynamic_activation_int4_weight", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index b486683290..949d156349 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -26,6 +26,8 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, + is_sm_at_least_89, + is_sm_at_least_90, ) from .granularity import ( @@ -45,6 +47,7 @@ "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", + "ALL_AUTOQUANT_CLASS_LIST", ] @@ -951,6 +954,22 @@ def get_weight_block_size(x): AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, ] +ALL_AUTOQUANT_CLASS_LIST = list( + set( + DEFAULT_AUTOQUANT_CLASS_LIST + + DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) +) +if is_sm_at_least_89(): + ALL_AUTOQUANT_CLASS_LIST += [ + AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, + ] + +if is_sm_at_least_90(): + ALL_AUTOQUANT_CLASS_LIST += [AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight] + def _change_linears_to_autoquantizable(model, **kwargs): """ From 63b30cab55a2ce1f759071f5542ee7f2982d9138 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 11 Dec 2024 15:20:23 -0800 Subject: [PATCH 31/40] Fix a bug in LinearActivationQuantizedTensor (#1400) * Fix a bug in LinearActivationQuantizedTensor Summary: quant_kwargs is not populated in some places Test Plan: python test/dtypes/test_affine_quantized_tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: * ruff --- .../test_affine_quantized_tensor_parallel.py | 3 +++ .../linear_activation_quantized_tensor.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 82d3d2501d..da20b930d3 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -181,6 +181,9 @@ class TestFloat8dqRowAffineQuantizedTensorParallel( def test_tp(self, dtype): return self._test_tp(dtype) + common_utils.instantiate_parametrized_tests( + TestFloat8woAffineQuantizedTensorParallel + ) common_utils.instantiate_parametrized_tests( TestFloat8dqTensorAffineQuantizedTensorParallel ) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 46b48393a3..e86b2f8e64 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -147,8 +147,8 @@ def _(func, types, args, kwargs): ) input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(bias, aqt, original_weight_tensor) + qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs) + return func(bias, qtensor, original_weight_tensor) else: # aten.mm.default assert args[0].shape[-1] == args[1].shape[0], ( @@ -161,8 +161,8 @@ def _(func, types, args, kwargs): ) input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(aqt, original_weight_tensor) + qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs) + return func(qtensor, original_weight_tensor) @implements(aten.detach.default) @@ -203,7 +203,9 @@ def _(func, types, args, kwargs): args, kwargs, LinearActivationQuantizedTensor( - func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, ), ) @@ -216,7 +218,9 @@ def _(func, types, args, kwargs): args, kwargs, LinearActivationQuantizedTensor( - func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, ), ) From 039cef4ad546716aa04cd54c461feb173f7fe403 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 11 Dec 2024 15:20:42 -0800 Subject: [PATCH 32/40] Add marlin and semi sparse + quant option to autoquant (#1399) * Add marlin and semi sparse + quant option to autoquant Summary: Added DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST for autoquant (v1) that contains: AQDefaultLinearWeight, AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight (float16 only) and AQInt8DynamicallyQuantizedSemiSparseLinearWeight Test Plan: tested on llama and sam python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress autoquant-sparse +cuda,vit_h,32,10271,12,25.575582921440905,39.099793074967025,0.5424332682384179,max-autotune,torch.bfloat16,autoquant-sparse,False,True,True,32,154,4928,None,None Baseline: around 22/23 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-sparse --precision float16 Average tokens/sec: 160.55 Base: Average tokens/sec: 110.47 Reviewers: Subscribers: Tasks: Tags: * ruff --- torchao/_models/llama/generate.py | 7 +++++ torchao/_models/sam/eval_combo.py | 2 ++ torchao/quantization/__init__.py | 2 ++ torchao/quantization/autoquant.py | 47 +++++++++++++++++++++++++++++-- torchao/quantization/quant_api.py | 3 ++ 5 files changed, 59 insertions(+), 2 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 7570700c65..8ec6acccc9 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -685,6 +685,13 @@ def ffn_or_attn_only(mod, fqn): qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, ) + if "autoquant-sparse" == quantization: + model = autoquant( + model, + manual=True, + qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) if "autoquant-all" == quantization: all_qtensor_classes = ( torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index afc625a47d..09a3448d6a 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -364,6 +364,8 @@ def mlp_only(mod, name): autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) elif "autoquant-float8" == compress: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) + elif "autoquant-sparse" == compress: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST) elif "autoquant-all" == compress: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST) else: diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 8b46d97dc6..14dfbab52b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -14,6 +14,7 @@ DEFAULT_AUTOQUANT_CLASS_LIST, DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, autoquant, ) @@ -92,6 +93,7 @@ "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", + "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 949d156349..b8cd0125f0 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -6,9 +6,12 @@ from torchao.dtypes import ( AffineQuantizedTensor, Float8Layout, + MarlinSparseLayout, PlainLayout, + SemiSparseLayout, TensorCoreTiledLayout, ) +from torchao.dtypes.utils import Layout from torchao.float8.inference import Float8MMConfig from torchao.kernel import safe_int_mm from torchao.quantization.linear_activation_quantized_tensor import ( @@ -46,6 +49,7 @@ "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", + "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", "ALL_AUTOQUANT_CLASS_LIST", ] @@ -406,6 +410,8 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ + layout: Layout = PlainLayout() + @classmethod def from_float(cls, weight): # TODO test if this is valid @@ -414,6 +420,9 @@ def from_float(cls, weight): # if in_features <= 16: # return weight + if weight.dim() != 2: + return weight + # avoid circular dep from torchao.dtypes import to_affine_quantized_intx @@ -439,7 +448,7 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - _layout = PlainLayout() + _layout = cls.layout input_quant_func = lambda x: to_affine_quantized_intx( x, input_mapping_type, @@ -526,6 +535,16 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): return res_f +class AQInt8DynamicallyQuantizedSemiSparseLinearWeight( + AQInt8DynamicallyQuantizedLinearWeight +): + layout: Layout = SemiSparseLayout() + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + return super()._autoquant_test(act_mat, weight, bias, best_time, None) + + class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight @@ -613,14 +632,16 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ group_size: int = 32 + layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8) @classmethod def from_float(cls, weight): group_size = cls.group_size - _layout = TensorCoreTiledLayout(inner_k_tiles=8) + _layout = cls.layout if weight.shape[-1] % group_size != 0: return weight + use_hqq = True mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -631,6 +652,13 @@ def from_float(cls, weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT + + if isinstance(_layout, MarlinSparseLayout): + mapping_type = MappingType.SYMMETRIC + preserve_zero = True + zero_point_domain = ZeroPointDomain.INT + use_hqq = False + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx( weight, mapping_type, @@ -665,6 +693,13 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 +class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( + AQInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 128 + layout: Layout = MarlinSparseLayout() + + class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a @@ -949,16 +984,24 @@ def get_weight_block_size(x): ] OTHER_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, ] +DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, + AQInt8DynamicallyQuantizedSemiSparseLinearWeight, +] + ALL_AUTOQUANT_CLASS_LIST = list( set( DEFAULT_AUTOQUANT_CLASS_LIST + DEFAULT_INT4_AUTOQUANT_CLASS_LIST + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST ) ) if is_sm_at_least_89(): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 96ccb1889c..99da86b87b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -676,6 +676,9 @@ def apply_int4_weight_only_quant(weight): mapping_type = MappingType.SYMMETRIC preserve_zero = True zero_point_domain = ZeroPointDomain.INT + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" return to_affine_quantized_intx( weight, From 19b3bb5b95fa823136fdafe075960dea3e39cbd0 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Wed, 11 Dec 2024 22:25:13 -0800 Subject: [PATCH 33/40] SAM2 video batching (#1395) --- examples/sam2_amg_server/video_profile.py | 99 ++- torchao/_models/sam2/map_tensor.py | 617 ++++++++++++++++++ .../sam2/modeling/backbones/image_encoder.py | 9 +- torchao/_models/sam2/sam2_video_predictor.py | 30 +- 4 files changed, 711 insertions(+), 44 deletions(-) create mode 100644 torchao/_models/sam2/map_tensor.py diff --git a/examples/sam2_amg_server/video_profile.py b/examples/sam2_amg_server/video_profile.py index 400e879a0e..e7874879d9 100644 --- a/examples/sam2_amg_server/video_profile.py +++ b/examples/sam2_amg_server/video_profile.py @@ -6,7 +6,6 @@ import numpy as np import torch from PIL import Image, ImageDraw -from torchao._models.sam2.build_sam import build_sam2_video_predictor from server import MODEL_TYPES_TO_MODEL from server import model_type_to_paths from pathlib import Path @@ -92,9 +91,9 @@ def synthesize_video_data( vy = np.random.choice([-1, 1]) * speed # TODO: If these frames exist, they will not be deleted in subsequent runs with less frames. - print(f"Generate {n_frames} frames") + print(f"Generate {n_frames} frames under path {out_dir}") if not synthesize_overwrite and len(os.listdir(out_dir)) > 0: - raise ValueError("Expected folder to be empty unless --synthesize-overwrite is specified.") + raise ValueError(f"Expected folder {out_dir} to be empty unless --synthesize-overwrite is specified.") # Generate 100 frames for i in range(n_frames): # Create a new image with a black background @@ -139,15 +138,14 @@ def profiler_runner(path, fn, *args, **kwargs): def main_loop(predictor, inference_state, time_profile=True, accumulate_result=False, count_result=False): results = [] num_output_frames = 0 - with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION]): - with torch.autograd.profiler.record_function("main_loop"): - for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( - inference_state - ): - if accumulate_result: - results.append(out_mask_logits) - if count_result: - num_output_frames += 1 + with torch.autograd.profiler.record_function("main_loop"): + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( + inference_state + ): + if accumulate_result: + results.append(out_mask_logits) + if count_result: + num_output_frames += 1 assert not (accumulate_result and count_result) if accumulate_result: return torch.cat(results) @@ -168,28 +166,31 @@ def run_test( n_frames: int, use_compile: bool, frame_batch_size: int, + batch_size: int, synthesize: bool, synthesize_overwrite: bool, store_output: str, compare_output: str, print_all_timings: bool, + use_baseline: bool, ): np.random.seed(seed) start_x = np.random.randint(radius, width - radius) start_y = np.random.randint(radius, height - radius) if synthesize: - synthesize_video_data( - out_dir=video_dir, - radius=radius, - seed=seed, - speed=speed, - width=width, - height=height, - n_frames=n_frames, - x=start_x, - y=start_y, - synthesize_overwrite=synthesize_overwrite, - ) + for i in range(batch_size): + synthesize_video_data( + out_dir=f"{video_dir}_{i}", + radius=radius, + seed=(seed + i), # Make sure every video is different + speed=speed, + width=width, + height=height, + n_frames=n_frames, + x=start_x, + y=start_y, + synthesize_overwrite=synthesize_overwrite, + ) # use bfloat16 for the entire notebook torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() @@ -198,6 +199,12 @@ def run_test( sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) + build_sam2_video_predictor = None + if use_baseline: + from sam2.build_sam import build_sam2_video_predictor + else: + from torchao._models.sam2.build_sam import build_sam2_video_predictor + device = "cuda:0" # hydra_overrides_extra = ["++model.compile_image_encoder=true"] predictor = build_sam2_video_predictor( @@ -208,16 +215,24 @@ def run_test( ) predictor._frame_batch_size = frame_batch_size - inference_state = predictor.init_state( - video_path=video_dir, async_loading_frames=False - ) - _, out_obj_ids, out_mask_logits = predictor.add_new_points( - inference_state=inference_state, - frame_idx=0, - obj_id=1, - points=np.array([[start_x, start_y]], dtype=np.float32), - labels=np.array([1], dtype=np.int32), - ) + inference_states = [] + for i in range(batch_size): + print("i: ", i) + inference_state = predictor.init_state( + video_path=f"{video_dir}_{i}", async_loading_frames=False + ) + _, out_obj_ids, out_mask_logits = predictor.add_new_points( + inference_state=inference_state, + frame_idx=0, + obj_id=1, + points=np.array([[start_x, start_y]], dtype=np.float32), + labels=np.array([1], dtype=np.int32), + ) + inference_states.append(inference_state) + if batch_size == 1: + inference_state = inference_states[0] + else: + inference_state = predictor.batch_inference_states(inference_states) if use_compile: print("Using torch.compile") @@ -368,7 +383,13 @@ def run_test( help="Use torch.compile to speed things up. First iteration will be much slower.", ) parser.add_argument( - "--frame_batch_size", + "--batch-size", + type=int, + default=1, + help="batch_size", + ) + parser.add_argument( + "--frame-batch-size", type=int, default=1, help="frame_batch_size", @@ -403,6 +424,12 @@ def run_test( dest="print_all_timings", help="Use torch.compile to speed things up. First iteration will be much slower.", ) + parser.add_argument( + "--use-baseline", + action="store_true", + dest="use_baseline", + help="Use sam2 package instead of torchao._models.sam2", + ) args = parser.parse_args() @@ -419,9 +446,11 @@ def run_test( n_frames=args.n_frames, use_compile=args.use_compile, frame_batch_size=args.frame_batch_size, + batch_size=args.batch_size, synthesize=args.synthesize, synthesize_overwrite=args.synthesize_overwrite, store_output=args.store_output, compare_output=args.compare_output, print_all_timings=args.print_all_timings, + use_baseline=args.use_baseline, ) diff --git a/torchao/_models/sam2/map_tensor.py b/torchao/_models/sam2/map_tensor.py new file mode 100644 index 0000000000..a32424d99b --- /dev/null +++ b/torchao/_models/sam2/map_tensor.py @@ -0,0 +1,617 @@ +import contextlib +import torch +from torch.utils._pytree import tree_map +from typing import Dict +from torch.nested._internal.nested_tensor import nested_view_from_values_offsets +import functools + +MAP_TENSOR_ATEN_OP_TABLE = {} + + +def implements(aten_ops_or_torch_fns): + if not isinstance(aten_ops_or_torch_fns, (list, tuple)): + aten_ops_or_torch_fns = [aten_ops_or_torch_fns] + + def decorator(func): + for op in aten_ops_or_torch_fns: + + @functools.wraps(op) + def wrapper(f, types, args, kwargs): + return func(f, types, args, kwargs) + + MAP_TENSOR_ATEN_OP_TABLE[op] = wrapper + return func + + return decorator + + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def wrap_dim(i, dim): + if i < 0: + return dim + i + return i + + +def unwrap(t): + if isinstance(t, MapTensor): + with no_dispatch(): + return t.elems + else: + return t + + +def unwrap_i(t, i): + if isinstance(t, MapTensor): + with no_dispatch(): + return t.elems[i] + else: + return t + + +def unwrap_fn(t, fn): + if isinstance(t, MapTensor): + with no_dispatch(): + return fn(t.elems) + else: + return None + + +def wrap(t): + if isinstance(t, torch.Tensor): + return MapTensor(t) + else: + return t + + +@implements(torch.ops.aten.native_layer_norm.default) +def layer_norm_impl(func, types, args, kwargs=None): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + norm_res = func(*unwrapped_args) + assert len(norm_res) == 3 + return tuple(wrap(a) for a in norm_res) + + +@implements(torch.ops.aten.add.Tensor) +def add_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + if not isinstance(args[0], MapTensor) and isinstance(args[1], MapTensor): + if args[0].dim() == (args[1].dim() + 1): + return NotImplemented + return NotImplemented + return wrap(func(*unwrapped_args, **unwrapped_kwargs)) + + +@implements([torch.ops.aten.cat.default, + torch.ops.aten.stack.default]) +def cat_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) <= 2, f"args: {unwrapped_args}" + # TODO: Use MapTensor type for filter + # First argument's dim + dim = unwrapped_args[0][0].dim() + size = unwrapped_args[0][0].size() + for a in unwrapped_args[0]: + if a.dim() > dim: + dim = a.dim() + size = a.size() + new_args = [] + for a in unwrapped_args[0]: + if a.dim() == dim: + new_args.append(a) + else: + assert a.dim() + 1 == dim + new_args.append(a.unsqueeze(0).expand((size[0],) + a.size())) + orig_dim = unwrapped_args[1] if len(unwrapped_args) == 2 else 0 + return wrap(func(new_args, wrap_dim(orig_dim, dim - 1) + 1)) + + +@implements(torch.ops.aten.select.int) +def select_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + return wrap(func(unwrapped_args[0], unwrapped_args[1] + 1, unwrapped_args[2])) + + +@implements(torch.ops.aten.slice.Tensor) +def slice_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 4, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1, + unwrapped_args[2], + unwrapped_args[3])) + + +@implements([torch.ops.aten.mean.dim, + torch.ops.aten.max.dim, + torch.ops.aten.argmax.default, + torch.ops.aten.min.dim, + torch.ops.aten.any.dim, + torch.ops.aten.amax.default, + torch.ops.aten.amin.default, + torch.ops.aten.all.default, + torch.ops.aten.sum.dim_IntList]) +def reductions_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + # TODO: THIS MIGHT BE WRONG + if len(unwrapped_args) == 3 and len(unwrapped_kwargs) == 0: + assert len(unwrapped_args[1]) == 1 + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]], + unwrapped_args[2])) + if len(unwrapped_args) == 2 and len(unwrapped_kwargs) == 1: + assert len(unwrapped_args[1]) == 1 + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]], + **unwrapped_kwargs)) + if len(unwrapped_args) == 2 and len(unwrapped_kwargs) == 0 and type(unwrapped_args[1]) == list: + assert len(unwrapped_args[1]) == 1 + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]])) + if len(unwrapped_args) == 2 and len(unwrapped_kwargs) == 0 and type(unwrapped_args[1]) == int: + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], wrap_dim(unwrapped_args[1], dim - 1) + 1)) + if len(args) == 1 and len(kwargs) == 0: + return wrap(func(unwrapped_args[0])) + return NotImplemented + + +@implements([torch.ops.aten._unsafe_view.default, + torch.ops.aten.expand.default]) +def view_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + input_size = unwrapped_args[0].size() + bigger_size = list(input_size[:1]) + unwrapped_args[1] + return wrap(func(unwrapped_args[0], bigger_size)) + + +@implements(torch.ops.aten.view.default) +def view_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + input_size = unwrapped_args[0].size() + bigger_size = list(input_size[:1]) + unwrapped_args[1] + if unwrapped_args[0].size() == tuple(bigger_size): + return wrap(args[0].elems) + return wrap(unwrapped_args[0].reshape(bigger_size)) + + +@implements([torch.ops.aten.mm.default, + torch.ops.aten.bmm.default]) +def mm_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + return wrap(torch.matmul(*unwrapped_args)) + + +@implements(torch.ops.aten.unsqueeze.default) +def unsqueeze_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + new_i = unwrapped_args[1] + if new_i >= 0: + new_i += 1 + return wrap(func(unwrapped_args[0], new_i)) + + +@implements(torch.ops.aten.addmm.default) +def addmm_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + return wrap(torch.matmul(unwrapped_args[1], unwrapped_args[2]) + unwrapped_args[0]) + + +@implements(torch.ops.aten.convolution.default) +def convolution_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 9, f"args: {unwrapped_args}" + a = unwrapped_args[0] + a = unwrapped_args[0].flatten(0, 1) + # TODO: It's scary that this .contiguous seems necessary, but I we're below composite conv + # which might expected contiguous output + resa = func(*((a,) + unwrapped_args[1:])).contiguous() + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + return wrap(resb) + + +@implements(torch.ops.aten.upsample_bilinear2d.default) +def upsample_bilinear2d_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + a = unwrapped_args[0] + a = unwrapped_args[0].flatten(0, 1) + # NOTE: It's scary that this .contiguous seems necessary, but we're below composite upsample + # which might expected contiguous output + resa = func(*((a,) + unwrapped_args[1:])).contiguous() + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + return wrap(resb) + + +@implements(torch.ops.aten.transpose.int) +def transpose_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1, + wrap_dim(unwrapped_args[2], dim - 1) + 1)) + + +@implements(torch.ops.aten.unbind.int) +def unbind_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1)) + + +@implements(torch.ops.aten.permute.default) +def permute_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + ([0] + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]]))) + + +@implements(torch.ops.aten._scaled_dot_product_efficient_attention.default) +def _scaled_dot_product_efficient_attention_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(args) == 5 + if all(isinstance(a, MapTensor) for a in args[:3]): + # assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + sdpa_res = wrap(func(unwrapped_args[0].flatten(0, 1), + unwrapped_args[1].flatten(0, 1), + unwrapped_args[2].flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4], **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if isinstance(args[0], MapTensor) and not any(isinstance(a, MapTensor) for a in args[1:]): + # assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 4 + assert unwrapped_args[2].dim() == 4 + a0 = unwrapped_args[0] + a1_size = unwrapped_args[1].size() + a1 = unwrapped_args[1].unsqueeze(0).expand((a0.size(0),) + a1_size) + a2 = unwrapped_args[2].unsqueeze(0).expand((a0.size(0),) + a1_size) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4], **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and (not isinstance(args[2], MapTensor))): + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 4 + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[0].size()[1:]) + a2 = unwrapped_args[2].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[2].size()[1:]) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4])) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and isinstance(args[2], MapTensor)): + # assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + a0_size = unwrapped_args[0].size() + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + a0_size) + a1 = unwrapped_args[1] + a2 = unwrapped_args[2] + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4], **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view((a1_size[0],) + a0_size)),) + sdpa_res[1:] + return NotImplemented + + +@implements(torch.ops.aten._scaled_dot_product_flash_attention.default) +def _scaled_dot_product_flash_attention_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(args) == 3 + assert len(unwrapped_kwargs) == 1 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + if all(isinstance(a, MapTensor) for a in args[:3]): + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + sdpa_res = wrap(func(unwrapped_args[0].flatten(0, 1), + unwrapped_args[1].flatten(0, 1), + unwrapped_args[2].flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if isinstance(args[0], MapTensor) and not any(isinstance(a, MapTensor) for a in args[1:]): + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 4 + assert unwrapped_args[2].dim() == 4 + a0 = unwrapped_args[0] + a1_size = unwrapped_args[1].size() + a1 = unwrapped_args[1].unsqueeze(0).expand((a0.size(0),) + a1_size) + a2 = unwrapped_args[2].unsqueeze(0).expand((a0.size(0),) + a1_size) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and (not isinstance(args[2], MapTensor))): + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 4 + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[0].size()[1:]) + a2 = unwrapped_args[2].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[2].size()[1:]) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and isinstance(args[2], MapTensor)): + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + a0_size = unwrapped_args[0].size() + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + a0_size) + a1 = unwrapped_args[1] + a2 = unwrapped_args[2] + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view((a1_size[0],) + a0_size)),) + sdpa_res[1:] + return NotImplemented + + +# torch.ops.aten._unsafe_index.Tensor is only needed by inductor for compile +@implements([torch.ops.aten._unsafe_index.Tensor, + torch.ops.aten.index.Tensor]) +def index_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + # if len(args[1]) == 1 and isinstance(args[1][0], MapTensor) and isinstance(args[0], MapTensor): + # return wrap(func(*unwrapped_args)) + if len(args[1]) == 1 and isinstance(args[1][0], MapTensor) and not isinstance(args[0], MapTensor): + tensors = [func(args[0], [args[1][0].elems[i]]) for i in range(len(args[1][0].elems))] + values = torch.cat(tensors) + lengths = torch.tensor([0] + [t.size(0) for t in tensors], pin_memory=True).to(values.device, non_blocking=True) + offsets = torch.cumsum(lengths, dim=0) + nt = nested_view_from_values_offsets(values, offsets) + assert nt.layout == torch.jagged + return wrap(nt) + if isinstance(args[0], MapTensor) and not isinstance(args[1][0], MapTensor) and len(args[1]) == 1: + return wrap(func(args[0].elems, [args[1][0].unsqueeze(0)])) + if isinstance(args[0], MapTensor) and not isinstance(args[1][0], MapTensor) and isinstance(args[1][1], MapTensor)and len(args[1]) == 2: + res = [] + for a0, a11 in zip(args[0].elems.unbind(), args[1][1].elems.unbind()): + res.append(func(a0, [args[1][0], a11])) + return wrap(torch.stack(res)) + if isinstance(args[0], MapTensor) and isinstance(args[1][0], MapTensor) and len(args[1]) == 1: + tensors = [func(args[0].elems[i], [args[1][0].elems[i]]) for i in range(len(args[0].elems))] + values = torch.cat(tensors) + lengths = torch.tensor([0] + [t.size(0) for t in tensors], pin_memory=True).to(values.device, non_blocking=True) + offsets = torch.cumsum(lengths, dim=0) + nt = nested_view_from_values_offsets(values, offsets) + assert nt.layout == torch.jagged + return wrap(nt) + a = unwrapped_args[0] + a = unwrapped_args[0].flatten(0, 1) + resa = func(a, args[1]) + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + return wrap(resb) + + +# Prims +@implements(torch.ops.aten.dim.default) +def dim_impl(func, types, args, kwargs): + assert len(args) == 1 + assert len(kwargs) == 0 + ret_dim = func(args[0].elems) - 1 + assert ret_dim >= 0 + return ret_dim + + +@implements(torch.ops.aten.sym_size.default) +def sym_impl(func, types, args, kwargs): + assert len(args) == 1 + assert len(kwargs) == 0 + elems_size = func(args[0].elems) + assert len(elems_size) > 0 + return elems_size[1:] + + +@implements(torch.ops.aten.is_contiguous.default) +def is_contiguous_impl(func, types, args, kwargs): + assert len(args) == 1 + assert len(kwargs) == 0 + return func(args[0].elems) + + +@implements([torch.ops.aten.clamp.default, + torch.ops.aten.clone.default, + torch.ops.aten.cos.default, + torch.ops.aten.div.Tensor, + torch.ops.aten.eq.Scalar, + torch.ops.aten.gelu.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.relu.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.sin.default, + torch.ops.aten.sqrt.default, + torch.ops.aten.sub.Tensor, + torch.ops.aten.unbind.int, + torch.ops.aten.where.self, + torch.ops.aten.zeros_like.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.gt.Scalar, + torch.ops.aten.ge.Scalar, + torch.ops.aten.bitwise_not.default, + torch.ops.aten.lt.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.eq.Tensor, + torch.ops.aten.abs.default, + torch.ops.aten.ne.Scalar, + torch.ops.aten.le.Tensor, + torch.ops.aten.view_as_complex.default, + torch.ops.aten.view_as_real.default, + torch.ops.aten.neg.default, + torch.ops.aten.le.Scalar, + torch.ops.aten.rsub.Scalar, + # Sketchy new in place ops + torch.ops.aten.bitwise_and_.Tensor, + torch.ops.aten.bitwise_or_.Tensor, + torch.ops.aten.le.Tensor, + torch.ops.aten.logical_and.default, + # in place ops + torch.ops.aten.add_.Tensor, + torch.ops.aten.copy_.default, + # Prims + torch.ops.prim.layout.default]) +def forwardables_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + return wrap(func(*unwrapped_args, **unwrapped_kwargs)) + + +def run_invariant_test(res, func, args, kwargs): + # Compares 0th element of list of results with + # func applied to 0th arg and kwarg. + # Rough test to maintain per-op accuracy. + if isinstance(res, torch.Tensor): + unwrapped_args_0 = tree_map(lambda x: unwrap_i(x, 0), args) + unwrapped_kwargs_0 = tree_map(lambda x: unwrap_i(x, 0), kwargs) + if func == torch.ops.aten.view.default: + res_0 = torch.ops.aten.reshape.default(*unwrapped_args_0, **unwrapped_kwargs_0) + else: + res_0 = func(*unwrapped_args_0, **unwrapped_kwargs_0) + if res.elems[0].size() != res_0.size(): + import pdb; pdb.set_trace() + if not torch.allclose(res.elems[0], res_0, atol=1e-3, rtol=1e-3): + import pdb; pdb.set_trace() + else: + pass + # print("res got type: ", type(res)) + # import pdb; pdb.set_trace() + return res + + +class MapTensor(torch.Tensor): + @staticmethod + def __new__(cls, elems): + # print("elems.layout: ", elems.layout) + return torch.Tensor._make_wrapper_subclass(cls, + elems.shape[1:], + dtype=elems.dtype, + device=elems.device, + layout=elems.layout, + dispatch_layout=True, + dispatch_sizes_strides_policy=("sizes" if elems.layout == torch.jagged else None), + storage_size=(elems._values.untyped_storage().size() if elems.layout == torch.jagged else None)) + + def __init__(self, elems): + self.elems = elems + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func in MAP_TENSOR_ATEN_OP_TABLE: + res = MAP_TENSOR_ATEN_OP_TABLE[func](func, types, args, kwargs) + # run_invariant_test(res, func, args, kwargs) + return res + return NotImplemented + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + # flatten/unflatten is needed for compile + def __tensor_flatten__(self): + ctx = {} + inner_tensors = ["elems"] + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): + from torch._subclasses.fake_tensor import FakeTensor + + # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] + assert len(inner_tensors) == 1, f"{inner_tensors}" + elems = inner_tensors["elems"] + + return MapTensor(elems) + + def __repr__(self): + return f"MapTensor({self.elems.size()})" + +# ts is a higher dim Tensor +def to_map_tensor(ts: torch.Tensor): + return MapTensor(ts) diff --git a/torchao/_models/sam2/modeling/backbones/image_encoder.py b/torchao/_models/sam2/modeling/backbones/image_encoder.py index 3f3a938857..7225316bc7 100644 --- a/torchao/_models/sam2/modeling/backbones/image_encoder.py +++ b/torchao/_models/sam2/modeling/backbones/image_encoder.py @@ -29,7 +29,14 @@ def __init__( def forward(self, sample: torch.Tensor): # Forward through backbone with torch.autograd.profiler.record_function("self.neck(self.trunk(sample))"): - features, pos = self.neck(self.trunk(sample)) + from torchao._models.sam2.map_tensor import MapTensor + from torchao._models.sam2.map_tensor import to_map_tensor + if isinstance(sample, MapTensor): + features, pos = self.neck(self.trunk(sample.elems.flatten(0, 1))) + features = [to_map_tensor(t.unsqueeze(1)) for t in features] + pos = [to_map_tensor(t.unsqueeze(1)) for t in pos] + else: + features, pos = self.neck(self.trunk(sample)) if self.scalp > 0: # Discard the lowest resolution features features, pos = features[: -self.scalp], pos[: -self.scalp] diff --git a/torchao/_models/sam2/sam2_video_predictor.py b/torchao/_models/sam2/sam2_video_predictor.py index cbd69005e4..46ab610556 100644 --- a/torchao/_models/sam2/sam2_video_predictor.py +++ b/torchao/_models/sam2/sam2_video_predictor.py @@ -40,7 +40,21 @@ def __init__( self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond - @torch.inference_mode() + @staticmethod + def batch_inference_states(inference_states: list): + assert all(dict == type(state) for state in inference_states) + num_states = len(inference_states) + assert num_states > 0 + import copy + batched_inference_state = copy.copy(inference_states[0]) + + from torchao._models.sam2.map_tensor import to_map_tensor + # NOTE: Making a build assumption only images differ + all_images = torch.stack([state["images"] for state in inference_states]) + batched_inference_state["images"] = to_map_tensor(all_images) + return batched_inference_state + + @torch.no_grad() def init_state( self, video_path, @@ -169,7 +183,7 @@ def _get_obj_num(self, inference_state): """Get the total number of unique object ids received so far in this session.""" return len(inference_state["obj_idx_to_id"]) - @torch.inference_mode() + @torch.no_grad() def add_new_points_or_box( self, inference_state, @@ -317,7 +331,7 @@ def add_new_points(self, *args, **kwargs): """Deprecated method. Please use `add_new_points_or_box` instead.""" return self.add_new_points_or_box(*args, **kwargs) - @torch.inference_mode() + @torch.no_grad() def add_new_mask( self, inference_state, @@ -589,7 +603,7 @@ def _get_empty_mask_ptr(self, inference_state, frame_idx): ) return current_out["obj_ptr"] - @torch.inference_mode() + @torch.no_grad() def propagate_in_video_preflight(self, inference_state): """Prepare inference_state and consolidate temporary outputs before tracking.""" # Tracking has started and we don't allow adding new objects until session is reset. @@ -659,7 +673,7 @@ def propagate_in_video_preflight(self, inference_state): input_frames_inds.update(mask_inputs_per_frame.keys()) assert all_consolidated_frame_inds == input_frames_inds - @torch.inference_mode() + @torch.no_grad() def propagate_in_video( self, inference_state, @@ -773,7 +787,7 @@ def _add_output_per_object( obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] obj_output_dict[storage_key][frame_idx] = obj_out - @torch.inference_mode() + @torch.no_grad() def clear_all_prompts_in_frame( self, inference_state, frame_idx, obj_id, need_output=True ): @@ -844,7 +858,7 @@ def clear_all_prompts_in_frame( ) return frame_idx, obj_ids, video_res_masks - @torch.inference_mode() + @torch.no_grad() def reset_state(self, inference_state): """Remove all input points or mask in all frames throughout the video.""" self._reset_tracking_results(inference_state) @@ -1039,7 +1053,7 @@ def _get_maskmem_pos_enc(self, inference_state, current_out): expanded_maskmem_pos_enc = None return expanded_maskmem_pos_enc - @torch.inference_mode() + @torch.no_grad() def remove_object(self, inference_state, obj_id, strict=False, need_output=True): """ Remove an object id from the tracking state. If strict is True, we check whether From 7624ae88da66fd9551418db136b00ce9497cfbbe Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 12 Dec 2024 12:07:29 -0800 Subject: [PATCH 34/40] Enable ciflow/benchmark (#1404) Enable ciflow/benchmark on ao --- .github/pytorch-probot.yml | 2 ++ .github/workflows/dashboard_perf_test.yml | 3 +++ 2 files changed, 5 insertions(+) diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 4cf21c2352..65cca3f10f 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -1 +1,3 @@ mergebot: True +ciflow_push_tags: +- ciflow/benchmark diff --git a/.github/workflows/dashboard_perf_test.yml b/.github/workflows/dashboard_perf_test.yml index c2933be107..62823e8895 100644 --- a/.github/workflows/dashboard_perf_test.yml +++ b/.github/workflows/dashboard_perf_test.yml @@ -1,6 +1,9 @@ name: A100-perf-nightly on: + push: + tags: + - ciflow/benchmark/* workflow_dispatch: schedule: - cron: 0 7 * * 0-6 From eed437f9eccd1b7fee41e5bc995a1bbbf8a355fb Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 12 Dec 2024 12:42:05 -0800 Subject: [PATCH 35/40] Update api_ref_quantization.rst (#1408) --- docs/source/api_ref_quantization.rst | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index 5bc0a0674c..7f2b312e85 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -9,8 +9,8 @@ torchao.quantization .. autosummary:: :toctree: generated/ :nosignatures: - autoquant - + + autoquant quantize_ int8_dynamic_activation_int4_weight int8_dynamic_activation_int8_weight @@ -21,12 +21,9 @@ torchao.quantization float8_static_activation_float8_weight uintx_weight_only fpx_weight_only - to_linear_activation_quantized - swap_linear_with_smooth_fq_linear smooth_fq_linear_to_inference - choose_qparams_affine choose_qparams_affine_with_min_max choose_qparams_affine_floatx @@ -37,10 +34,8 @@ torchao.quantization choose_qparams_and_quantize_affine_hqq fake_quantize_affine fake_quantize_affine_cachemask - safe_int_mm int_scaled_matmul - MappingType ZeroPointDomain TorchAODType From 31234dbc3210a0833cd922856107dc004d088ce1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 12 Dec 2024 12:42:17 -0800 Subject: [PATCH 36/40] Update index.rst (#1409) --- docs/source/index.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index befe30570c..c008c80453 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,7 +1,11 @@ Welcome to the torchao Documentation ======================================= -`**torchao** `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials. +`torchao `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on: + +1. API Reference +2. Developer Contribution Guide +3. Tutorials .. .. grid:: 3 @@ -96,7 +100,6 @@ Welcome to the torchao Documentation :glob: :maxdepth: 1 :caption: Tutorials - :hidden: serialization From ebc43034e665bcda759cf9ef9c2c207057c5eeb1 Mon Sep 17 00:00:00 2001 From: Manuel Candales <42380156+manuelcandales@users.noreply.github.com> Date: Thu, 12 Dec 2024 20:49:51 -0500 Subject: [PATCH 37/40] metal lowbit kernels: executorch ops Differential Revision: D65957345 Pull Request resolved: https://github.com/pytorch/ao/pull/1322 --- .../mps/codegen/gen_metal_shader_lib.py | 4 +- .../kernels/mps/src/MetalShaderLibrary.h | 64 ++++++++ .../kernels/mps/src/OperationUtils.h | 101 +------------ .../kernels/mps/src/OperationUtils.mm | 20 +++ torchao/experimental/kernels/mps/src/common.h | 51 +++++++ torchao/experimental/kernels/mps/src/lowbit.h | 21 +-- .../experimental/kernels/mps/test/Makefile | 4 +- .../kernels/mps/test/test_lowbit.mm | 4 +- torchao/experimental/ops/mps/CMakeLists.txt | 32 +++- ...r.mm => linear_fp_act_xbit_weight_aten.mm} | 43 +++++- .../linear_fp_act_xbit_weight_executorch.mm | 138 ++++++++++++++++++ .../experimental/ops/mps/test/test_lowbit.py | 2 +- .../ops/mps/test/test_quantizer.py | 2 +- torchao/experimental/quant_api.py | 25 ++-- 14 files changed, 364 insertions(+), 147 deletions(-) create mode 100644 torchao/experimental/kernels/mps/src/MetalShaderLibrary.h create mode 100644 torchao/experimental/kernels/mps/src/OperationUtils.mm create mode 100644 torchao/experimental/kernels/mps/src/common.h rename torchao/experimental/ops/mps/{aten/register.mm => linear_fp_act_xbit_weight_aten.mm} (78%) create mode 100644 torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index eea7e42666..7764c0871f 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -37,9 +37,9 @@ */ #ifdef USE_ATEN -using namespace at::native::mps; +using at::native::mps::MetalShaderLibrary; #else -#include +#include #endif static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT( diff --git a/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h new file mode 100644 index 0000000000..3aca35e699 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h @@ -0,0 +1,64 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +class MetalShaderLibrary { + public: + MetalShaderLibrary(const std::string& src) : shaderSource(src) { + lib = compileLibraryFromSource(shaderSource); + } + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + MetalShaderLibrary(MetalShaderLibrary&&) = delete; + + id getPipelineStateForFunc( + const std::string& fname) { + id func = loadFunc(fname); + + NSError* error = nil; + id device = get_metal_device(); + auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; + if (cpl == nil) { + throw std::runtime_error( + "Failed to construct pipeline state: " + + std::string(error.description.UTF8String)); + } + return cpl; + + } + + private: + std::string shaderSource; + id lib = nil; + + id loadFunc(const std::string& func_name) const { + id func = [lib + newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; + if (func == nil) { + throw std::runtime_error("Can't get function:" + func_name); + } + return func; + } + + id compileLibraryFromSource( + const std::string& source) { + NSError* error = nil; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:MTLLanguageVersion3_1]; + NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; + id device = get_metal_device(); + id library = [device newLibraryWithSource:kernel_source + options:options + error:&error]; + if (library == nil) { + throw std::runtime_error( + "Failed to compile: " + std::string(error.description.UTF8String)); + } + return library; + } +}; diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.h b/torchao/experimental/kernels/mps/src/OperationUtils.h index 7cb902f23f..5a41b264af 100644 --- a/torchao/experimental/kernels/mps/src/OperationUtils.h +++ b/torchao/experimental/kernels/mps/src/OperationUtils.h @@ -6,101 +6,12 @@ #pragma once -#include -#include - -static void throw_exception(const std::string& str) { - std::cerr << str << std::endl; - throw std::runtime_error(str); -} - -inline void dispatch_block( - [[maybe_unused]] id queue, - void (^block)()) { - __block std::optional block_exception; - try { - block(); - } catch (...) { - block_exception = std::current_exception(); - } - if (block_exception) { - std::rethrow_exception(*block_exception); - } -} - -inline id getMetalDevice() { - @autoreleasepool { - NSArray* devices = [MTLCopyAllDevices() autorelease]; - if (devices.count == 0) { - throw_exception("Metal is not supported"); - } - return devices[0]; - } -} - -static id MTL_DEVICE = getMetalDevice(); - -static id compileLibraryFromSource( - id device, - const std::string& source) { - NSError* error = nil; - MTLCompileOptions* options = [MTLCompileOptions new]; - [options setLanguageVersion:MTLLanguageVersion3_1]; - NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; - id library = [device newLibraryWithSource:kernel_source - options:options - error:&error]; - if (library == nil) { - throw_exception( - "Failed to compile: " + std::string(error.description.UTF8String)); - } - return library; -} - -class MetalShaderLibrary { - public: - MetalShaderLibrary(const std::string& src) : shaderSource(src) { - lib = compileLibraryFromSource(device, shaderSource); - } - MetalShaderLibrary(const MetalShaderLibrary&) = delete; - MetalShaderLibrary(MetalShaderLibrary&&) = delete; - - id getPipelineStateForFunc( - const std::string& fname) { - return get_compute_pipeline_state(load_func(fname)); - } - - private: - std::string shaderSource; - id device = MTL_DEVICE; - id lib = nil; - - id load_func(const std::string& func_name) const { - id func = [lib - newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; - if (func == nil) { - throw_exception("Can't get function:" + func_name); - } - return func; - } - - id get_compute_pipeline_state( - id func) const { - NSError* error = nil; - auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; - if (cpl == nil) { - throw_exception( - "Failed to construct pipeline state: " + - std::string(error.description.UTF8String)); - } - return cpl; - } -}; +id getMetalDevice(); class MPSStream { public: MPSStream() { - _commandQueue = [MTL_DEVICE newCommandQueue]; + _commandQueue = [getMetalDevice() newCommandQueue]; } ~MPSStream() { @@ -136,14 +47,6 @@ class MPSStream { id _commandEncoder = nil; }; -inline void finalize_block(MPSStream* mpsStream) { - id encoder = mpsStream->commandEncoder(); - id cmdBuffer = mpsStream->commandBuffer(); - [encoder endEncoding]; - [cmdBuffer commit]; - [cmdBuffer waitUntilCompleted]; -} - inline MPSStream* getCurrentMPSStream() { return new MPSStream(); } diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.mm b/torchao/experimental/kernels/mps/src/OperationUtils.mm new file mode 100644 index 0000000000..795c93225a --- /dev/null +++ b/torchao/experimental/kernels/mps/src/OperationUtils.mm @@ -0,0 +1,20 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +id getMetalDevice() { + @autoreleasepool { + NSArray* devices = [MTLCopyAllDevices() autorelease]; + if (devices.count == 0) { + throw std::runtime_error("Metal is not supported"); + } + static id MTL_DEVICE = devices[0]; + return MTL_DEVICE; + } +} diff --git a/torchao/experimental/kernels/mps/src/common.h b/torchao/experimental/kernels/mps/src/common.h new file mode 100644 index 0000000000..0710d37b3a --- /dev/null +++ b/torchao/experimental/kernels/mps/src/common.h @@ -0,0 +1,51 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#ifdef USE_ATEN +#include +using namespace at::native::mps; +#elif defined(USE_EXECUTORCH) +#include +using namespace executorch::backends::mps::delegate; +#else +#include +#endif + +inline void dispatch_block( + MPSStream* mpsStream, + void (^block)()) { +#if defined(USE_ATEN) + dispatch_sync_with_rethrow(mpsStream->queue(), block); +#elif defined(USE_EXECUTORCH) + dispatch_sync(mpsStream->queue(), block); +#else + (void)mpsStream; + block(); +#endif +} + +inline void optionally_wait_for_command_completion(MPSStream* mpsStream) { +#if defined(USE_ATEN) +#elif defined(USE_EXECUTORCH) + ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok); +#else + id encoder = mpsStream->commandEncoder(); + id cmdBuffer = mpsStream->commandBuffer(); + [encoder endEncoding]; + [cmdBuffer commit]; + [cmdBuffer waitUntilCompleted]; +#endif +} + +inline id get_metal_device() { +#if defined(USE_ATEN) || defined(USE_EXECUTORCH) + return MPSDevice::getInstance()->device(); +#else + return getMetalDevice(); +#endif +} diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index d37001350a..ae3951e217 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -9,24 +9,11 @@ #include #include +#include #include -#include +#include // metal_lowbit_quantized_lib #include -#include -#include -#include - -#ifdef USE_ATEN -#include -using namespace at::native::mps; -inline void finalize_block(MPSStream* mpsStream) {} -void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) = - dispatch_sync_with_rethrow; -#else -#include -#endif - namespace torchao::kernels::mps::lowbit { namespace { @@ -103,7 +90,7 @@ inline void linear_lowbit_quant_weights_mps_impl( 0}; MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_block(mpsStream->queue(), ^() { + dispatch_block(mpsStream, ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); id cpl = @@ -119,7 +106,7 @@ inline void linear_lowbit_quant_weights_mps_impl( length:sizeof(uint32_t) * sizes.size() atIndex:5]; dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K); - finalize_block(mpsStream); + optionally_wait_for_command_completion(mpsStream); } }); } diff --git a/torchao/experimental/kernels/mps/test/Makefile b/torchao/experimental/kernels/mps/test/Makefile index e8213818c5..3c0da54f7c 100644 --- a/torchao/experimental/kernels/mps/test/Makefile +++ b/torchao/experimental/kernels/mps/test/Makefile @@ -1,7 +1,7 @@ all: test_lowbit -test_lowbit: test_lowbit.mm - clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation +test_lowbit: test_lowbit.mm ../src/OperationUtils.mm + clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $^ -framework Metal -framework Foundation run: test_lowbit ./test_lowbit diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 2d86223034..7fb20d254a 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -31,7 +31,7 @@ id rc = [device newBufferWithLength:length options:MTLResourceStorageModeShared]; if (rc == nil) { - throw_exception( + throw std::runtime_error( "Can't allocate " + std::to_string(length) + " bytes on GPU"); } return rc; @@ -80,7 +80,7 @@ void reference_linear_lowbit_quant_weights_cpu( : M(m), K(k), N(n), qGroupSize(group_size) {} void init() { - allocBuffers(MTL_DEVICE); + allocBuffers(getMetalDevice()); T* a_ptr = reinterpret_cast([buf_A contents]); uint8_t* w_ptr = reinterpret_cast([buf_W contents]); diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt index 044433ef95..1d41f75854 100644 --- a/torchao/experimental/ops/mps/CMakeLists.txt +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -26,10 +26,13 @@ endif() find_package(Torch REQUIRED) # Generate metal_shader_lib.h by running gen_metal_shader_lib.py +set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal) +set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py) set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) add_custom_command( OUTPUT ${GENERATED_METAL_SHADER_LIB} - COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB} + COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB} + DEPENDS ${METAL_SHADERS_DIR}/*.metal ${GEN_SCRIPT} COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py" ) add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB}) @@ -41,7 +44,7 @@ message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") include_directories(${TORCHAO_INCLUDE_DIRS}) include_directories(${CMAKE_INSTALL_PREFIX}/include) -add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm) +add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm) add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib) target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") @@ -53,8 +56,25 @@ find_library(METAL_LIB Metal) find_library(FOUNDATION_LIB Foundation) target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) -install( - TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten - EXPORT _targets - DESTINATION lib +add_library(torchao_ops_mps_aten SHARED) +target_link_libraries(torchao_ops_mps_aten PRIVATE + torchao_ops_mps_linear_fp_act_xbit_weight_aten ) +install(TARGETS torchao_ops_mps_aten DESTINATION lib) + +if(TORCHAO_BUILD_EXECUTORCH_OPS) + include_directories(${CMAKE_INSTALL_PREFIX}/../..) + include_directories(${CMAKE_INSTALL_PREFIX}/schema/include) + include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include) + add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT linear_fp_act_xbit_weight_executorch.mm) + add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib) + target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE executorch executorch_core mpsdelegate) + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) + + add_library(torchao_ops_mps_executorch STATIC) + target_link_libraries(torchao_ops_mps_executorch PRIVATE + torchao_ops_mps_linear_fp_act_xbit_weight_executorch + ) + install(TARGETS torchao_ops_mps_executorch DESTINATION lib) +endif() diff --git a/torchao/experimental/ops/mps/aten/register.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm similarity index 78% rename from torchao/experimental/ops/mps/aten/register.mm rename to torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index 92a3ba89f0..e11e55c5a0 100644 --- a/torchao/experimental/ops/mps/aten/register.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -70,12 +70,13 @@ void check_linear_mps_args( } template -Tensor linear_mps_kernel( +Tensor linear_mps_kernel_out( const Tensor& A, const Tensor& B, int64_t group_size, const Tensor& S, - const Tensor& Z) { + const Tensor& Z, + Tensor& C) { TORCH_CHECK( A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps"); TORCH_CHECK( @@ -84,6 +85,8 @@ Tensor linear_mps_kernel( S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps"); TORCH_CHECK( Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); + TORCH_CHECK( + C.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); check_linear_mps_args(A, B, group_size, S, Z); @@ -91,8 +94,6 @@ Tensor linear_mps_kernel( auto N = B.size(0); auto K = A.size(1); - auto C = at::empty({M, N}, A.options()); - LowBitQuantWeights::linear( getMTLBufferStorage(A), getMTLBufferStorage(B), @@ -108,6 +109,19 @@ Tensor linear_mps_kernel( return C; } +template +Tensor linear_mps_kernel( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto M = A.size(0); + auto N = B.size(0); + auto C = at::empty({M, N}, A.options()); + return linear_mps_kernel_out(A, B, group_size, S, Z, C); +} + template Tensor linear_mps_kernel_meta( const Tensor& A, @@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); + m.def( + "_linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>); m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>); m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>); + m.impl("_linear_fp_act_1bit_weight.out", &linear_mps_kernel_out<1>); + m.impl("_linear_fp_act_2bit_weight.out", &linear_mps_kernel_out<2>); + m.impl("_linear_fp_act_3bit_weight.out", &linear_mps_kernel_out<3>); + m.impl("_linear_fp_act_4bit_weight.out", &linear_mps_kernel_out<4>); + m.impl("_linear_fp_act_5bit_weight.out", &linear_mps_kernel_out<5>); + m.impl("_linear_fp_act_6bit_weight.out", &linear_mps_kernel_out<6>); + m.impl("_linear_fp_act_7bit_weight.out", &linear_mps_kernel_out<7>); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm new file mode 100644 index 0000000000..2892a67245 --- /dev/null +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm @@ -0,0 +1,138 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::backends::mps::delegate::getMTLBufferStorage; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::tensor_is_rank; + +namespace { + +std::string scalar_type_to_string(const ScalarType& scalar_type) { + switch (scalar_type) { + case ScalarType::Float: + return "float"; + case ScalarType::Half: + return "half"; + case ScalarType::BFloat16: + return "bfloat"; + default: + ET_CHECK_MSG( + false, "Unsupported type by lowbit quantized linear"); + return "undefined"; + } +} + +template +bool check_linear_mps_args( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto N = B.size(0); + auto K = A.size(1); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + A.scalar_type() == ScalarType::BFloat16 || + A.scalar_type() == ScalarType::Half || + A.scalar_type() == ScalarType::Float, + "Expect A to be either 32-bit or 16-bit float tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + tensor_is_rank(A, 2), "Expect A to be 2D tensor."); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.scalar_type() == ScalarType::Byte, "Expect B to be uint8 tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.size(1) == (K / 8) * nbit, "Expect B.size(1) == (K / 8) * nbit"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE(K % 8 == 0, "Expect K to be multiple of 8"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + group_size == 32 || group_size == 64 || group_size == 128 || + group_size == 256, + "Expect group_size to be 32, 64, 128 or 256"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + S.dim() == 2 && S.size(1) == N, + "Expect S to be 2d tensor with shape [:, N]"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + Z.dim() == 2 && Z.size(1) == N, + "Expect Z to be 2d tensor with shape [:, N]"); + + return true; +} + +template +Tensor& linear_mps_kernel_et_ctx_out( + KernelRuntimeContext& ctx, + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + check_linear_mps_args(A, B, group_size, S, Z), + InvalidArgument, + out); + + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + + torchao::kernels::mps::lowbit::LowBitQuantWeights::linear( + getMTLBufferStorage(A), + getMTLBufferStorage(B), + group_size, + getMTLBufferStorage(S), + getMTLBufferStorage(Z), + getMTLBufferStorage(out), + M, + K, + N, + scalar_type_to_string(A.scalar_type())); + + return out; +} + +} // namespace + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_1bit_weight.out", linear_mps_kernel_et_ctx_out<1>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_2bit_weight.out", linear_mps_kernel_et_ctx_out<2>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_3bit_weight.out", linear_mps_kernel_et_ctx_out<3>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_4bit_weight.out", linear_mps_kernel_et_ctx_out<4>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_5bit_weight.out", linear_mps_kernel_et_ctx_out<5>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_6bit_weight.out", linear_mps_kernel_et_ctx_out<6>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_7bit_weight.out", linear_mps_kernel_et_ctx_out<7>); +} diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index f4c460a368..acff5624c8 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -11,7 +11,7 @@ from parameterized import parameterized -libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) ) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 00c08738c2..5b3331c6a8 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -17,7 +17,7 @@ from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer from torchao.experimental.quant_api import _quantize -libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) ) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index be72a59aab..0904d1d174 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -469,21 +469,19 @@ def quantize(self, model: nn.Module) -> nn.Module: return model -from torchao.experimental._linear_8bit_act_xbit_weight_layout import Linear8BitActXBitWeightLayout -from torchao.quantization.quant_api import ( - _get_linear_subclass_inserter, - MappingType, - to_affine_quantized_intx, - ZeroPointDomain, -) - - def int8_dynamic_activation_intx_weight( group_size: int = 128, nbit: int = 4, has_weight_zeros: bool = False, target: str = "native", ): + from torchao.experimental._linear_8bit_act_xbit_weight_layout import Linear8BitActXBitWeightLayout + from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + MappingType, + to_affine_quantized_intx, + ZeroPointDomain, + ) def apply(weight): assert weight.shape[-1] % group_size == 0 @@ -541,10 +539,11 @@ def quantize_and_pack_weights(self, weights, nbit, group_size): ) weight_scales = torch.transpose_copy(weight_scales, 1, 0) weight_zeros = torch.transpose_copy(weight_zeros, 1, 0) - self.weight_scales = weight_scales - self.weight_zeros = -weight_zeros * weight_scales - - self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + weight_zeros = -weight_zeros * weight_scales + self.weight_scales = nn.Parameter(weight_scales, requires_grad=False) + self.weight_zeros = nn.Parameter(weight_zeros, requires_grad=False) + packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + self.packed_weights = nn.Parameter(packed_weights, requires_grad=False) def forward(self, x): assert x.dim() >= 2 From 7d7c14e898eca3fe66138d2a9445755a9270b800 Mon Sep 17 00:00:00 2001 From: Manuel Candales <42380156+manuelcandales@users.noreply.github.com> Date: Fri, 13 Dec 2024 12:48:25 -0500 Subject: [PATCH 38/40] metal lowbit kernels: glob metal files in CMakeLists (#1410) --- torchao/experimental/ops/mps/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt index 1d41f75854..820205fa27 100644 --- a/torchao/experimental/ops/mps/CMakeLists.txt +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -27,12 +27,13 @@ find_package(Torch REQUIRED) # Generate metal_shader_lib.h by running gen_metal_shader_lib.py set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal) +file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal) set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py) set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) add_custom_command( OUTPUT ${GENERATED_METAL_SHADER_LIB} COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB} - DEPENDS ${METAL_SHADERS_DIR}/*.metal ${GEN_SCRIPT} + DEPENDS ${METAL_FILES} ${GEN_SCRIPT} COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py" ) add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB}) From a5557349fa6f0422ff1966dcd3539a9fe26063d1 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 13 Dec 2024 17:49:21 -0500 Subject: [PATCH 39/40] Clarify how users should fix ruff errors (#1416) Before: ``` Would reformat: test/quantization/test_qat.py Would reformat: torchao/quantization/qat/api.py 2 files would be reformatted, 138 files already formatted ``` After: ``` Would reformat: test/quantization/test_qat.py Would reformat: torchao/quantization/qat/api.py 2 files would be reformatted, 138 files already formatted Ruff check failed, please try again after running 'ruff format'. ``` --- .github/workflows/ruff_linter.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ruff_linter.yml b/.github/workflows/ruff_linter.yml index dec9bdef1a..027279721e 100644 --- a/.github/workflows/ruff_linter.yml +++ b/.github/workflows/ruff_linter.yml @@ -70,7 +70,10 @@ jobs: # please be careful when using this large changes means everyone needs to rebase ruff check --isolated --select F821,F823,W191 ruff check --select F,I - ruff format --check + ruff format --check || { + echo "Ruff check failed, please try again after running 'ruff format'." + exit 1 + } - name: Apply fixes to PR if: github.event_name == 'workflow_dispatch' From 46b8796412eb350d1923091892850582d32737d0 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 13 Dec 2024 17:49:45 -0500 Subject: [PATCH 40/40] Add back mistakenly deleted QAT BC import test (#1417) Summary: The unused imports in this test were mistakenly deleted in https://github.com/pytorch/ao/pull/1359. This commit adds them back. Test Plan: python test/quantization/test_qat.py --- test/quantization/test_qat.py | 47 +++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 3a998635aa..8862d88b54 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1108,6 +1108,53 @@ def test_qat_prototype_bc(self): Just to make sure we can import all the old prototype paths. We will remove this test in the near future when we actually break BC. """ + from torchao.quantization.prototype.qat import ( # noqa: F401, F811, I001 + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + ComposableQATQuantizer, + Int8DynActInt4WeightQATLinear, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, + ) + from torchao.quantization.prototype.qat._module_swap_api import ( # noqa: F401, F811 + disable_4w_fake_quant_module_swap, + enable_4w_fake_quant_module_swap, + disable_8da4w_fake_quant_module_swap, + enable_8da4w_fake_quant_module_swap, + Int4WeightOnlyQATQuantizerModuleSwap, + Int8DynActInt4WeightQATQuantizerModuleSwap, + ) + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( # noqa: F401, F811 + AffineFakeQuantizedTensor, + to_affine_fake_quantized, + ) + from torchao.quantization.prototype.qat.api import ( # noqa: F401, F811 + ComposableQATQuantizer, + FakeQuantizeConfig, + ) + from torchao.quantization.prototype.qat.embedding import ( # noqa: F401, F811 + FakeQuantizedEmbedding, + Int4WeightOnlyEmbeddingQATQuantizer, + Int4WeightOnlyEmbedding, + Int4WeightOnlyQATEmbedding, + ) + from torchao.quantization.prototype.qat.fake_quantizer import ( # noqa: F401, F811 + FakeQuantizer, + ) + from torchao.quantization.prototype.qat.linear import ( # noqa: F401, F811 + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + FakeQuantizedLinear, + Int4WeightOnlyQATLinear, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATLinear, + Int8DynActInt4WeightQATQuantizer, + ) if __name__ == "__main__":