diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 4cf21c2352..65cca3f10f 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -1 +1,3 @@ mergebot: True +ciflow_push_tags: +- ciflow/benchmark diff --git a/.github/workflows/dashboard_perf_test.yml b/.github/workflows/dashboard_perf_test.yml new file mode 100644 index 0000000000..62823e8895 --- /dev/null +++ b/.github/workflows/dashboard_perf_test.yml @@ -0,0 +1,49 @@ +name: A100-perf-nightly + +on: + push: + tags: + - ciflow/benchmark/* + workflow_dispatch: + schedule: + - cron: 0 7 * * 0-6 + +jobs: + benchmark: + runs-on: linux.aws.a100 + strategy: + matrix: + torch-spec: + - '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124' + steps: + - uses: actions/checkout@v3 + + - name: Setup miniconda + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: "3.9" + + - name: Run benchmark + shell: bash + run: | + set -eux + ${CONDA_RUN} python -m pip install --upgrade pip + ${CONDA_RUN} pip install ${{ matrix.torch-spec }} + ${CONDA_RUN} pip install -r dev-requirements.txt + ${CONDA_RUN} pip install . + + export CHECKPOINT_PATH=checkpoints + export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + ${CONDA_RUN} python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + ${CONDA_RUN} python scripts/convert_hf_checkpoint.py --checkpoint_dir "${CHECKPOINT_PATH}/${MODEL_REPO}" + + mkdir -p ${{ runner.temp }}/benchmark-results + ${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --output_json_path ${{ runner.temp }}/benchmark-results/benchmark-results.json + + - name: Upload the benchmark results to OSS benchmark database for the dashboard + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: ${{ runner.temp }}/benchmark-results + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index d9649b7f7e..0488e6d922 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -70,6 +70,7 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" + - name: CPU 2.3 runs-on: linux.4xlarge torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' diff --git a/.github/workflows/ruff_linter.yml b/.github/workflows/ruff_linter.yml index dec9bdef1a..027279721e 100644 --- a/.github/workflows/ruff_linter.yml +++ b/.github/workflows/ruff_linter.yml @@ -70,7 +70,10 @@ jobs: # please be careful when using this large changes means everyone needs to rebase ruff check --isolated --select F821,F823,W191 ruff check --select F,I - ruff format --check + ruff format --check || { + echo "Ruff check failed, please try again after running 'ruff format'." + exit 1 + } - name: Apply fixes to PR if: github.event_name == 'workflow_dispatch' diff --git a/README.md b/README.md index 1af5a7013c..6ba0e3be4c 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ We're also fortunate to be integrated into some of the leading open-source libra 3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference) 4. [TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes 5. [torchchat](https://github.com/pytorch/torchchat) for post training quantization -6. [SGLang](https://github.com/sgl-project/sglang/pull/1341) for LLM inference quantization +6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341). ## Videos * [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009) @@ -205,4 +205,5 @@ If you find the torchao library useful, please cite it in your work as below. license = {BSD-3-Clause}, month = oct, year = {2024} +} ``` diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index f4f2813a37..e545ea4665 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -6,6 +6,7 @@ import copy import io +import functools import os import random from contextlib import nullcontext, redirect_stdout @@ -22,6 +23,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.checkpoint import ( + checkpoint, + create_selective_checkpoint_contexts, + CheckpointPolicy, +) from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -254,6 +260,22 @@ def profile_function( return prof +# set up AC for max(abs(tensor)) +# context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts +ops_to_save = [ + torch.ops.aten.abs.default, + torch.ops.aten.max.default, +] + +def policy_fn(ctx, op, *args, **kwargs): + if op in ops_to_save: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + +context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + + def main( profile_path_prefix: pathlib.Path, compile: bool = True, @@ -265,6 +287,7 @@ def main( dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, enable_sync_amax_history: bool = True, + enable_activation_checkpointing: bool = False, ): assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") @@ -294,6 +317,7 @@ def main( print(f"Compile is set to | {compile}") print(f"model_type is set to | {model_type}") print(f"scaling_repr is set to | {scaling_repr}") + print(f"enable_activation_checkpointing is set to {enable_activation_checkpointing}") device = "cuda" ref_dtype = torch.bfloat16 @@ -338,11 +362,17 @@ def main( convert_to_float8_training(m_float8, config=config) def ref_forw_backward(x): - out = m_ref(x) + if enable_activation_checkpointing: + out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_ref(x) out.sum().backward() def float8_forw(x): - out = m_float8(x) + if enable_activation_checkpointing: + out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn) + else: + out = m_float8(x) return out sync_amax_history = sync_float8_amax_and_scale_history diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index 5bc0a0674c..7f2b312e85 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -9,8 +9,8 @@ torchao.quantization .. autosummary:: :toctree: generated/ :nosignatures: - autoquant - + + autoquant quantize_ int8_dynamic_activation_int4_weight int8_dynamic_activation_int8_weight @@ -21,12 +21,9 @@ torchao.quantization float8_static_activation_float8_weight uintx_weight_only fpx_weight_only - to_linear_activation_quantized - swap_linear_with_smooth_fq_linear smooth_fq_linear_to_inference - choose_qparams_affine choose_qparams_affine_with_min_max choose_qparams_affine_floatx @@ -37,10 +34,8 @@ torchao.quantization choose_qparams_and_quantize_affine_hqq fake_quantize_affine fake_quantize_affine_cachemask - safe_int_mm int_scaled_matmul - MappingType ZeroPointDomain TorchAODType diff --git a/docs/source/index.rst b/docs/source/index.rst index befe30570c..c008c80453 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,7 +1,11 @@ Welcome to the torchao Documentation ======================================= -`**torchao** `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on 1. API Reference 2. Developer / Researcher Contribution Guide 3. Tutorials. +`torchao `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on: + +1. API Reference +2. Developer Contribution Guide +3. Tutorials .. .. grid:: 3 @@ -96,7 +100,6 @@ Welcome to the torchao Documentation :glob: :maxdepth: 1 :caption: Tutorials - :hidden: serialization diff --git a/examples/sam2_amg_server/README.md b/examples/sam2_amg_server/README.md index 43fc2b2528..c09b012c26 100644 --- a/examples/sam2_amg_server/README.md +++ b/examples/sam2_amg_server/README.md @@ -41,12 +41,32 @@ The 'ao' mode is a copy of the baseline with modifications to make the code more ### 0. Download checkpoints and install requirements ``` -pip install -r requirements.txt +# From the top-level "ao" directory + +# If necessary, create and activate a virtual environment +# Ex: +python -m venv venv && source venv/bin/activate + +# Install requirements for this example +pip install -r examples/sam2_amg_server/requirements.txt + +# If you have an older version of torch in your current environment, uninstall it first +pip uninstall torch + +# Install torch nightly +pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 + +# Build ao from source for now +python setup.py develop + +# On your mark, get set... +cd examples/sam2_amg_server/ ``` Download `sam2.1_hiera_large.pt` from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints and put it into `~/checkpoints/sam2` ### 1. Create a random subset of 1000 images +Using images with corresponding mask annotations, like from the Segment Anything Video (SA-V) [Dataset](https://github.com/facebookresearch/sam2/tree/main/sav_dataset#download-the-dataset) is suggested, to later compare any drop in accuracy using `--furious` (using `torch.float16`). ``` find sav_val -type f > sav_val_image_paths shuf -n 1000 sav_val_image_paths > sav_val_image_paths_shuf_1000 diff --git a/examples/sam2_amg_server/cli.py b/examples/sam2_amg_server/cli.py index 2fead4b5a4..9cf5bdc8f3 100644 --- a/examples/sam2_amg_server/cli.py +++ b/examples/sam2_amg_server/cli.py @@ -6,6 +6,8 @@ from server import model_type_to_paths from server import MODEL_TYPES_TO_MODEL from server import set_fast +from server import set_aot_fast +from server import load_aot_fast from server import set_furious from torchao._models.sam2.build_sam import build_sam2 from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator @@ -22,17 +24,20 @@ def main_docstring(): """ -def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False): +def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""): device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) if verbose: print(f"Loading model {sam2_checkpoint} with config {model_cfg}") sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") - if fast: - set_fast(mask_generator) if furious: set_furious(mask_generator) + if load_fast: + load_aot_fast(mask_generator, load_fast) + if fast: + set_fast(mask_generator, load_fast) + image_tensor = file_bytes_to_image_tensor(input_bytes) if verbose: print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.") @@ -50,7 +55,7 @@ def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=102 buf.seek(0) return buf.getvalue() -def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False): +def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""): input_bytes = bytearray(open(input_path, 'rb').read()) output_bytes = main_headless(checkpoint_path, model_type, @@ -59,7 +64,8 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch= output_format=output_format, verbose=verbose, fast=fast, - furious=furious) + furious=furious, + load_fast=load_fast) with open(output_path, "wb") as file: file.write(output_bytes) diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index d779411c93..060a5ad5dd 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -332,14 +332,175 @@ def model_type_to_paths(checkpoint_path, model_type): model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}" return sam2_checkpoint, model_cfg -def set_fast(mask_generator): - # TODO: Using CUDA graphs can cause numerical differences? - mask_generator.predictor.model.image_encoder = torch.compile( - mask_generator.predictor.model.image_encoder, - mode="max-autotune", - fullgraph=True, - dynamic=False, + +def aot_compile(model_directory, name, fn, sample_args): + path = Path(model_directory) / Path(f"{name}.pt2") + print(f"Saving at {path=}") + options = { + "max_autotune": True, + "triton.cudagraphs": True, + } + + exported = torch.export.export_for_inference(fn, sample_args) + output_path = torch._inductor.aoti_compile_and_package( + exported, + package_path=str(path), + inductor_configs=options, ) + return output_path + + +def aot_load(path): + return torch._export.aot_load(path, "cuda") + +class FunctionModel(torch.nn.Module): + + def __init__(self, module, fn_name): + super().__init__() + self.module = module + self.fn_name = fn_name + + def forward(self, *args): + return getattr(self.module, self.fn_name)(*args) + + +def set_aot_fast(mask_generator, model_directory): + example_input = torch.empty(1, 3, 1024, 1024) + example_input = example_input.to(mask_generator.predictor._image_dtype) + example_input = (example_input.to(mask_generator.predictor.device),) + aot_compile(model_directory, + "sam2_image_encoder", + mask_generator.predictor.model.image_encoder, + example_input) + + # NOTE: THIS DOESN'T WORK YET! + # example_input_0_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_0_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device) + # example_input = ([example_input_0_0, example_input_0_1], + # example_input_1, + # example_input_2, + # example_input_3, + # None, + # None, + # True, + # True, + # -1) + # mask_generator.forward = mask_generator.predictor._predict_masks_with_features + # mask_generator(*example_input) + # aot_compile("sam2__predict_masks_with_features", + # mask_generator, + # example_input) + + # example_input_2 = torch.empty(1024, 1, 2, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 1, dtype=torch.int32, device=mask_generator.predictor.device) + # aot_compile("sam2_sam_prompt_encoder", + # mask_generator.predictor.model.sam_prompt_encoder, + # ((example_input_2, example_input_3), + # None, + # None)) + + # NOTE: THIS DOESN'T WORK YET! + # example_input_0 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 2, 256, dtype=torch.float32, device=mask_generator.predictor.device) + # example_input_3 = torch.empty(1024, 256, 64, 64, dtype=torch.float32, device=mask_generator.predictor.device) + + # example_input_4_0 = torch.empty(1, 32, 256, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_4_1 = torch.empty(1, 64, 128, 128, dtype=torch.float16, device=mask_generator.predictor.device) + + # example_input = (example_input_0, + # example_input_1, + # example_input_2, + # example_input_3, + # True, + # True, + # [example_input_4_0, example_input_4_1]) + # print("Example") + # mask_generator.predictor.model.sam_mask_decoder(*example_input) + # print("Example done") + # aot_compile("sam2_sam_mask_decoder", + # mask_generator.predictor.model.sam_mask_decoder, + # example_input) + + # example_input_0 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_1 = torch.empty(1024, 256, 64, 64, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input_2 = torch.empty(1024, 8, 256, dtype=torch.float16, device=mask_generator.predictor.device) + # example_input = (example_input_0, example_input_1, example_input_2) + + # mask_generator.predictor.model.sam_mask_decoder.transformer(*example_input) + # aot_compile("sam2_sam_mask_decoder_transformer", + # mask_generator.predictor.model.sam_mask_decoder.transformer, + # example_input) + + + + +class LoadedModel(torch.nn.Module): + + def __init__(self, aoti_compiled_model): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + + def forward(self, *args): + return self.aoti_compiled_model(*args) + +class LoadedDecoder(torch.nn.Module): + + def __init__(self, aoti_compiled_model, other): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + self.other = other + + def forward(self, *args): + return self.aoti_compiled_model(*args) + + def get_dense_pe(self, *args, **kwargs) -> torch.Tensor: + return self.other.get_dense_pe(*args, **kwargs) + +def load_aot_fast(mask_generator, model_directory): + t0 = time.time() + path = Path(model_directory) / Path(f"sam2_image_encoder.pt2") + assert path.exists(), f"Expected {path} to exist." + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = LoadedModel(pkg) + mask_generator.predictor.model.image_encoder = pkg_m + + # NOTE: This doesn't work yet! + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2__predict_masks_with_features.pt2")) + # pkg_m = LoadedModel(pkg) + # mask_generator.predictor._predict_masks_with_features = pkg_m.forward + + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_prompt_encoder.pt2")) + # pkg_m = LoadedDecoder(pkg, mask_generator.predictor.model.sam_prompt_encoder) + # mask_generator.predictor.model.sam_prompt_encoder = pkg_m + + # NOTE: This doesn't work yet! + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder.pt2")) + # pkg_m = LoadedModel(pkg) + # pkg_m.conv_s0 = mask_generator.predictor.model.sam_mask_decoder.conv_s0 + # pkg_m.conv_s1 = mask_generator.predictor.model.sam_mask_decoder.conv_s1 + # mask_generator.predictor.model.sam_mask_decoder = pkg_m + + # pkg = torch._inductor.aoti_load_package(os.path.join(os.getcwd(), "sam2_sam_mask_decoder_transformer.pt2")) + # pkg_m = LoadedModel(pkg) + # mask_generator.predictor.model.sam_mask_decoder.transformer = pkg_m + + print(f"End load. Took {time.time() - t0}s") + + +def set_fast(mask_generator, load_fast=""): + if load_fast == "": + # TODO: Using CUDA graphs can cause numerical differences? + mask_generator.predictor.model.image_encoder = torch.compile( + mask_generator.predictor.model.image_encoder, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) mask_generator.predictor._predict_masks = torch.compile( mask_generator.predictor._predict_masks, @@ -365,12 +526,25 @@ def set_furious(mask_generator): # NOTE: Not baseline feature mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 +def set_autoquant(mask_generator): + from torchao import autoquant + from torchao.quantization import DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + # NOTE: Not baseline feature + mask_generator.predictor.model.image_encoder = autoquant(mask_generator.predictor.model.image_encoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision('high') + # NOTE: this fails when we run + # python server.py ~/checkpoints/sam2 large --port 8000 --host localhost --fast --use_autoquant --unittest + # https://gist.github.com/jerryzh168/d337cb5de0a1dec306069fe48ac8225e + # mask_generator.predictor.model.sam_mask_decoder = autoquant(mask_generator.predictor.model.sam_mask_decoder, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, min_sqnr=40) + def main(checkpoint_path, model_type, baseline=False, fast=False, furious=False, + use_autoquant=False, unittest=False, benchmark=False, profile=None, @@ -380,7 +554,9 @@ def main(checkpoint_path, port=5000, host="127.0.0.1", dry=False, - batch_size=1): + batch_size=1, + load_fast="", + save_fast=""): if verbose: logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', @@ -399,22 +575,34 @@ def main(checkpoint_path, from torchao._models.sam2.build_sam import build_sam2 from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from torchao._models.sam2.utils.amg import rle_to_mask - + device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) - + logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}") sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False) - + logging.info(f"Using {points_per_batch} points_per_batch") mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle") - if fast: - assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." - set_fast(mask_generator) + if load_fast != "": + load_aot_fast(mask_generator, load_fast) if furious: set_furious(mask_generator) + # since autoquant is replicating what furious mode is doing, don't use these two together + elif use_autoquant: + set_autoquant(mask_generator) + + if save_fast != "": + assert load_fast == "", "Can't save compiled models while loading them with --load-fast." + assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." + print(f"Saving compiled models under directory {save_fast}") + set_aot_fast(mask_generator, save_fast) + + if fast: + assert not baseline, "--fast cannot be combined with baseline. code to be torch.compile(fullgraph=True) compatible." + set_fast(mask_generator, load_fast) with open('dog.jpg', 'rb') as f: image_tensor = file_bytes_to_image_tensor(bytearray(f.read())) @@ -494,7 +682,7 @@ async def upload_rle(image: UploadFile = File(...)): await request_queue.put((image_tensor, response_future)) masks = await response_future return masks_to_rle_dict(masks) - + @app.post("/upload") async def upload_image(image: UploadFile = File(...)): image_tensor = file_bytes_to_image_tensor(bytearray(await image.read())) @@ -512,7 +700,7 @@ async def upload_image(image: UploadFile = File(...)): plt.savefig(buf, format='png') buf.seek(0) return StreamingResponse(buf, media_type="image/png") - + # uvicorn.run(app, host=host, port=port, log_level="info") uvicorn.run(app, host=host, port=port) diff --git a/examples/sam2_amg_server/video_profile.py b/examples/sam2_amg_server/video_profile.py new file mode 100644 index 0000000000..e7874879d9 --- /dev/null +++ b/examples/sam2_amg_server/video_profile.py @@ -0,0 +1,456 @@ +import argparse +import time +import os +from datetime import datetime + +import numpy as np +import torch +from PIL import Image, ImageDraw +from server import MODEL_TYPES_TO_MODEL +from server import model_type_to_paths +from pathlib import Path + +from torch._inductor import config as inductorconfig +inductorconfig.triton.unique_kernel_names = True +inductorconfig.coordinate_descent_tuning = True +inductorconfig.coordinate_descent_check_all_directions = True + +from torch.nn.attention import SDPBackend, sdpa_kernel + +# timer.py +import time +from collections import defaultdict + + +class CodeTimer: + def __init__(self): + self.start_times = {} + self.elapsed_times = defaultdict(list) + self.enabled = False + + def tic(self, section_name): + self.start_times[section_name] = time.time() + + def toc(self, section_name): + if section_name in self.start_times: + elapsed_time = time.time() - self.start_times[section_name] + self.elapsed_times[section_name].append(elapsed_time) + del self.start_times[section_name] + + def get_average_time(self, section_name, warmup: int = 1): + times = self.elapsed_times.get(section_name, []) + times = times[warmup:] + return sum(times) / len(times) if times else 0.0 + + def reset(self): + self.start_times.clear() + self.elapsed_times.clear() + + def print_all_timings(self, warmup: int = 5): + if not self.elapsed_times: + print("No timings recorded.") + return + print("Average timings for all sections:") + for section_name in self.elapsed_times: + average_time = self.get_average_time(section_name, warmup) + print(f"{section_name}, {average_time*1000.0:.6f}") + + +global_timer = CodeTimer() + + +def max_memory_allocated(): + max_memory_allocated_bytes = torch.cuda.max_memory_allocated() + _, total_memory = torch.cuda.mem_get_info() + max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory)) + max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 + print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%") + + +def synthesize_video_data( + out_dir: str, + radius: int, + seed: int, + speed: int, + width: int, + height: int, + n_frames: int, + x: int, + y: int, + synthesize_overwrite: bool, +): + circle_color = (255, 0, 0) # red + + os.makedirs(out_dir, exist_ok=True) + + np.random.seed(seed) + # Initial position and velocity + x = np.random.randint(radius, width - radius) + y = np.random.randint(radius, height - radius) + vx = np.random.choice([-1, 1]) * speed + vy = np.random.choice([-1, 1]) * speed + + # TODO: If these frames exist, they will not be deleted in subsequent runs with less frames. + print(f"Generate {n_frames} frames under path {out_dir}") + if not synthesize_overwrite and len(os.listdir(out_dir)) > 0: + raise ValueError(f"Expected folder {out_dir} to be empty unless --synthesize-overwrite is specified.") + # Generate 100 frames + for i in range(n_frames): + # Create a new image with a black background + img = Image.new("RGB", (width, height), (0, 0, 0)) + draw = ImageDraw.Draw(img) + # Draw the circle at its current position + draw.ellipse( + [(x - radius, y - radius), (x + radius, y + radius)], fill=circle_color + ) + # Save the image as a JPEG file + filename = f"{i:03d}.jpg" + img.save(os.path.join(out_dir, filename)) + # Update the circle's position for the next frame + x += vx + y += vy + # Bounce off the edges + if x - radius < 0 or x + radius > width: + vx *= -1 + if y - radius < 0 or y + radius > height: + vy *= -1 + + +def profiler_runner(path, fn, *args, **kwargs): + if path is None: + path = os.path.join( + os.path.expanduser("~/traces"), + f'{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.json.gz', + ) + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + ) as prof: + result = fn(*args, **kwargs) + prof.export_chrome_trace(path) + print(f"Exported trace to {path}") + return result + + +def main_loop(predictor, inference_state, time_profile=True, accumulate_result=False, count_result=False): + results = [] + num_output_frames = 0 + with torch.autograd.profiler.record_function("main_loop"): + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( + inference_state + ): + if accumulate_result: + results.append(out_mask_logits) + if count_result: + num_output_frames += 1 + assert not (accumulate_result and count_result) + if accumulate_result: + return torch.cat(results) + if count_result: + return num_output_frames + + +def run_test( + checkpoint_path: str, + model_type: str, + profile: bool, + video_dir: str, + radius: int, + seed: int, + speed: int, + width: int, + height: int, + n_frames: int, + use_compile: bool, + frame_batch_size: int, + batch_size: int, + synthesize: bool, + synthesize_overwrite: bool, + store_output: str, + compare_output: str, + print_all_timings: bool, + use_baseline: bool, +): + np.random.seed(seed) + start_x = np.random.randint(radius, width - radius) + start_y = np.random.randint(radius, height - radius) + if synthesize: + for i in range(batch_size): + synthesize_video_data( + out_dir=f"{video_dir}_{i}", + radius=radius, + seed=(seed + i), # Make sure every video is different + speed=speed, + width=width, + height=height, + n_frames=n_frames, + x=start_x, + y=start_y, + synthesize_overwrite=synthesize_overwrite, + ) + + # use bfloat16 for the entire notebook + torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) + + build_sam2_video_predictor = None + if use_baseline: + from sam2.build_sam import build_sam2_video_predictor + else: + from torchao._models.sam2.build_sam import build_sam2_video_predictor + + device = "cuda:0" + # hydra_overrides_extra = ["++model.compile_image_encoder=true"] + predictor = build_sam2_video_predictor( + model_cfg, + sam2_checkpoint, + device=device, + # hydra_overrides_extra=hydra_overrides_extra, + ) + predictor._frame_batch_size = frame_batch_size + + inference_states = [] + for i in range(batch_size): + print("i: ", i) + inference_state = predictor.init_state( + video_path=f"{video_dir}_{i}", async_loading_frames=False + ) + _, out_obj_ids, out_mask_logits = predictor.add_new_points( + inference_state=inference_state, + frame_idx=0, + obj_id=1, + points=np.array([[start_x, start_y]], dtype=np.float32), + labels=np.array([1], dtype=np.int32), + ) + inference_states.append(inference_state) + if batch_size == 1: + inference_state = inference_states[0] + else: + inference_state = predictor.batch_inference_states(inference_states) + + if use_compile: + print("Using torch.compile") + predictor.image_encoder.trunk.forward = torch.compile( + predictor.image_encoder.trunk.forward, + # mode="max-autotune-no-cudagraphs", + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + predictor.sam_prompt_encoder.forward = torch.compile( + predictor.sam_prompt_encoder.forward, + # mode="max-autotune-no-cudagraphs", + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + predictor.sam_mask_decoder.transformer = torch.compile( + predictor.sam_mask_decoder.transformer, + mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + predictor._forward_sam_heads = torch.compile( + predictor._forward_sam_heads, + mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + predictor.memory_attention = torch.compile( + predictor.memory_attention, + # mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=True, + ) + + predictor.memory_encoder.forward = torch.compile( + predictor.memory_encoder.forward, + mode="max-autotune", + # mode="max-autotune-no-cudagraphs", + fullgraph=True, + dynamic=False, + ) + + print("\nWarm-up round and gather outputs.") + global_timer.reset() + result = main_loop(predictor=predictor, inference_state=inference_state, accumulate_result=True) + if store_output: + print(f"Writing results to {store_output}") + torch.save(result, store_output) + if compare_output: + print(f"Comparing to results from {compare_output}") + ref_result = torch.load(compare_output) + torch.testing.assert_close(result, ref_result) + print("Passed comparison!") + if print_all_timings: + global_timer.print_all_timings() + + global_timer.reset() + print("\nProfile round.") + if profile is None: + main_loop(predictor=predictor, inference_state=inference_state) + else: + profiler_runner( + profile, + main_loop, + predictor=predictor, + inference_state=inference_state, + ) + if print_all_timings: + global_timer.print_all_timings() + + print("\nFinal timing and memory usage round.") + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + global_timer.reset() + t0 = time.time() + num_output_frames = main_loop(predictor=predictor, inference_state=inference_state, count_result=True) + t = time.time() - t0 + print(f"main_loop took {t}s for {num_output_frames} frames at {num_output_frames / t}fps") + max_memory_allocated() + if print_all_timings: + global_timer.print_all_timings() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_path", + type=str, + help="Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints", + ) + parser.add_argument( + "model_type", + type=str, + help=f"Choose one of {list(MODEL_TYPES_TO_MODEL.keys())}", + ) + parser.add_argument( + "--video_dir", + type=str, + default="/tmp/segment-anything-2/synth_video", + help="Directory to store the synthetic video", + ) + parser.add_argument( + "--profile", + type=str, + dest="profile", + help="If specified stores profile at given path.", + ) + parser.add_argument( + "--radius", + type=int, + default=50, + help="Radius of the circle for synthetic video", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Seed for initial position and velocity", + ) + parser.add_argument( + "--speed", type=int, default=20, help="Speed of the circle for synthetic video" + ) + parser.add_argument( + "--width", type=int, default=1024, help="Width of the synthetic video" + ) + parser.add_argument( + "--height", type=int, default=1024, help="Height of the synthetic video" + ) + parser.add_argument( + "--n_frames", + type=int, + default=200, + help="Number of frames in the synthetic video", + ) + parser.add_argument( + "--use-compile", + action="store_true", + dest="use_compile", + help="Use torch.compile to speed things up. First iteration will be much slower.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="batch_size", + ) + parser.add_argument( + "--frame-batch-size", + type=int, + default=1, + help="frame_batch_size", + ) + parser.add_argument( + "--synthesize", + action="store_true", + dest="synthesize", + help="Synthesize data for the benchmark.", + ) + parser.add_argument( + "--synthesize-overwrite", + action="store_true", + dest="synthesize_overwrite", + help="Overwrite data if it already exists when synthesizing.", + ) + parser.add_argument( + "--store-output", + type=str, + default="", + help="Pass a .pt file to store outputs in.", + ) + parser.add_argument( + "--compare-output", + type=str, + default="", + help="Pass a .pt file to load for comparison.", + ) + parser.add_argument( + "--print-all-timings", + action="store_true", + dest="print_all_timings", + help="Use torch.compile to speed things up. First iteration will be much slower.", + ) + parser.add_argument( + "--use-baseline", + action="store_true", + dest="use_baseline", + help="Use sam2 package instead of torchao._models.sam2", + ) + + args = parser.parse_args() + + run_test( + args.checkpoint_path, + args.model_type, + profile=args.profile, + video_dir=args.video_dir, + radius=args.radius, + seed=args.seed, + speed=args.speed, + width=args.width, + height=args.height, + n_frames=args.n_frames, + use_compile=args.use_compile, + frame_batch_size=args.frame_batch_size, + batch_size=args.batch_size, + synthesize=args.synthesize, + synthesize_overwrite=args.synthesize_overwrite, + store_output=args.store_output, + compare_output=args.compare_output, + print_all_timings=args.print_all_timings, + use_baseline=args.use_baseline, + ) diff --git a/ruff.toml b/ruff.toml index 09d0a1ec97..b20cab030c 100644 --- a/ruff.toml +++ b/ruff.toml @@ -7,13 +7,17 @@ include = [ "torchao/quantization/**/*.py", "torchao/dtypes/**/*.py", "torchao/sparsity/**/*.py", + "torchao/profiler/**/*.py", + "torchao/testing/**/*.py", "torchao/prototype/low_bit_optim/**.py", + "torchao/utils.py", + "torchao/ops.py", + "torchao/_executorch_ops.py", "test/float8/**/*.py", - "test/quantization/test_observer.py", + "test/quantization/**/*.py", "test/dtypes/**/*.py", + "test/sparsity/**/*.py", "test/prototype/low_bit_optim/**.py", - "torchao/utils.py", - ] lint.ignore = ["E731"] diff --git a/scripts/clean_release_notes.py b/scripts/clean_release_notes.py new file mode 100644 index 0000000000..1055f288d0 --- /dev/null +++ b/scripts/clean_release_notes.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# ============================================================= +# This script automatically cleans up the raw release notes +# generated by github by doing an initial pass to sort the +# commits. The output still requires manual reviewing. +# +# This script uses PyGithub. If you don't have it yet, please +# install it using: +# +# pip install PyGithub +# +# We expect the following format for the input release notes: +# +# ## What's Changed +# * commit1_title by @userX in https://github.com/pytorch/ao/pull/123 +# * commit2_title by @userY in https://github.com/pytorch/ao/pull/234 +# * commit3_title by @userZ in https://github.com/pytorch/ao/pull/345 +# +# ## New Contributors +# * @userX made their first contribution in https://github.com/pytorch/ao/pull/123 +# * @userY made their first contribution in https://github.com/pytorch/ao/pull/234 +# +# Example output: +# +# ## Highlights +# +# We are excited to announce the X.Y.Z release of torchao! This release adds support for A, B, C, D! +# +# ### Highlight Feature 1 +# +# ### Highlight Feature 2 +# +# ## BC-Breaking +# +# ## Deprecation +# +# ## New Features +# * commit1_title (https://github.com/pytorch/ao/pull/123) +# +# ## Improvement +# * commit2_title (https://github.com/pytorch/ao/pull/234) +# +# ## Bug Fixes +# * commit3_title (https://github.com/pytorch/ao/pull/345) +# +# ## Performance +# +# ## Documentation +# +# ## Developers +# +# ## New Contributors +# * @userX made their first contribution in https://github.com/pytorch/ao/pull/123 +# * @userY made their first contribution in https://github.com/pytorch/ao/pull/234 +# +# ============================================================= + + +import os +import re +import sys +from typing import Dict, List, Optional + +try: + from github import Github +except ImportError as err: + raise ValueError("PyGithub not installed, please run 'pip install PyGithub'") from err + +if len(sys.argv) != 2: + print("Usage: python clean_release_notes.py [raw_release_notes.txt]") + sys.exit(1) + +input_file = sys.argv[1] +output_file = input_file + ".out" +VERBOSE = os.getenv("VERBOSE", "true").lower() == "true" +GITHUB_LABEL_TO_CATEGORY = { + "topic: bc-breaking": "BC Breaking", + "topic: deprecation": "Deprecation", + "topic: new feature": "New Features", + "topic: improvement": "Improvement", + "topic: bug fix": "Bug Fixes", + "topic: performance": "Performance", + "topic: documentation": "Documentation", + "topic: for developer": "Developers", +} + + +def clean_release_notes(): + """ + Main entry point for this script. + + This function pre-processes the raw release notes and produces a template + with all the standard sections and pre-sorts the commits into different + categories based on github labels and commit title keywords. + """ + + # Write the header section + with open(output_file, "w") as out_f: + out_f.write("## Highlights\n\n") + out_f.write("We are excited to announce the X.Y.Z release of torchao! This release adds support for A, B, C, D!\n\n") + out_f.write("### Highlight Feature 1\n\n") + out_f.write("### Highlight Feature 2\n\n") + + # Sort commits into different categories and write them to output file + # For lines after the commits, just copy them to the output file as is + commit_lines = [] + commit_start = False + commits_by_category = { + "BC Breaking": [], + "Deprecations": [], + "New Features": [], + "Improvement": [], + "Bug Fixes": [], + "Performance": [], + "Documentation": [], + "Developers": [], + } + with open(input_file, "r") as in_f, open(output_file, "a") as out_f: + for line in in_f.readlines(): + if line.startswith("## What's Changed"): + commit_start = True + elif commit_start and line.startswith("*"): + commit_lines.append(line) + elif commit_start: + # End of commits, fetch PR labels based on commits collected so far + commit_start = False + pr_number_to_label = fetch_pr_labels(commit_lines) + # Assign each commit to a category + for commit_line in commit_lines: + category = get_commit_category(commit_line, pr_number_to_label) + if category is not None: + commits_by_category[category].append(commit_line) + # Write all commits to the output file by category + for category, commits in commits_by_category.items(): + out_f.write("## %s\n\n" % category) + for commit_line in commits: + out_f.write(format_commit(commit_line)) + out_f.write("\n") + else: + # Not a commit, just copy to the output file + out_f.write(line) + print("Wrote to %s." % output_file) + + +def parse_pr_number(commit_line: str) -> int: + """ + Helper function to parse PR number from commit line. + """ + return int(re.match(".*pytorch/ao/pull/(.*)", commit_line).groups()[0]) + + +def fetch_pr_labels(commit_lines: List[str]) -> Dict[int, str]: + """ + Fetch the relevant github labels starting with "topic: " from all PRs. + If such a label exists for a given PR, store the first one. + """ + pr_number_to_label = {} + all_pr_numbers = [parse_pr_number(line) for line in commit_lines] + smallest_pr_number = min(all_pr_numbers) + repo = Github().get_repo("pytorch/ao") + + # This call fetches 30 PRs at a time in descending order of when the PR was created + pulls = repo.get_pulls(state="closed") + for pr in pulls: + if pr.number < smallest_pr_number: + break + labels = [l.name for l in pr.labels if l.name.startswith("topic: ")] + if len(labels) > 0: + if VERBOSE: + print("Found label for PR %s: '%s'" % (pr.number, labels[0])) + pr_number_to_label[pr.number] = labels[0] + return pr_number_to_label + + +def get_commit_category(commit_line: str, pr_number_to_label: Dict[int, str]) -> Optional[str]: + """ + Assign the commit to a category based on: + (1) The github label if it exists + (2) Keywords in the PR title + + If the commit is not meant to be user facing, remove None. + Otherwise, return "Improvement" by default. + """ + pr_number = parse_pr_number(commit_line) + if pr_number in pr_number_to_label: + label = pr_number_to_label[pr_number] + if label == "topic: not user facing": + return None + if label in GITHUB_LABEL_TO_CATEGORY: + return GITHUB_LABEL_TO_CATEGORY[label] + elif any(x in commit_line.lower() for x in ["revert", "version.txt"]): + return None + elif any(x in commit_line.lower() for x in ["doc", "readme", "tutorial", "typo", "example", "spelling"]): + return "Documentation" + elif any(x in commit_line.lower() for x in ["test", "lint", " ci", "nightl"]): + return "Developers" + elif " fix" in commit_line.lower(): + return "Bug Fixes" + elif " add" in commit_line.lower(): + return "New Features" + else: + return "Improvement" + + +def format_commit(commit_line: str) -> str: + """ + Format the commit line as follows: + Before: * commit title by @userX in https://github.com/pytorch/ao/pull/123 + After: * Commit title (https://github.com/pytorch/ao/pull/123) + """ + # Remove author, put PR link in parentheses + commit_line = re.sub(" by @.* in (.*)", " (\g<1>)", commit_line) + # Capitalize first letter + commit_line = commit_line.lstrip("* ") + commit_line = "* " + commit_line[0].upper() + commit_line[1:] + return commit_line + + +if __name__ == "__main__": + clean_release_notes() diff --git a/scripts/prepare.sh b/scripts/prepare.sh index db426e3b11..9cbc8295ee 100644 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B python scripts/download.py --repo_id meta-llama/Llama-3.2-3B +python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4 python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B +# neuralmagic doesn't come with tokenizer, so we need to copy it over +mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4 diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index e049500e3b..43d57b7d12 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,7 +8,7 @@ run_tests, ) -from torchao.dtypes import SemiSparseLayout +from torchao.dtypes import Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, int4_weight_only, @@ -17,12 +17,14 @@ int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, +) -def get_quantization_functions(do_sparse: bool, do_int4: bool): +def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): base_functions = [ int8_weight_only(), int8_dynamic_activation_int4_weight(), @@ -30,14 +32,19 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: - base_functions.append(int4_weight_only(group_size=32)) + if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: + base_functions.append( + int4_weight_only(group_size=32, layout=Int4CPULayout()) + ) + else: + base_functions.append(int4_weight_only(group_size=32)) if do_sparse: base_functions.append( int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) - if is_cuda_8_9: + if is_sm_at_least_89(): base_functions.append(float8_weight_only()) return base_functions @@ -152,30 +159,28 @@ class TestAffineQuantizedBasic(TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.bfloat16] - @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_flatten_unflatten(self, apply_quant, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(linear) - lp_tensor = ql.weight - tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = { - name: getattr(lp_tensor, name) for name in tensor_data_name_dict - } - outer_size = lp_tensor.size() - outer_stride = lp_tensor.stride() - reconstructed = type(lp_tensor).__tensor_unflatten__( - tensor_data_dict, tensor_attributes, outer_size, outer_stride - ) - example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) - ref = ql(*example_inputs) - ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) - reconstruct_res = ql(*example_inputs) - self.assertEqual(reconstruct_res, ref) + def test_flatten_unflatten(self, device, dtype): + apply_quant_list = get_quantization_functions(False, True, device) + for apply_quant in apply_quant_list: + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + ql = apply_quant(linear) + lp_tensor = ql.weight + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } + outer_size = lp_tensor.size() + outer_stride = lp_tensor.stride() + reconstructed = type(lp_tensor).__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) + example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) + ref = ql(*example_inputs) + ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) + reconstruct_res = ql(*example_inputs) + self.assertEqual(reconstruct_res, ref) common_utils.instantiate_parametrized_tests(TestAffineQuantized) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 74c130dc5e..4d8312b427 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -37,13 +37,14 @@ MappingType, choose_qparams_affine, ) +from torchao.utils import ( + is_sm_at_least_89, + is_sm_at_least_90, +) random.seed(0) torch.manual_seed(0) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - class ToyLinearModel(torch.nn.Module): def __init__(self, in_features, out_features): @@ -59,12 +60,14 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( - "granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()] + "granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()] ) # Inputs are (M,..), K, N @common_utils.parametrize( @@ -134,12 +137,16 @@ def test_fp8_linear_variants( compute_error(output_original, output_quantized) > 20 ), f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): float8_dynamic_activation_float8_weight(granularity="invalid") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_mismatched_granularity(self): with pytest.raises( ValueError, @@ -147,7 +154,9 @@ def test_mismatched_granularity(self): ): float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_unsupported_granularity(self): class UnsupportedGranularity: pass @@ -158,7 +167,9 @@ class UnsupportedGranularity: ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_per_row_with_float32(self): with pytest.raises( AssertionError, @@ -170,7 +181,9 @@ def test_per_row_with_float32(self): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model @@ -240,7 +253,9 @@ def test_serialization(self, mode: str): ), f"Scales do not match for {layer_name}" @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9") + @unittest.skipIf( + not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 82d3d2501d..da20b930d3 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -181,6 +181,9 @@ class TestFloat8dqRowAffineQuantizedTensorParallel( def test_tp(self, dtype): return self._test_tp(dtype) + common_utils.instantiate_parametrized_tests( + TestFloat8woAffineQuantizedTensorParallel + ) common_utils.instantiate_parametrized_tests( TestFloat8dqTensorAffineQuantizedTensorParallel ) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index d00b96d3bb..58df3a343c 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,7 +14,11 @@ import torch import torch.nn as nn -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -26,6 +30,8 @@ Float8LinearRecipeName, ScalingGranularity, ScalingType, + e4m3_dtype, + e5m2_dtype, recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear @@ -49,8 +55,6 @@ from torchao.float8.float8_utils import ( FP8_TYPES, compute_error, - e4m3_dtype, - e5m2_dtype, fp8_tensor_statistics, tensor_to_scale, ) @@ -60,10 +64,6 @@ torch.manual_seed(0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - - def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -141,6 +141,25 @@ def test_copy_(self): fp8_b.copy_(fp8_a) torch.testing.assert_close(fp8_a._data, fp8_b._data) + def test_transpose(self): + a = torch.rand((16, 16), dtype=torch.bfloat16) + for axiswise_dim in (None, 0, -1): + scale_a = tensor_to_scale(a, e4m3_dtype) + fp8_a = hp_tensor_and_scale_to_float8( + a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim + ) + fp8_b = hp_tensor_and_scale_to_float8( + a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim + ) + + fp8_a_transposed = fp8_a.transpose(0, 1) + fp8_b_t = fp8_b.t() + + torch.testing.assert_close( + (fp8_a_transposed._data, fp8_a_transposed._scale), + (fp8_b_t._data, fp8_b_t._scale), + ) + @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) def test_axiswise_dynamic_cast(self, shape, axiswise_dim): @@ -219,7 +238,7 @@ def test_axiswise_reshape(self): ], ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0") + @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") @@ -333,7 +352,9 @@ def _test_linear_impl( # verify initialization flags got updated assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", @@ -415,7 +436,9 @@ def test_linear_from_recipe( config, ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @@ -462,7 +485,9 @@ def test_autocast_outputs( @pytest.mark.parametrize( "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] ) - @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True]) + @pytest.mark.parametrize( + "emulate", [True, False] if is_sm_at_least_89() else [True] + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) @@ -521,9 +546,9 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s + assert "i:dyn_ten_e4m3,w:del_ten_e4m3,go:dyn_ten_e5m2" in s - @unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): x = torch.randn(32, 32, device="cuda") m = nn.Sequential(nn.Linear(32, 32)).cuda() @@ -531,10 +556,25 @@ def test_inference_mode(self): with torch.inference_mode(mode=True): m(x) + @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") + def test_quantize(self): + x = torch.randn(32, 32, device="cuda") + m = nn.Sequential(nn.Linear(32, 32)).cuda() + m = convert_to_float8_training(m) + assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" + from torchao.quantization.quant_api import float8_weight_only, quantize_ + + quantize_(m, float8_weight_only()) + assert ( + m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn + ), "Post quantization dtype should be torch.float8_e4m3fn" + with torch.no_grad(): + m(x) + class TestScaledMM: @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -576,10 +616,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: - atol, rtol = 2e-3, 2e-3 + atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_cuda_8_9, "CUDA not available") + @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") def test_different_configs_error(self): x_fp32 = torch.randn(16, 16, device="cuda") x_scale = torch.tensor(1.0, device="cuda") @@ -615,7 +655,7 @@ def test_different_configs_error(self): a @ b @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ced5db7ff3..9a9e555cb2 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -26,6 +30,7 @@ Float8LinearConfig, Float8LinearRecipeName, ScalingType, + e4m3_dtype, recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear @@ -43,13 +48,8 @@ LinearMMConfig, ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype from torchao.testing.float8.test_utils import get_test_float8_linear_config -# TODO(future PR): standardize IS_H100 with the rest of the codebase -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - def _test_compile_base( backend: str, @@ -99,7 +99,7 @@ def _test_compile_base( "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only( @@ -126,7 +126,7 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize( "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] ) @@ -177,7 +177,7 @@ def test_aot_eager( [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -215,7 +215,9 @@ def test_inductor_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) -@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available") +@unittest.skipIf( + not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" +) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() config = recipe_name_to_linear_config(recipe_name) @@ -253,7 +255,7 @@ def forward(self, x): # TODO(future): figure out why the test below fails on CUDA capability 8.9 @unittest.skipIf( - not torch.cuda.is_available() or not is_H100, + not torch.cuda.is_available() or not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): @@ -269,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self): torch.testing.assert_close(y_eager, y_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_input(self): @@ -293,7 +295,7 @@ def to_float(x): torch.testing.assert_close(y2_eager, y2_compiled) @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_output(self): @@ -323,7 +325,7 @@ def test_float8_graph_output(self): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func(): @@ -364,7 +366,7 @@ def __exit__(self, *args): @unittest.skipIf( - not torch.cuda.is_available() or not is_cuda_8_9, + not torch.cuda.is_available() or not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_sync_amax_func_cuda_graph_success(): @@ -396,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success(): @unittest.skipIf( - not is_cuda_8_9, + not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 5985a3f5b5..41b21e4406 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -31,9 +31,9 @@ from tqdm import tqdm from torchao.float8 import Float8LinearConfig -from torchao.float8.config import CastConfig, ScalingType +from torchao.float8.config import CastConfig, ScalingType, e4m3_dtype from torchao.float8.float8_linear_utils import convert_to_float8_training -from torchao.float8.float8_scaling_utils import NoopFwToFloat8E5M2BwDynamic +from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -45,7 +45,7 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) -from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale +from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.dtensor_utils import ToyModel @@ -173,7 +173,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): ) out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) - out = NoopFwToFloat8E5M2BwDynamic.apply(out, LinearMMConfig()) + out = NoopFwToFloat8BwDynamic.apply(out, LinearMMConfig(), fp8_dtype) assert isinstance(out, DTensor), f"Expected DTensor, got {type(out)}" loss = torch.sum(torch.abs(out - dist_target)) loss.backward() diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index c3e31816ad..fbe5c9b508 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -6,7 +6,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -40,8 +40,7 @@ from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py index d5c0d7b853..d2e9a51c7f 100644 --- a/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py +++ b/test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py @@ -3,7 +3,7 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -30,8 +30,7 @@ from torchao.float8.float8_tensor import GemmInputRole from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -if not is_cuda_8_9: +if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index e9028c8712..311964d831 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -11,7 +11,11 @@ import pytest -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + is_sm_at_least_89, + is_sm_at_least_90, +) if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -34,9 +38,6 @@ from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config -is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - torch.manual_seed(0) @@ -176,7 +177,9 @@ def _test_impl(self, config: Float8LinearConfig) -> None: "scaling_type_grad_output", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], ) - @pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_89(), reason="requires SM89 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_config_params( self, @@ -199,7 +202,9 @@ def test_encoder_fw_bw_from_config_params( Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, ], ) - @pytest.mark.skipif(not is_cuda_9_0, reason="requires SM90 compatible machine") + @pytest.mark.skipif( + not is_sm_at_least_90(), reason="requires SM90 compatible machine" + ) @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw_from_recipe( self, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index ac2403d6dc..6aae8b2e31 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -19,7 +19,7 @@ from torchao.quantization.dynamic_quant import ( DynamicallyPerAxisQuantizedLinear, ) -from torchao.dtypes import TensorCoreTiledLayout +from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only, @@ -91,8 +91,10 @@ TORCH_VERSION_AT_LEAST_2_6, unwrap_tensor_subclass, is_fbcode, - benchmark_model + benchmark_model, + is_sm_at_least_90, ) +from torchao.dtypes.utils import is_device logger = logging.getLogger("INFO") @@ -104,7 +106,6 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def _int8wo_api(mod): if TORCH_VERSION_AT_LEAST_2_4: @@ -133,7 +134,10 @@ def _int8da_int8w_api(mod): change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: + if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False) + unwrap_tensor_subclass(mod) + elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(), set_inductor_config=False) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) @@ -775,7 +779,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_weight_only_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype @@ -795,7 +799,7 @@ def test_autoquantizable_flatten_unflatten(self): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): if dtype != torch.bfloat16: with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"): @@ -809,7 +813,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") - @unittest.skipIf(not is_H100, "Need H100 to run") + @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype @@ -881,7 +885,6 @@ def _test_lin_weight_subclass_api_impl( @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( _int8da_int8w_api, device, 35, test_dtype=dtype @@ -935,10 +938,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") + layout_list = [] + if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6: + layout_list.append(Int4CPULayout()) + else: + for inner_k_tiles in [4, 2]: + layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)) for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])): for groupsize in [64, 32]: - for inner_k_tiles in [4, 2]: - kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)} + for layout in layout_list: + kwargs = {"groupsize": groupsize, "layout": layout} def api(mod): kwargs_copy = kwargs.copy() @@ -1514,6 +1523,23 @@ def forward(self, x): assert not isinstance(model.lin1.weight.weight, AutoQuantizableLinearWeight) model(x_in) + @parameterized.expand(list(itertools.product(["cuda"], COMMON_DTYPES))) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_autoquant_min_sqnr(self, device, dtype): + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + out = model(example_input) + torchao.autoquant(model, min_sqnr=60) + out2 = model(example_input) + sqnr = SQNR(out, out2) + # without setting min_sqnr to 60, we get around 45-50 final sqnr + # setting min_sqnr for individual linear to be 60 allows us to achieve >= 50 final sqnr + self.assertTrue(sqnr >= 50, f"sqnr: {sqnr}") diff --git a/test/kernel/test_autotuner.py b/test/kernel/test_autotuner.py index 4ed0974172..3e8c9b0a04 100644 --- a/test/kernel/test_autotuner.py +++ b/test/kernel/test_autotuner.py @@ -13,10 +13,10 @@ import pytest import torch from parameterized import parameterized +from torchao.utils import is_sm_at_least_90 logging.basicConfig(level=logging.INFO) -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) class TestQuantFlow(unittest.TestCase): @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype): ("cuda", torch.float16), ] ) - @unittest.skipIf(not is_H100, "Needs H100") + @unittest.skipIf(not is_sm_at_least_90(), "Needs H100") def test_int_mm_float8(self, device, dtype): from torchao.kernel import intmm diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bc9b02deb5..4cac940313 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -20,11 +20,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -102,7 +99,7 @@ def test_linear_compile(elem_dtype, bias): Verify that compile does not change numerics of MX linear fw + bw """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") input_shape = (2, 4) grad_shape = (2, 6) @@ -173,7 +170,7 @@ def test_inference_compile_simple(elem_dtype): Smoke test for inference compile """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 964a575411..522785ae6f 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -24,11 +24,8 @@ ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 -# trying to outsmart flake8 -__has_cuda = torch.cuda.is_available() -IS_CUDA_GE_89 = __has_cuda and torch.cuda.get_device_capability() >= (8, 9) torch.manual_seed(2) @@ -225,7 +222,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not IS_CUDA_GE_89: + if not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") diff --git a/test/profiler/test_device_spec.py b/test/prototype/test_device_spec.py similarity index 97% rename from test/profiler/test_device_spec.py rename to test/prototype/test_device_spec.py index 1ede428fe0..dd159f5336 100644 --- a/test/profiler/test_device_spec.py +++ b/test/prototype/test_device_spec.py @@ -8,7 +8,7 @@ import torch from utils import patch_device -from torchao.profiler.device_spec import ( +from torchao.prototype.profiler.device_spec import ( _AVAILABLE_GPU_SPECS, CUDADeviceSpec, get_chip_name, diff --git a/test/profiler/test_performance_counter.py b/test/prototype/test_performance_counter.py similarity index 99% rename from test/profiler/test_performance_counter.py rename to test/prototype/test_performance_counter.py index 2cd1a33581..6ece2c6398 100644 --- a/test/profiler/test_performance_counter.py +++ b/test/prototype/test_performance_counter.py @@ -30,8 +30,8 @@ qkv_proj_io_check, ) -from torchao.profiler.device_spec import CUDADeviceSpec, DeviceSpec -from torchao.profiler.performance_counter import ( +from torchao.prototype.profiler.device_spec import CUDADeviceSpec, DeviceSpec +from torchao.prototype.profiler.performance_counter import ( CUDAPerformanceTimer, PerformanceCounterMode, PerformanceStats, diff --git a/test/prototype/test_sparse_api.py b/test/prototype/test_sparse_api.py index 757eb9f913..f3cdbe8386 100644 --- a/test/prototype/test_sparse_api.py +++ b/test/prototype/test_sparse_api.py @@ -50,6 +50,9 @@ def test_sparse(self): sparsify_(model, semi_sparse_weight()) sparse_result = model(input) + if compile: + model = torch.compile(model) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) diff --git a/test/profiler/utils.py b/test/prototype/utils.py similarity index 98% rename from test/profiler/utils.py rename to test/prototype/utils.py index 7b2b999809..8c402b8114 100644 --- a/test/profiler/utils.py +++ b/test/prototype/utils.py @@ -5,7 +5,7 @@ import torch -from torchao.profiler import PerformanceTimer +from torchao.prototype.profiler import PerformanceTimer @contextmanager diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 37709c4128..3eb9b0a2c5 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -3,13 +3,16 @@ import pytest # Skip entire test if triton is not available, otherwise CI failure -try: - import triton -except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) - -from bitsandbytes.functional import create_dynamic_map, quantize_blockwise, dequantize_blockwise +try: # noqa: F401 + import triton # noqa: F401 +except ImportError: # noqa: F401 + pytest.skip("triton is not installed", allow_module_level=True) # noqa: F401 import torch +from bitsandbytes.functional import ( + create_dynamic_map, + dequantize_blockwise, + quantize_blockwise, +) from torchao.prototype.galore.kernels import ( triton_dequant_blockwise, diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 0dcaaf9c8c..ebdf2281e0 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -1,4 +1,5 @@ import copy +import unittest import pytest import torch @@ -19,9 +20,12 @@ choose_qparams_and_quantize_affine_qqq, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode -import unittest -@unittest.skipIf(is_fbcode(), "Skipping the test in fbcode since we don't have TARGET file for kernels") + +@unittest.skipIf( + is_fbcode(), + "Skipping the test in fbcode since we don't have TARGET file for kernels", +) class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 29f833c9ab..8862d88b54 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -13,9 +13,8 @@ import torch import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 -from torchao.dtypes import ( - TensorCoreTiledLayout, -) + +from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4 from torchao.quantization.granularity import ( PerAxis, PerGroup, @@ -26,33 +25,26 @@ ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, -) from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, + Int4WeightOnlyQATLinear, Int8DynActInt4WeightQATLinear, - Int4WeightOnlyQATLinear ) from torchao.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, - _get_qmin_qmax, _GenericFakeQuantize, -) -from torchao.quantization.quant_api import ( - int4_weight_only, - quantize_, + _get_qmin_qmax, ) from torchao.quantization.quant_primitives import ( - fake_quantize_affine, MappingType, TorchAODType, ZeroPointDomain, + fake_quantize_affine, ) from torchao.quantization.unified import ( TwoStepQuantizer, @@ -65,17 +57,12 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, -) - -from torchao.quantization.GPTQ import ( - _replace_linear_8da4w, - _replace_linear_int4 ) # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() + class Sub(torch.nn.Module): def __init__(self): super().__init__() @@ -87,6 +74,7 @@ def example_inputs(self): def forward(self, x): return self.linear(x) + class M(torch.nn.Module): def __init__(self): super().__init__() @@ -103,6 +91,7 @@ def forward(self, x): x = self.linear2(x) return x + class M2(torch.nn.Module): def __init__(self): super().__init__() @@ -118,7 +107,9 @@ def forward(self, x): class TestQAT(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) @@ -132,20 +123,40 @@ def test_fake_quantize_per_channel_group(self): # fake quant op out = _fake_quantize_per_channel_group( - x, s, zp, qmin, qmax, group_size, + x, + s, + zp, + qmin, + qmax, + group_size, ) out.sum().backward() # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_channel_group( - x2, s, zp, qmin, qmax, torch.int8, group_size, + x2, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_channel_group( - out_ptq, s, zp, qmin, qmax, torch.int8, group_size, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_per_token(self): (qmin, qmax) = _get_qmin_qmax(8) @@ -161,10 +172,21 @@ def test_fake_quantize_per_token(self): # compare against PTQ ops out_ptq = torch.ops.quantized_decomposed.quantize_per_token( - x2, s, zp, qmin, qmax, torch.int8, + x2, + s, + zp, + qmin, + qmax, + torch.int8, ) out_ptq = torch.ops.quantized_decomposed.dequantize_per_token( - out_ptq, s, zp, qmin, qmax, torch.int8, torch.float32, + out_ptq, + s, + zp, + qmin, + qmax, + torch.int8, + torch.float32, ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) @@ -182,9 +204,10 @@ def _set_ptq_weight( WeightOnlyInt4Linear, ) from torchao.quantization.qat.linear import ( - Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, + Int8DynActInt4WeightQATLinear, ) + n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) group_size = qat_linear.weight_fake_quantizer.config.group_size @@ -193,7 +216,13 @@ def _set_ptq_weight( fp32_weight = qat_linear.weight (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + fp32_weight, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, ) ptq_linear.weight = q_weight ptq_linear.scales = s @@ -201,28 +230,39 @@ def _set_ptq_weight( elif isinstance(ptq_linear, WeightOnlyInt4Linear): assert isinstance(qat_linear, Int4WeightOnlyQATLinear) (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - qat_linear.weight, n_bit, group_size, + qat_linear.weight, + n_bit, + group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), qat_linear.inner_k_tiles, + q_weight.to("cuda"), + qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight ptq_linear.scales_and_zeros = scales_and_zeros else: raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_linear(self): - from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.qat.linear import Int8DynActInt4WeightQATLinear group_size = 128 torch.manual_seed(self.SEED) qat_linear = Int8DynActInt4WeightQATLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) ptq_linear = Int8DynActInt4WeightLinear( - 256, 688, bias=False, groupsize=group_size, + 256, + 688, + bias=False, + groupsize=group_size, ) # Force the weights to be the same @@ -236,10 +276,12 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer(self): - from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer + from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer group_size = 16 torch.manual_seed(self.SEED) @@ -268,9 +310,13 @@ def test_qat_8da4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer @@ -282,7 +328,9 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -341,7 +389,9 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -363,8 +413,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): nn_model.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight) # Simulate training for both models - optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) - optimizer2 = torch.optim.SGD(qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer1 = torch.optim.SGD( + nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) + optimizer2 = torch.optim.SGD( + qat_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn1 = torch.nn.CrossEntropyLoss() loss_fn2 = torch.nn.CrossEntropyLoss() example_inputs = nn_model.example_inputs() @@ -382,9 +436,15 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): optimizer2.step() # After 1 training step, weights should match exactly - torch.testing.assert_close(nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) - torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + torch.testing.assert_close( + nn_model.linear1.weight, qat_model.linear1.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0 + ) + torch.testing.assert_close( + nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0 + ) def _test_qat_quantized_gradients(self, quantizer): """ @@ -394,7 +454,9 @@ def _test_qat_quantized_gradients(self, quantizer): torch.manual_seed(self.SEED) m = M() model = quantizer.prepare(m) - optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) + optimizer = torch.optim.SGD( + model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5 + ) loss_fn = torch.nn.CrossEntropyLoss() # Simulate training @@ -426,13 +488,18 @@ def _test_qat_quantized_gradients(self, quantizer): optimizer.step() current_step += 1 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_8da4w_quantizer_gradients(self): from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer + quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_generic_fake_quantize(self): """ Test that the generic fake quantize used in 8da4w QAT matches @@ -443,7 +510,9 @@ def test_qat_generic_fake_quantize(self): py_input = torch.randn(16, 64).float().requires_grad_() py_s = torch.randn(16).float() py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) - py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax) + py_out = torch.fake_quantize_per_channel_affine( + py_input, py_s, py_zp, 0, qmin, qmax + ) py_out.sum().backward() ao_input = copy.deepcopy(py_input) @@ -451,7 +520,9 @@ def test_qat_generic_fake_quantize(self): block_size = (1, ao_input.shape[-1]) ao_s = copy.deepcopy(py_s) ao_zp = copy.deepcopy(py_zp) - ao_out = _GenericFakeQuantize.apply(ao_input, block_size, ao_s, ao_zp, qmin, qmax) + ao_out = _GenericFakeQuantize.apply( + ao_input, block_size, ao_s, ao_zp, qmin, qmax + ) ao_out.sum().backward() torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) @@ -485,10 +556,14 @@ def test_qat_4w_primitives(self): # PTQ (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(device), inner_k_tiles, + q_weight.to(device), + inner_k_tiles, ) ptq_out = torch.ops.aten._weight_int4pack_mm( x, q_weight, group_size, scales_and_zeros @@ -497,9 +572,12 @@ def test_qat_4w_primitives(self): # QAT block_size = (1, group_size) quant_min = 0 - quant_max = 2 ** n_bit - 1 + quant_max = 2**n_bit - 1 scales, zero_points = get_groupwise_affine_qparams( - weight, n_bit, group_size, scales_precision, + weight, + n_bit, + group_size, + scales_precision, ) w_fq = fake_quantize_affine( weight, @@ -509,27 +587,37 @@ def test_qat_4w_primitives(self): torch.int32, quant_min, quant_max, - zero_point_domain = ZeroPointDomain.FLOAT, + zero_point_domain=ZeroPointDomain.FLOAT, ) qat_out = torch.nn.functional.linear(x, w_fq) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear + from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 device = torch.device("cuda") dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) ptq_linear = WeightOnlyInt4Linear( - 256, 688, bias=False, groupsize=group_size, device=device, + 256, + 688, + bias=False, + groupsize=group_size, + device=device, ) # Force the weights to be the same @@ -543,17 +631,22 @@ def test_qat_4w_linear(self): ptq_out = ptq_linear(x2) self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_quantizer_gradients(self): from torchao.quantization.qat import Int4WeightOnlyQATQuantizer + quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 @@ -563,10 +656,12 @@ def test_qat_4w_quantizer(self): m = M().to(device).to(dtype) m2 = copy.deepcopy(m) qat_quantizer = Int4WeightOnlyQATQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, inner_k_tiles=inner_k_tiles, + groupsize=group_size, + inner_k_tiles=inner_k_tiles, ) qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) @@ -589,13 +684,16 @@ def test_qat_4w_quantizer(self): converted_state_dict = converted_model.state_dict() self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) for k in ptq_state_dict.keys(): - torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + torch.testing.assert_close( + ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0 + ) class _MyQATQuantizer(TwoStepQuantizer): """ Dummy quantizer that attaches a certain value to each nn.Linear's `_temp_quantizer_values` attribute. """ + ATTR_NAME = "_temp_quantizer_values" def __init__(self, value: str): @@ -626,19 +724,24 @@ def test_composable_qat_quantizer(self): self.assertEqual(values_list, ["quantizer1", "quantizer2"]) composable_quantizer.convert(model) values_list = getattr(model.linear1, self._MyQATQuantizer.ATTR_NAME) - self.assertEqual(values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"]) + self.assertEqual( + values_list, ["quantizer1", "quantizer2", "quantizer1", "quantizer2"] + ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_4w_embedding(self): from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer + model = M2() x = model.example_inputs() - out = model(*x) + model(*x) quantizer = Int4WeightOnlyEmbeddingQATQuantizer() prepared = quantizer.prepare(model) - prepared_out = prepared(*x) + prepared(*x) converted = quantizer.convert(model) - converted_out = converted(*x) + converted(*x) def test_fake_quantize_config_granularity(self): """ @@ -685,7 +788,9 @@ def test_fake_quantize_config_granularity_error_cases(self): Test incorrect settings of `FakeQuantizeConfig`'s granularity. """ # no granularity provided - with self.assertRaisesRegex(ValueError, "`granularity` or `group_size` must be set"): + with self.assertRaisesRegex( + ValueError, "`granularity` or `group_size` must be set" + ): FakeQuantizeConfig(torch.int8) # group_size with conflicting granularity @@ -718,8 +823,12 @@ def test_fake_quantize_config_mapping_type(self): """ # symmetric symmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token") - symmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=True) - symmetric_config3 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC) + symmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=True + ) + symmetric_config3 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC + ) self.assertEqual(symmetric_config1.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config2.mapping_type, MappingType.SYMMETRIC) self.assertEqual(symmetric_config3.mapping_type, MappingType.SYMMETRIC) @@ -728,8 +837,12 @@ def test_fake_quantize_config_mapping_type(self): self.assertTrue(symmetric_config3.is_symmetric) # asymmetric - asymmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - asymmetric_config2 = FakeQuantizeConfig(torch.int8, "per_token", MappingType.ASYMMETRIC) + asymmetric_config1 = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ) + asymmetric_config2 = FakeQuantizeConfig( + torch.int8, "per_token", MappingType.ASYMMETRIC + ) self.assertEqual(asymmetric_config1.mapping_type, MappingType.ASYMMETRIC) self.assertEqual(asymmetric_config2.mapping_type, MappingType.ASYMMETRIC) self.assertFalse(asymmetric_config1.is_symmetric) @@ -743,11 +856,15 @@ def test_fake_quantize_config_mapping_type(self): # bad config1: both mapping_type and is_symmetric are set msg = "Cannot set both `mapping_type` and `is_symmetric`" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False + ) # bad config2: not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig(torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR) + FakeQuantizeConfig( + torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR + ) def test_fake_quantize_config_dtype(self): """ @@ -781,7 +898,9 @@ def test_fake_quantize_config_dtype(self): FakeQuantizeConfig(TorchAODType.INT7, "per_token") FakeQuantizeConfig(torch.int8, "per_token") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_8da4w(self): """ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`. @@ -792,7 +911,9 @@ def test_fake_quantized_linear_8da4w(self): 256, 688, bias=False, - activation_config=FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False), + activation_config=FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ), weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), ) @@ -801,7 +922,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant. """ # activations - (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) + (s, zp) = _choose_qparams_per_token_asymmetric( + x, torch.float32, torch.int32 + ) (qmin, qmax) = _get_qmin_qmax(8) x_fq = _fake_quantize_per_token(x, s, zp, qmin, qmax) @@ -809,7 +932,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.linear(x_fq, w_fq) # Compare linear values @@ -820,7 +945,9 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = linear_forward_8da4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_linear_4w(self): """ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. @@ -849,7 +976,13 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_groupwise_affine_qparams(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) w_fq = _fake_quantize_per_channel_group( - weight, s, zp, qmin, qmax, group_size, zero_point_domain=ZeroPointDomain.FLOAT, + weight, + s, + zp, + qmin, + qmax, + group_size, + zero_point_domain=ZeroPointDomain.FLOAT, ) return F.linear(x, w_fq) @@ -860,50 +993,78 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: fq_out = fq_linear(x) baseline_out = linear_forward_4w(x2, fq_linear.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_8da4w(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) - _replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True) - assert(isinstance(module[0], Int8DynActInt4WeightQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) + _replace_linear_8da4w( + module, + 256, + False, + torch.float32, + torch.float32, + Int8DynActInt4WeightQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int8DynActInt4WeightQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_replace_linear_int4(self): - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=True) - ]) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=True)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance(module[0], torch.nn.Linear)) - module = torch.nn.ModuleList([ - torch.nn.Linear(in_features=256, out_features=50, bias=False) - ]) + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance( + module[0], torch.nn.Linear + ) + module = torch.nn.ModuleList( + [torch.nn.Linear(in_features=256, out_features=50, bias=False)] + ) _replace_linear_int4( - module, - 256, + module, + 256, 8, - padding_allowed=True, - precision=torch.bfloat16, - scales_precision=torch.bfloat16, - linear_class=Int4WeightOnlyQATLinear, - copy_weights=True) - assert(isinstance(module[0], Int4WeightOnlyQATLinear)) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + padding_allowed=True, + precision=torch.bfloat16, + scales_precision=torch.bfloat16, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + assert isinstance(module[0], Int4WeightOnlyQATLinear) + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantized_embedding_4w(self): """ Test that we can express int4 per group symmetric weight only fake quantization @@ -926,7 +1087,9 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: (s, zp) = get_group_qparams_symmetric(weight, 4, group_size, torch.float32) zp = zp.to(torch.int32) (qmin, qmax) = _get_qmin_qmax(4) - w_fq = _fake_quantize_per_channel_group(weight, s, zp, qmin, qmax, group_size) + w_fq = _fake_quantize_per_channel_group( + weight, s, zp, qmin, qmax, group_size + ) return F.embedding(x, w_fq) # Compare embedding values @@ -937,13 +1100,15 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: baseline_out = embedding_forward_4w(x2, fq_embedding.weight) torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_qat_prototype_bc(self): """ Just to make sure we can import all the old prototype paths. We will remove this test in the near future when we actually break BC. """ - from torchao.quantization.prototype.qat import ( + from torchao.quantization.prototype.qat import ( # noqa: F401, F811, I001 disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, @@ -954,7 +1119,7 @@ def test_qat_prototype_bc(self): Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, ) - from torchao.quantization.prototype.qat._module_swap_api import ( + from torchao.quantization.prototype.qat._module_swap_api import ( # noqa: F401, F811 disable_4w_fake_quant_module_swap, enable_4w_fake_quant_module_swap, disable_8da4w_fake_quant_module_swap, @@ -962,24 +1127,24 @@ def test_qat_prototype_bc(self): Int4WeightOnlyQATQuantizerModuleSwap, Int8DynActInt4WeightQATQuantizerModuleSwap, ) - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( # noqa: F401, F811 AffineFakeQuantizedTensor, to_affine_fake_quantized, ) - from torchao.quantization.prototype.qat.api import ( + from torchao.quantization.prototype.qat.api import ( # noqa: F401, F811 ComposableQATQuantizer, FakeQuantizeConfig, ) - from torchao.quantization.prototype.qat.embedding import ( + from torchao.quantization.prototype.qat.embedding import ( # noqa: F401, F811 FakeQuantizedEmbedding, Int4WeightOnlyEmbeddingQATQuantizer, Int4WeightOnlyEmbedding, Int4WeightOnlyQATEmbedding, ) - from torchao.quantization.prototype.qat.fake_quantizer import ( + from torchao.quantization.prototype.qat.fake_quantizer import ( # noqa: F401, F811 FakeQuantizer, ) - from torchao.quantization.prototype.qat.linear import ( + from torchao.quantization.prototype.qat.linear import ( # noqa: F401, F811 disable_4w_fake_quant, disable_8da4w_fake_quant, enable_4w_fake_quant, @@ -991,5 +1156,6 @@ def test_qat_prototype_bc(self): Int8DynActInt4WeightQATQuantizer, ) + if __name__ == "__main__": unittest.main() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 458cd07810..eb5f1337d1 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -6,81 +6,86 @@ # mypy: ignore-errors # This test takes a long time to run +import copy +import gc +import tempfile import unittest +from pathlib import Path + import torch -import os from torch.ao.quantization.quantize_pt2e import ( - prepare_pt2e, convert_pt2e, + prepare_pt2e, ) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( XNNPACKQuantizer, get_symmetric_quantization_config, ) +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import TestCase -import torchao +from torchao import quantize_ +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import ( AffineQuantizedTensor, ) from torchao.quantization import ( LinearActivationQuantizedTensor, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) -from torchao.quantization.subclass import ( - Int8WeightOnlyQuantizedLinearWeight, - Int4WeightOnlyQuantizedLinearWeight, -) -from torchao import quantize_ from torchao.quantization.quant_api import ( - _replace_with_custom_fn_if_matches_filter, Quantizer, TwoStepQuantizer, - int8_dynamic_activation_int4_weight, + _replace_with_custom_fn_if_matches_filter, int4_weight_only, - int8_weight_only, + int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, + int8_weight_only, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) +from torchao.quantization.subclass import ( + Int4WeightOnlyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + unwrap_tensor_subclass, ) -from pathlib import Path -from torchao._models.llama.tokenizer import get_tokenizer -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao.utils import unwrap_tensor_subclass -import copy -import tempfile -import gc -from torch.testing._internal.common_utils import TestCase -from torch.testing._internal import common_utils def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs).module() - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) return m + def capture_and_prepare(model, example_inputs): m = torch.export.export(model, example_inputs) - quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) + quantizer = XNNPACKQuantizer().set_global( + get_symmetric_quantization_config(is_dynamic=True) + ) m = prepare_pt2e(m, quantizer) # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this m(*example_inputs) return m -class XNNPackDynamicQuantizer(TwoStepQuantizer): +class XNNPackDynamicQuantizer(TwoStepQuantizer): def prepare(self, model: torch.nn.Module) -> torch.nn.Module: _replace_with_custom_fn_if_matches_filter( model, - lambda linear_mod: capture_and_prepare(linear_mod, (torch.randn(1, linear_mod.in_features))), + lambda linear_mod: capture_and_prepare( + linear_mod, (torch.randn(1, linear_mod.in_features)) + ), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) return model @@ -93,11 +98,13 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: ) return model + class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: quantize_(model, int8_dynamic_activation_int8_weight()) return model + class ToyLinearModel(torch.nn.Module): def __init__(self, m=64, n=32, k=64): super().__init__() @@ -105,7 +112,11 @@ def __init__(self, m=64, n=32, k=64): self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) def forward(self, x): x = self.linear1(x) @@ -118,9 +129,11 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs The deprecated implementation for int8 dynamic quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _in_features_greater_than_16 - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import ( + _get_subclass_inserter, + _in_features_greater_than_16, + _is_linear, + ) from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight if filter_fn is None: @@ -129,37 +142,49 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs ) _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + model, + _get_subclass_inserter( + Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs + ), + filter_fn, ) + def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass): def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): """ The deprecated implementation for weight only quant API, used as a reference for numerics and performance """ - from torchao.quantization.quant_api import _is_linear - from torchao.quantization.quant_api import _get_subclass_inserter + from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear filter_fn = kwargs.pop("filter_fn", _is_linear) _replace_with_custom_fn_if_matches_filter( model, - _get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs), + _get_subclass_inserter( + deprecated_tenosr_subclass, enable_parametrization=True, **kwargs + ), filter_fn, ) return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) + +_ref_change_linear_weights_to_int8_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) +) +_ref_change_linear_weights_to_int4_woqtensors = ( + _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) +) + class TestQuantFlow(TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() quantize_(m, int8_dynamic_activation_int8_weight()) - quantized = m(*example_inputs) + m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -182,7 +207,9 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) - @unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!") + @unittest.skip( + "FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!" + ) def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): quantizer = TorchCompileDynamicQuantizer() m = ToyLinearModel().eval() @@ -196,10 +223,8 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): - from torchao.quantization.quant_api import ( - change_linear_weights_to_int8_woqtensors, - ) m = ToyLinearModel().eval().cpu() + def api(model): quantize_(model, int8_weight_only()) unwrap_tensor_subclass(model) @@ -223,10 +248,12 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch verion is 2.3 or lower" + ) def test_8da4w_quantizer(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) m = ToyLinearModel().eval() @@ -242,8 +269,9 @@ def test_8da4w_quantizer(self): # https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_8da4w_gptq_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer from torchao._models._eval import InputRecorder, TransformerEvalWrapper + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer + # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -268,16 +296,20 @@ def test_8da4w_gptq_quantizer(self): input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int8DynActInt4WeightGPTQQuantizer( blocksize, @@ -287,7 +319,7 @@ def test_8da4w_gptq_quantizer(self): ) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -298,15 +330,17 @@ def test_8da4w_gptq_quantizer(self): 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.88, ( - f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.88 + ), f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower" + ) def test_8da4w_quantizer_eval(self): - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer precision = torch.bfloat16 device = "cpu" @@ -325,7 +359,7 @@ def test_8da4w_quantizer_eval(self): quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision) q_model = quantizer.quantize(model) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( q_model, tokenizer, q_model.config.block_size, @@ -335,14 +369,18 @@ def test_8da4w_quantizer_eval(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4_weight_only(self): + from torchao._models._eval import ( + MultiTensorInputRecorder, + TransformerEvalWrapper, + ) from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer - from torchao._models._eval import MultiTensorInputRecorder, TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -367,18 +405,21 @@ def test_gptq_quantizer_int4_weight_only(self): calibration_seq_length = 100 input_prep_func = prepare_inputs_for_model pad_calibration_inputs = False - inputs = MultiTensorInputRecorder( - tokenizer, - calibration_seq_length, - input_prep_func, - pad_calibration_inputs, - model.config.vocab_size, - device="cpu", - ).record_inputs( - calibration_tasks, - calibration_limit, - ).get_inputs() - + inputs = ( + MultiTensorInputRecorder( + tokenizer, + calibration_seq_length, + input_prep_func, + pad_calibration_inputs, + model.config.vocab_size, + device="cpu", + ) + .record_inputs( + calibration_tasks, + calibration_limit, + ) + .get_inputs() + ) quantizer = Int4WeightOnlyGPTQQuantizer( blocksize, @@ -398,14 +439,15 @@ def test_gptq_quantizer_int4_weight_only(self): ["wikitext"], None, ) - assert result['results']['wikitext']['word_perplexity,none'] < 7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_quantizer_int4_weight_only(self): - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao._models._eval import TransformerEvalWrapper + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -435,13 +477,14 @@ def test_quantizer_int4_weight_only(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -456,7 +499,7 @@ def test_eval_wrapper(self): tokenizer_path, "Llama-2-7b-chat-hf", ) - result=TransformerEvalWrapper( + result = TransformerEvalWrapper( model, tokenizer, model.config.block_size, @@ -466,17 +509,20 @@ def test_eval_wrapper(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none']<7.77, ( - f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 7.77 + ), f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper + precision = torch.bfloat16 device = "cuda" - checkpoint_path = Path(".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth") + checkpoint_path = Path( + ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" + ) model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) model.load_state_dict(checkpoint, assign=True) @@ -498,30 +544,43 @@ def test_eval_wrapper_llama3(self): ["wikitext"], 1, ) - assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( - f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" - ) + assert ( + result["results"]["wikitext"]["word_perplexity,none"] < 8.24 + ), f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") - @common_utils.parametrize("mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR]) + @common_utils.parametrize( + "mapping_type", [MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR] + ) def test_quantized_tensor_subclass_8da4w(self, mapping_type): group_size = 32 m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size, mapping_type=mapping_type)) + quantize_( + m, + int8_dynamic_activation_int4_weight( + group_size=group_size, mapping_type=mapping_type + ), + ) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=mapping_type) + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=group_size, mapping_type=mapping_type + ) m_copy = quantizer.quantize(m_copy) assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) @@ -552,7 +611,6 @@ def test_quantized_tensor_subclass_int4(self): self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_wo(self): @@ -568,13 +626,11 @@ def test_quantized_tensor_subclass_int8_wo(self): # reference _ref_change_linear_weights_to_int8_woqtensors(m_copy) - res = m(*example_inputs) ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.5 and below") @@ -583,13 +639,19 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") + example_inputs = m.example_inputs( + batch_size=20, dtype=torch.bfloat16, device="cuda" + ) quantize_(m, int8_dynamic_activation_int8_weight()) assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) - assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) - assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance( + m.linear1.weight.original_weight_tensor, AffineQuantizedTensor + ) + assert isinstance( + m.linear2.weight.original_weight_tensor, AffineQuantizedTensor + ) # reference _ref_change_linear_weights_to_int8_dqtensors(m_copy) @@ -601,6 +663,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): # workaround for export path from torchao.utils import unwrap_tensor_subclass + m_unwrapped = unwrap_tensor_subclass(m) m = torch.export.export(m_unwrapped, example_inputs).module() @@ -630,12 +693,10 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") quantize_(m, int8_weight_only()) @@ -654,7 +715,6 @@ def test_int4wo_quantized_model_to_device(self): devices = ["cuda", "cuda:0"] for device in devices: m = ToyLinearModel().eval().to(torch.bfloat16).to(device) - m_copy = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) quantize_(m, int4_weight_only()) @@ -678,7 +738,7 @@ def test_quantized_tensor_subclass_save_load_map_location(self): f.seek(0) state_dict = torch.load(f.name, map_location="cpu", mmap=True) - with torch.device('meta'): + with torch.device("meta"): m_copy = ToyLinearModel().eval() m_copy.load_state_dict(state_dict, assign=True) @@ -710,12 +770,13 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) -class TestMultiTensorFlow(TestCase): +class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_add_tensors(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) tensor2 = torch.randn(3, 3) mt = MultiTensor(tensor1) @@ -728,6 +789,7 @@ def test_multitensor_add_tensors(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_pad_unpad(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.randn(3, 3) mt = MultiTensor(tensor1) mt.pad_to_length(3) @@ -739,14 +801,13 @@ def test_multitensor_pad_unpad(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_multitensor_inplace_operation(self): from torchao.quantization.GPTQ_MT import MultiTensor + tensor1 = torch.ones(3, 3) mt = MultiTensor(tensor1) mt += 1 # In-place addition self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2))) - - common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 4e0663eb87..a3fef29fea 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -7,25 +7,27 @@ # mypy: ignore-errors # This test takes a long time to run import unittest + import torch + +from torchao.dtypes.utils import is_device from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + choose_qparams_affine, + dequantize_affine, fake_quantize_affine, fake_quantize_affine_cachemask, quantize_affine, - dequantize_affine, - choose_qparams_affine, - MappingType, - ZeroPointDomain, ) + # TODO: remove test for utils? from torchao.quantization.utils import ( get_group_qparams_symmetric, - get_groupwise_affine_qparams, - groupwise_affine_quantize_tensor_from_qparams, groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor_from_qparams, quantize_activation_per_token_absmax, ) - from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -37,6 +39,7 @@ _SEED = 1234 torch.manual_seed(_SEED) + # Helper function to run a function twice # and verify that the result is the same. # Adds some verification to avoid side effects. @@ -47,9 +50,12 @@ def check_idempotent(self, fn, *args, **kwargs): output0 = fn(*args, **kwargs) assert torch.is_tensor(output0) output1 = fn(*args, **kwargs) - self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.") + self.assertTrue( + torch.equal(output0, output1), f"Expected given function {fn} to be idempotent." + ) return output1 + # Legacy tinygemm ops def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): if groupsize > w.shape[-1]: @@ -70,6 +76,7 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1 dtype=dtype ).reshape(w.shape[0], -1) + def _groupwise_affine_quantize_tensor_from_qparams( w, scales, @@ -102,10 +109,12 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) if TORCH_VERSION_AT_LEAST_2_5: - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 + def _groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, scales, @@ -136,7 +145,9 @@ def _groupwise_affine_dequantize_tensor_from_qparams( class TestQuantPrimitives(unittest.TestCase): SEED = 123 - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_get_group_qparams_symmetric(self): """ Test that `get_group_qparams_symmetric` produces the exact same scales as @@ -145,7 +156,6 @@ def test_get_group_qparams_symmetric(self): n_bit = 4 qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1 - eps = torch.finfo(torch.float32).eps groupsize = 256 torch.manual_seed(self.SEED) weight = torch.randn(100, 256).to(torch.float16) @@ -158,14 +168,16 @@ def test_get_group_qparams_symmetric(self): quant_max=qmax, # This is needed to ensure `min_val` and `max_val` are fp16, # otherwise they default to fp32 and the qparams will be slightly off - factory_kwargs={"dtype": torch.float16} + factory_kwargs={"dtype": torch.float16}, ) obs(weight) (scale_obs, _) = obs.calculate_qparams() scale_obs = scale_obs.reshape(weight.shape[0], -1) # assert that scales are identical - (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16) + (scale_ao, _) = get_group_qparams_symmetric( + weight, n_bit, groupsize, precision=torch.float16 + ) torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0) def test_choose_qparams_group_sym(self): @@ -178,9 +190,19 @@ def test_choose_qparams_group_sym(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) @@ -195,13 +217,26 @@ def test_choose_qparams_group_sym_no_clipping_err(self): block_size = (1, 2) eps = torch.finfo(torch.float32).eps precision = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=eps, + scale_dtype=precision, + zero_point_dtype=precision, + ) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type) + scale_ref, zp_ref = get_group_qparams_symmetric( + input, n_bit=8, groupsize=2, precision=precision, mapping_type=mapping_type + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) @@ -209,11 +244,29 @@ def test_choose_qparams_token_asym(self): dtype = torch.int8 block_size = (1, 10) if TORCH_VERSION_AT_LEAST_2_6: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float64, + zero_point_dtype=torch.int64, + ) else: - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + eps=torch.finfo(torch.float32).eps, + ) - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) + scale_ref, zp_ref = ( + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + input, dtype + ) + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -227,12 +280,15 @@ def test_choose_qparams_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() @@ -246,18 +302,24 @@ def test_choose_qparams_tensor_sym(self): dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=eps + ) quant_min = -128 quant_max = 127 - scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype) + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric( + input, quant_min, quant_max, eps, dtype + ) scale_ref = scale_ref.squeeze() zp_ref = zp_ref.squeeze() self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -270,21 +332,35 @@ def test_quantize_activation_per_token_abs_max(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(scale, scale_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_zero_input(self): input = torch.zeros(10, 10) # make sure it still works quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_quantize_activation_per_token_abs_max_dtype(self): input = torch.zeros(10, 10, dtype=torch.bfloat16) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) @@ -298,18 +374,30 @@ def test_quantize_activation_per_token_abs_max_dtype(self): quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) self.assertTrue(scale_ref.dtype, torch.float32) - - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC dtype = torch.int8 block_size = (1, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) group_size = 2 quant_min = -128 @@ -318,23 +406,43 @@ def test_quantize_dequantize_group_sym(self): input, scale, zero_point, quant_min, quant_max, torch.int8, group_size ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + group_size, + output_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 1) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) output_dtype = torch.float32 quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) axis = 1 quant_min = -128 @@ -343,12 +451,21 @@ def test_quantize_dequantize_channel_asym(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) @@ -356,32 +473,61 @@ def test_quantize_dequantize_tensor_asym(self): dtype = torch.int8 block_size = (10, 10) output_dtype = torch.float32 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=output_dtype, + ) - axis = 1 quant_min = -128 quant_max = 127 quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor( input, scale, zero_point, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor( - quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype + quantized_ref, + scale, + zero_point, + quant_min, + quant_max, + torch.int8, + out_dtype=output_dtype, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 1, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) axis = 2 quant_min = -128 @@ -390,20 +536,40 @@ def test_quantize_dequantize_channel_asym_4d(self): input, scale, zero_point, axis, quant_min, quant_max, torch.int8 ) dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( - quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32 + quantized_ref, + scale, + zero_point, + axis, + quant_min, + quant_max, + torch.int8, + out_dtype=torch.float32, ) self.assertTrue(torch.equal(quantized, quantized_ref)) self.assertTrue(torch.equal(dequantized, dequantized_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower" + ) def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (3, 3, 2, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps + ) quantized = quantize_affine(input, block_size, scale, zero_point, dtype) - dequantized = check_idempotent(self, dequantize_affine, quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + dequantized = check_idempotent( + self, + dequantize_affine, + quantized, + block_size, + scale, + zero_point, + dtype, + output_dtype=torch.float32, + ) # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02) @@ -412,11 +578,15 @@ def test_choose_qparams_tensor_asym_eps(self): mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) eps = torch.finfo(torch.float32).eps self.assertEqual(scale, eps) - @unittest.skipIf(not torch.cuda.is_available(), "skipping when cuda is not available") + @unittest.skipIf( + not torch.cuda.is_available(), "skipping when cuda is not available" + ) def test_get_group_qparams_symmetric_memory(self): """Check the memory usage of the op""" weight = torch.randn(1024, 1024).to(device="cuda") @@ -428,18 +598,20 @@ def test_get_group_qparams_symmetric_memory(self): self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use) def test_raises(self): - """Make sure some errors are raised when user requested an unsupported type of quantization - """ + """Make sure some errors are raised when user requested an unsupported type of quantization""" input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (10, 10) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) - + scale, zero_point = choose_qparams_affine( + input, mapping_type, block_size, dtype + ) # make sure we can't quantize int32 tensors: with self.assertRaisesRegex(AssertionError, "Unsupported input dtype:"): - _ = quantize_affine(input.to(torch.int32), block_size, scale, zero_point, dtype) + _ = quantize_affine( + input.to(torch.int32), block_size, scale, zero_point, dtype + ) # block_size and scale/zero_point shape mismatch block_size = (1, 1) @@ -458,7 +630,10 @@ def test_not_preserve_zero_not_supported(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - with self.assertRaisesRegex(ValueError, "preserve_zero == False is not supported for symmetric quantization"): + with self.assertRaisesRegex( + ValueError, + "preserve_zero == False is not supported for symmetric quantization", + ): choose_qparams_affine( input, mapping_type, @@ -472,11 +647,12 @@ def test_not_preserve_zero_not_supported(self): preserve_zero=False, ) - def test_get_groupwise_affine_qparams(self): input = torch.randn(10, 256) n_bit = 4 - scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) + scale_ref, zero_point_ref = _get_groupwise_affine_qparams( + input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16 + ) mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 @@ -486,20 +662,19 @@ def test_get_groupwise_affine_qparams(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - scale, zero_point = \ - choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - quant_min, - quant_max, - eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=False, + zero_point_domain=ZeroPointDomain.FLOAT, + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zero_point_ref)) @@ -511,8 +686,12 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self): n_bit = 4 groupsize = 128 - w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) @@ -524,15 +703,25 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): groupsize = 128 if TORCH_VERSION_AT_LEAST_2_5: - input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize) + input_tmp = input + if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize + ) else: - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) - w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine(self): input = torch.randn(10, 10) @@ -544,14 +733,31 @@ def test_fake_quantize_affine(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) - fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + fake_quantized = fake_quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) torch.testing.assert_close(dequantized, fake_quantized) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) def test_fake_quantize_affine_cachemask(self): input = torch.randn(10, 10) @@ -563,16 +769,36 @@ def test_fake_quantize_affine_cachemask(self): eps = 1e-5 quant_min = -127 quant_max = 127 - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps=eps, + scale_dtype=torch.float, + ) - quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) - dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) + quantized = quantize_affine( + input, block_size, scale, zero_point, dtype, quant_min, quant_max + ) + dequantized = dequantize_affine( + quantized, block_size, scale, zero_point, dtype, quant_min, quant_max + ) (fake_quantized, mask) = fake_quantize_affine_cachemask( - input, block_size, scale, zero_point, dtype, quant_min, quant_max, + input, + block_size, + scale, + zero_point, + dtype, + quant_min, + quant_max, ) expected_mask = torch.full(input.shape, True) torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index f2d5686fd3..e3f5626d49 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -1,19 +1,18 @@ -import logging -import unittest import copy +import unittest import torch -import torch.nn.functional as F from torch import nn from torch.testing._internal.common_utils import TestCase from torchao.sparsity.training import ( + SemiSparseLinear, swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, - SemiSparseLinear ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_fbcode + class ToyModel(nn.Module): def __init__(self): super().__init__() @@ -26,8 +25,8 @@ def forward(self, x): x = self.linear2(x) return x -class TestRuntimeSemiStructuredSparsity(TestCase): +class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @@ -35,6 +34,7 @@ class TestRuntimeSemiStructuredSparsity(TestCase): def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() model = ToyModel().half().cuda() @@ -42,7 +42,9 @@ def test_runtime_weight_sparsification(self): for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): - sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense() + sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort( + mod.weight.detach() + ).to_dense() mod.weight = nn.Parameter(sparse) dense_result = model(input) @@ -62,8 +64,12 @@ def test_runtime_weight_sparsification(self): sparse_result.backward(grad) # check grad - assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1) - assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1) + assert torch.allclose( + model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1 + ) + assert torch.allclose( + model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1 + ) # check that swap back works swap_semi_sparse_linear_with_linear(model_c) @@ -77,6 +83,7 @@ def test_runtime_weight_sparsification(self): def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() model = ToyModel().half().cuda() @@ -84,7 +91,9 @@ def test_runtime_weight_sparsification_compile(self): for name, mod in model.named_modules(): if isinstance(mod, torch.nn.Linear): - sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(mod.weight.detach()).to_dense() + sparse = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort( + mod.weight.detach() + ).to_dense() mod.weight = nn.Parameter(sparse) model = torch.compile(model, fullgraph=True) @@ -106,8 +115,12 @@ def test_runtime_weight_sparsification_compile(self): sparse_result.backward(grad) # check grad - assert torch.allclose(model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1) - assert torch.allclose(model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1) + assert torch.allclose( + model.linear1.weight.grad, model_c.linear1.weight.grad, rtol=1e-1, atol=1e-1 + ) + assert torch.allclose( + model.linear2.weight.grad, model_c.linear2.weight.grad, rtol=1e-1, atol=1e-1 + ) # check that swap back works swap_semi_sparse_linear_with_linear(model_c) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 173afd7dab..4da7304a24 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -1,28 +1,24 @@ -import torch import copy -import pytest +import pytest +import torch from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + from torchao.dtypes import MarlinSparseLayout -from torchao.sparsity.sparse_api import apply_fake_sparsity from torchao.quantization.quant_api import int4_weight_only, quantize_ -from torchao.sparsity.marlin import ( - pack_to_marlin_24, - unpack_from_marlin_24, - inject_24 -) from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, choose_qparams_affine, quantize_affine, - ZeroPointDomain, - MappingType, ) +from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class SparseMarlin24(TestCase): - def setUp(self): super().setUp() torch.manual_seed(0) @@ -53,7 +49,9 @@ def test_quant_sparse_marlin_layout_eager(self): quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + assert torch.allclose( + dense_result, sparse_result, atol=3e-1 + ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @@ -71,7 +69,9 @@ def test_quant_sparse_marlin_layout_compile(self): self.model.forward = torch.compile(self.model.forward, fullgraph=True) sparse_result = self.model(self.input) - assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close" + assert torch.allclose( + dense_result, sparse_result, atol=3e-1 + ), "Results are not close" @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") def test_pack_unpack_equivalence(self): @@ -94,9 +94,30 @@ def test_pack_unpack_equivalence(self): # Inject 2:4 sparsity mask w_24, _ = inject_24(w, *w.shape) - # Quantize weights - scales, zeros = choose_qparams_affine(w_24, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) - w_q_24 = quantize_affine(w_24, block_size, scales, zeros, target_dtype, quant_min, quant_max, zero_point_domain) + # Quantize weights + scales, zeros = choose_qparams_affine( + w_24, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + w_q_24 = quantize_affine( + w_24, + block_size, + scales, + zeros, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) scales = scales.reshape(-1, w_q_24.shape[1]) # Test pack/unpack equivalence @@ -107,8 +128,12 @@ def test_pack_unpack_equivalence(self): q_w_comp, packed_scales, meta, shape, group_size, num_bits ) - assert torch.equal(w_q_24, unpacked_q_w), "Unpacked weights do not match original weights" - assert torch.equal(scales, unpacked_scales), "Unpacked scales do not match original scales" + assert torch.equal( + w_q_24, unpacked_q_w + ), "Unpacked weights do not match original weights" + assert torch.equal( + scales, unpacked_scales + ), "Unpacked scales do not match original scales" if __name__ == "__main__": diff --git a/test/sparsity/test_wanda.py b/test/sparsity/test_wanda.py index fcb94036aa..e02ea9822a 100644 --- a/test/sparsity/test_wanda.py +++ b/test/sparsity/test_wanda.py @@ -3,12 +3,13 @@ import torch from torch import nn -from torchao.sparsity import WandaSparsifier from torch.ao.pruning import FakeSparsity from torch.nn.utils.parametrize import is_parametrized from torch.testing._internal.common_pruning import SimpleLinear from torch.testing._internal.common_utils import TestCase +from torchao.sparsity import WandaSparsifier + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -29,7 +30,9 @@ def test_prepare(self): assert hasattr(module.parametrizations["weight"][0], "mask") # Check parametrization exists and is correct assert is_parametrized(module, "weight") - assert type(module.parametrizations.weight[0]) == FakeSparsity + assert isinstance( + module.parametrizations.weight[0], FakeSparsity + ), "FakeSparsity not found" # check activation observer is present assert hasattr(module, "activation_post_process") @@ -110,5 +113,6 @@ def test_two_layer_mlp_unstructured(self): sparsifier.squash_mask() + if __name__ == "__main__": unittest.main() diff --git a/torchao/_executorch_ops.py b/torchao/_executorch_ops.py index 6a1a66ab77..3cf94ee53d 100644 --- a/torchao/_executorch_ops.py +++ b/torchao/_executorch_ops.py @@ -10,9 +10,14 @@ def _quantized_decomposed_quantize_per_channel_group_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.quantize_per_channel_group(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.quantize_per_channel_group( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.quantize_per_channel_group, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **kwargs): @@ -24,9 +29,14 @@ def _quantized_decomposed_choose_qparams_per_token_asymmetric_wrapper(*args, **k in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): @@ -38,9 +48,14 @@ def _quantized_decomposed_dequantize_per_channel_group_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: - return torch.ops.quantized_decomposed.dequantize_per_channel_group(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later.") + return torch.ops.quantized_decomposed.dequantize_per_channel_group( + *args, **kwargs + ) + raise ImportError( + "Need torch.ops.quantized_decomposed.dequantize_per_channel_group, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): @@ -52,9 +67,12 @@ def _quantized_decomposed_quantize_per_token_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: return torch.ops.quantized_decomposed.quantize_per_token(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later.") + raise ImportError( + "Need torch.ops.quantized_decomposed.quantize_per_token, which is only available with PyTorch 2.3 or later." + ) def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): @@ -66,6 +84,9 @@ def _quantized_decomposed_dequantize_per_token_wrapper(*args, **kwargs): in PyTorch 2.3+ and recently changed signatures. """ from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 + if TORCH_VERSION_AT_LEAST_2_3: return torch.ops.quantized_decomposed.dequantize_per_token(*args, **kwargs) - raise ImportError("Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later.") + raise ImportError( + "Need torch.ops.quantized_decomposed.dequantize_per_token, which is only available with PyTorch 2.3 or later." + ) diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 63733c736d..c8cd4bf39c 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128 + +# TTFT benchmarks +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured + +# 2:4 sparse model +export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 862f5d186d..8ec6acccc9 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -3,41 +3,124 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import json import os +import platform import sys import time +from datetime import datetime from pathlib import Path from typing import Optional, Tuple -from datetime import datetime + import torch -import torchao import torch._dynamo.config import torch._inductor.config -from torchao.utils import get_model_size_in_bytes + +import torchao from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5 + +torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + + +class HostEvent: + def __init__(self): + self.event_time = None + + def record(self): + self.event_time = time.perf_counter() + + def elapsed_time(self, other_event): + if self.event_time is None: + raise ValueError("Event not recorded!") + # return ms to match cuda event + return abs(other_event.event_time - self.event_time) * 1000 + + +def get_arch_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_name() + else: + # This returns x86_64 or arm64 (for aarch64) + return platform.machine() + + +def device_timer(device): + if "cuda" in device: + return torch.cuda.Event(enable_timing=True) + elif ("cpu" in device) or ("mps" in device): + return HostEvent() + else: + print(f"device={device} is not yet suppported") + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: print(f"device={device} is not yet suppported") -default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def write_json_result(output_json_path, headers, row): + """ + Write the result into JSON format, so that it can be uploaded to the benchmark database + to be displayed on OSS dashboard. The JSON format is defined at + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + mapping_headers = {headers[i]: v for i, v in enumerate(row)} + record = { + "benchmark": { + "name": "TorchAO benchmark", + "mode": "inference", + "dtype": mapping_headers["dtype"], + "extra_info": { + "device": mapping_headers["device"], + "arch": mapping_headers["arch"], + }, + }, + "model": { + "name": mapping_headers["name"], + "type": "model", + "origins": ["pytorch"], + }, + "metric": { + "name": mapping_headers["metric"], + "benchmark_values": [mapping_headers["actual"]], + "target_value": mapping_headers["target"], + }, + } + + with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f: + print(json.dumps(record), file=f) + + +default_device = ( + "cuda" + if torch.cuda.is_available() + else "xpu" + if torch.xpu.is_available() + else "cpu" +) # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.model import prepare_inputs_for_model, Transformer from torchao._models.llama.tokenizer import get_tokenizer -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): logits = logits / max(temperature, 1e-5) @@ -48,23 +131,38 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non probs = torch.nn.functional.softmax(logits, dim=-1) return probs + def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): probs = logits_to_probs(logits[:, -1], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs -def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: # input_pos: [B, S] logits = model(x, input_pos) return sample(logits, **sampling_kwargs)[0] -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [B, 1] assert input_pos.shape[-1] == 1 logits = model(x, input_pos) return sample(logits, **sampling_kwargs) -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): new_tokens, new_probs = [], [] for i in range(num_new_tokens): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): @@ -84,6 +182,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc def model_forward(model, x, input_pos): return model(x, input_pos) + @torch.no_grad() def generate( model: Transformer, @@ -92,11 +191,15 @@ def generate( batch_size: int, *, interactive: bool, - callback = lambda x: x, + callback=lambda x: x, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, - linear_causal_mask: bool=False, - **sampling_kwargs + linear_causal_mask: bool = False, + prefill_start_event: Optional[torch.cuda.Event] = None, + prefill_end_event: Optional[torch.cuda.Event] = None, + decode_start_event: Optional[torch.cuda.Event] = None, + decode_end_event: Optional[torch.cuda.Event] = None, + **sampling_kwargs, ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. @@ -107,12 +210,14 @@ def generate( T = prompt.size(-1) # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) - max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + max_seq_length = ( + min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + ) new_tokens = max_seq_length - T # format model input prompt, input_pos = prepare_inputs_for_model(prompt) - prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize + prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize # full prompt+output will be stored in seq seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) @@ -122,25 +227,53 @@ def generate( with torch.device(device): if cache_size is None: cache_size = max_seq_length - assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt" - model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T) + assert ( + cache_size >= max_seq_length + ), "need cache_size to be greater than max_new_tokens + size-of-prompt" + model.setup_caches( + max_batch_size=batch_size, + max_seq_length=cache_size, + kv_cache_quantization=kv_cache_quantization, + linear_causal_mask=linear_causal_mask, + prompt_length=T, + ) # execute prefill - next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() + if prefill_start_event is not None: + prefill_start_event.record() + next_token = prefill( + model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs + ).clone() seq[:, T] = next_token.squeeze() + if prefill_end_event is not None: + prefill_end_event.record() + # execute token generation + if decode_start_event is not None: + decode_start_event.record() input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) - seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(batch_size, -1), + input_pos, + new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq = torch.cat((seq[:, : T + 1], *generated_tokens), dim=-1) + if decode_end_event is not None: + decode_end_event.record() return seq + def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) if bos: tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) + def _load_model(checkpoint_path, device, precision): checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) if "model" in checkpoint and "stories" in str(checkpoint_path): @@ -152,9 +285,12 @@ def _load_model(checkpoint_path, device, precision): return model.eval() + B_INST, E_INST = "[INST]", "[/INST]" + def main( + prefill_size: Optional[int] = None, prompt: str = "Hello, my name is", interactive: bool = False, num_samples: int = 5, @@ -162,11 +298,14 @@ def main( batch_size: int = 1, top_k: int = 200, temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + checkpoint_path: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" + ), quantization: Optional[str] = None, + sparsity: Optional[str] = None, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, - linear_causal_mask: bool=False, + linear_causal_mask: bool = False, save: bool = False, compile: bool = True, compile_prefill: bool = False, @@ -175,9 +314,13 @@ def main( device=default_device, precision=torch.bfloat16, write_result: Optional[Path] = None, + output_json_path: Optional[Path] = None, ) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" + + if prefill_size is not None and prefill_size > 0: + # create prompt of prefill size + prompt = "prompt " * (int(prefill_size) - 3) torchao.quantization.utils.recommended_inductor_config_setter() @@ -192,8 +335,7 @@ def main( t0 = time.time() model = _load_model(checkpoint_path, device, precision) - - device_sync(device=device) # MKG + device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) @@ -203,43 +345,71 @@ def main( torch.manual_seed(1234) + def ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn + + def not_ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) + + def ffn_or_attn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and ( + "feed_forward" in fqn or "attention" in fqn + ) if quantization: from torchao.quantization import ( - quantize_, autoquant, - int8_weight_only, - int8_dynamic_activation_int8_weight, + float8_dynamic_activation_float8_weight, + float8_weight_only, + fpx_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, - fpx_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, uintx_weight_only, - float8_weight_only, - float8_dynamic_activation_float8_weight, ) - from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 - from torchao.utils import unwrap_tensor_subclass - from torchao.quantization.granularity import PerTensor, PerRow + from torchao.quantization.granularity import PerRow, PerTensor from torchao.utils import unwrap_tensor_subclass + if "spinquant" in quantization: from torchao.prototype.spinquant import apply_spinquant + apply_spinquant(model) if "int8wo" in quantization: quantize_(model, int8_weight_only()) - elif "int8dq" in quantization: - quantize_(model, int8_dynamic_activation_int8_weight()) - elif "int4wo" in quantization: + if "int8dq" in quantization: + if sparsity and "semi" in sparsity: + from torchao.dtypes import SemiSparseLayout + + quantize_( + model, + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + filter_fn=ffn_only, + ) + quantize_( + model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only + ) + else: + quantize_(model, int8_dynamic_activation_int8_weight()) + if "int4wo" in quantization: if "hqq" in quantization: - use_hqq=True + use_hqq = True else: - use_hqq=False - group_size=int(quantization.split("-")[1]) - assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + use_hqq = False + group_size = int(quantization.split("-")[1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size)) if "marlin" in quantization: if "qqq" in quantization: from torchao.dtypes import MarlinQQQLayout + quantize_( model, int8_dynamic_activation_int4_weight( @@ -249,27 +419,43 @@ def main( layout=MarlinQQQLayout(), ), ) - else: + elif "semi" in sparsity: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) + + quantize_( + model, + int4_weight_only(layout=MarlinSparseLayout()), + filter_fn=ffn_or_attn_only, + ) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) elif "embed-int8wo" in quantization: - quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding)) + quantize_( + model, + int8_weight_only(group_size=64), + filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), + ) elif quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - from torchao.prototype.awq.example import get_calib_dataset + if not TORCH_VERSION_AT_LEAST_2_3: print("Awq requires torch2.3+") exit() - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear + from torchao.prototype.awq import ( + awq_uintx, + AWQObservedLinear, + insert_awq_observer_, + ) + quant_dtype = quantization.split("-")[1] group_size = int(quantization.split("-")[2]) quant_dtype = getattr(torch, quant_dtype, torch.uint8) - model=model.to(device) + model = model.to(device) # get calibration data - insert_awq_observer_(model, 1, 256, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) TransformerEvalWrapper( model=model.to(device), tokenizer=tokenizer, @@ -277,12 +463,18 @@ def main( input_prep_func=prepare_inputs_for_model, device=device, ).run_eval( - tasks=['wikitext'], + tasks=["wikitext"], limit=1, ) is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) use_hqq = "hqq" in quantization - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) elif "uintx" in quantization: # uintx-nbits-group_size, e.g. "uintx-2-64" if "hqq" in quantization: @@ -293,38 +485,87 @@ def main( _quant_args = quantization.split("-") nbits = int(_quant_args[1]) assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8" - _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} + _NBITS_TO_DTYPE = { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + 8: torch.uint8, + } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + elif "int8_dynamic_activation_intx_weight" in quantization: + from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, + ) + + assert ( + precision == torch.float32 + ), "int8_dynamic_activation_intx_weight requires fp32 precision" + + # Build kernels in temp location, and load them in torch + # This requires an ARM CPU + from torchao.experimental.temp_build import temp_build_and_load_torchao_ops + + temp_build_and_load_torchao_ops( + cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + + "/../../experimental" + ) + + # Quantize model + _quant_args = quantization.split("-") + nbit = int(_quant_args[1]) + assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8" + group_size = int(_quant_args[2]) + has_weight_zeros = bool(_quant_args[3]) + quantize_( + model, + int8_dynamic_activation_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + ), + ) elif "float8wo" in quantization: quantize_(model, float8_weight_only()) elif "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) - if granularity=="tensor": + if granularity == "tensor": granularity = PerTensor() - elif granularity=="row": + elif granularity == "row": granularity = PerRow() else: granularity = PerTensor() - quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity)) + quantize_( + model, float8_dynamic_activation_float8_weight(granularity=granularity) + ) elif "autoquant_v2" in quantization: from torchao._models._eval import InputRecorder from torchao._models.llama.model import prepare_inputs_for_model + from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 calibration_seq_length = 256 calibration_limit = 1 - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - prepare_inputs_for_model, - False, # pad_calibration_inputs - model.config.vocab_size, - device="cuda" - ).record_inputs( - ["wikitext"], - 1, - ).get_inputs()[0].values[0] + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda", + ) + .record_inputs( + ["wikitext"], + 1, + ) + .get_inputs()[0] + .values[0] + ) inputs = prepare_inputs_for_model(inputs) with torch.device("cuda"): model.setup_caches( @@ -332,11 +573,54 @@ def main( ) if "autoquant_v2-int4" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) elif "autoquant_v2-float8" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) + elif "autoquant_v2-fp" == quantization: + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) + elif "autoquant_v2-all" == quantization: + all_qtensor_classes = ( + torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += ( + torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST + ) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=all_qtensor_classes, + example_input=inputs, + batch_size=calibration_seq_length, + ) else: - model = autoquant_v2(model, manual=True, example_input=inputs) + model = autoquant_v2( + model, + manual=True, + example_input=inputs, + batch_size=calibration_seq_length, + ) print("running generate") generate( @@ -358,17 +642,22 @@ def main( calibration_seq_length = 256 calibration_limit = 1 - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - prepare_inputs_for_model, - False, # pad_calibration_inputs - model.config.vocab_size, - device="cuda" - ).record_inputs( - ["wikitext"], - 1, - ).get_inputs()[0].values[0] + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda", + ) + .record_inputs( + ["wikitext"], + 1, + ) + .get_inputs()[0] + .values[0] + ) inputs = prepare_inputs_for_model(inputs) with torch.device("cuda"): model.setup_caches( @@ -376,9 +665,50 @@ def main( ) if "autoquant-int4" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) elif "autoquant-float8" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) + if "autoquant-fp" == quantization: + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) + if "autoquant-sparse" == quantization: + model = autoquant( + model, + manual=True, + qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) + if "autoquant-all" == quantization: + all_qtensor_classes = ( + torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += ( + torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST + ) + model = autoquant( + model, + manual=True, + qtensor_class_list=all_qtensor_classes, + example_input=inputs, + ) else: model = autoquant(model, manual=True, example_input=inputs) @@ -399,36 +729,61 @@ def main( if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) + # standalone sparsity + elif sparsity: + from torchao.sparsity import semi_sparse_weight, sparsify_ + + if "semi" in sparsity: + # TODO there is a bug here, need to fix + sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 if save: output_dir = str(checkpoint_path.cwd()) filename = str(checkpoint_path.name).split(".")[0] - torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt")) + torch.save( + model.state_dict(), + os.path.join(output_dir, filename + f"-{quantization}.pt"), + ) if compile: print("Compiling Model") global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) if memory_profile: - if device != "cuda": - print("Memory profiling only works on CUDA") + if device == "cuda": + torch.cuda.memory._record_memory_history( + True, trace_alloc_max_entries=250000, trace_alloc_record_context=True + ) + elif device == "xpu": + torch.xpu.memory._record_memory_history( + True, trace_alloc_max_entries=250000, trace_alloc_record_context=True + ) else: - torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + print("Memory profiling only works on CUDA or XPU devices") + aggregate_metrics = { - 'tokens_per_sec': [], + "tokens_per_sec": [], + "time": [], + "decode_tokens_per_sec": [], + "prefill_time": [], } start = -1 if compile else 0 for i in range(start, num_samples): - if i==0: + if i == 0: if device == "cuda": - torch.cuda.reset_peak_memory_stats() # MKG - device_sync(device=device) # MKG + torch.cuda.reset_peak_memory_stats() # MKG + elif device == "xpu": + torch.xpu.reset_peak_memory_stats() # MKG + device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") if is_chat: @@ -437,8 +792,9 @@ def main( if interactive and i >= 0: buffer = [] - period_id = tokenizer.encode('.')[0] + period_id = tokenizer.encode(".")[0] done_generating = False + def callback(x): nonlocal done_generating if done_generating: @@ -447,14 +803,22 @@ def callback(x): if x.item() == tokenizer.eos_id(): done_generating = True if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) + print("".join(buffer), end="", flush=True) buffer.clear() # print(, end='', flush=True) + else: - callback = lambda x : x + callback = lambda x: x t0 = time.perf_counter() + prefill_start_event, prefill_end_event = device_timer(device), device_timer( + device + ) + decode_start_event, decode_end_event = device_timer(device), device_timer( + device + ) import contextlib - if (i != num_samples - 1 or not profile): + + if i != num_samples - 1 or not profile: prof = contextlib.nullcontext() else: torch.profiler._utils._init_for_cuda_graphs() @@ -472,47 +836,79 @@ def callback(x): kv_cache_quantization=kv_cache_quantization, cache_size=cache_size, linear_causal_mask=linear_causal_mask, + prefill_start_event=prefill_start_event, + prefill_end_event=prefill_end_event, + decode_start_event=decode_start_event, + decode_end_event=decode_end_event, ) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue if hasattr(prof, "export_chrome_trace"): prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG + device_sync(device=device) # MKG t = time.perf_counter() - t0 - if not interactive: - tok_list = y[0].tolist() - # truncate text after end of string token - tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())] - print(tokenizer.decode(tokens)) + if not interactive and prefill_size is None: + tok_list = y[0].tolist() + # truncate text after end of string token + tokens = ( + tok_list + if tokenizer.eos_id() not in tok_list + else tok_list[: tok_list.index(tokenizer.eos_id())] + ) + print(tokenizer.decode(tokens)) else: print() - tokens_generated = (y.size(-1) - prompt_length) + tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["time"].append(t) + decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000 + decode_tokens_sec = tokens_generated / decode_time + aggregate_metrics["decode_tokens_per_sec"].append(decode_tokens_sec) + prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000 + aggregate_metrics["prefill_time"].append(prefill_time) + print( + f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", + f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec", + ) print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") - if memory_profile and i==0: - if device != "cuda": - print("Memory profiling only works on CUDA") - else: + if memory_profile and i == 0: + if device == "cuda": snapshot = torch.cuda.memory._snapshot() - with open(f"{memory_profile}.pickle", 'wb') as f: - from pickle import dump - dump(snapshot, f) - print( - f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", - "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" - ) - break + elif device == "xpu": + snapshot = torch.xpu.memory._snapshot() + else: + print("Memory profiling only works on CUDA or XPU devices") + with open(f"{memory_profile}.pickle", "wb") as f: + from pickle import dump + + dump(snapshot, f) + print( + f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", + "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html", + ) + break print("==========") - tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + # ignore first sample for warmup + tokpersec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() + ttft = torch.mean(torch.tensor(aggregate_metrics["prefill_time"])).item() + decode_tokpersec = torch.mean( + torch.tensor(aggregate_metrics["decode_tokens_per_sec"]) + ).item() bandwidth = model_size * tokpersec - mem = torch.cuda.max_memory_reserved() /1e9 + mem = torch.cuda.max_memory_reserved() / 1e9 + print(f"Average overall tokens/sec: {tokpersec:.2f}") + print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s") + print(f"Average TTFT: {ttft:.04f} s") + if device == "cuda": + mem = torch.cuda.max_memory_reserved() / 1e9 + elif device == "xpu": + mem = torch.xpu.max_memory_reserved() / 1e9 print(f"Average tokens/sec: {tokpersec:.2f}") if batch_size > 1: print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") @@ -520,66 +916,165 @@ def callback(x): print(f"Peak Memory Usage: {mem:.02f} GB") print(f"Model Size: {model_size:.02f} GB") if write_result: - result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " - result_txt += f"repro: python generate.py " + result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " + result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " + result_txt += "repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" + result_txt += f"--sparsity {sparsity} " if sparsity else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " - result_txt += f"--compile " if compile else "" - result_txt += f"--compile_prefill " if compile_prefill else "" + result_txt += "--compile " if compile else "" + result_txt += "--compile_prefill " if compile_prefill else "" + result_txt += f"--prefill_size {prefill_size}" if prefill_size else "" result_txt += f"--profile {profile} " if profile else "" result_txt += f"--profile {memory_profile} " if memory_profile else "" - result_txt += f"--interactive " if interactive else "" + result_txt += "--interactive " if interactive else "" result_txt += f"--num_samples {num_samples} " result_txt += f"--max_new_tokens {max_new_tokens} " result_txt += f"--batch_size {batch_size} " result_txt += f"--top_k {top_k} " result_txt += f"--temperature {temperature} " result_txt += f"--cache_size {cache_size}" if cache_size else "" - result_txt += f"--kv_cache_quantization " if kv_cache_quantization else "" - result_txt += f"--linear_causal_mask " if linear_causal_mask else "" + result_txt += "--kv_cache_quantization " if kv_cache_quantization else "" + result_txt += "--linear_causal_mask " if linear_causal_mask else "" - f=open(write_result, "a") + f = open(write_result, "a") f.write(result_txt) f.close() + headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] + name = checkpoint_path.parent.name + arch = get_arch_name() + dtype = quantization or str(precision) + memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None] + performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None] + if output_json_path: + write_json_result(output_json_path, headers, memory_result) + write_json_result(output_json_path, headers, performance_result) -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, + + parser = argparse.ArgumentParser(description="Your CLI description.") + parser.add_argument( + "--prefill_size", type=int, default=0, help="Whether to run in ttft mode" + ) + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Whether to launch in interactive mode", + ) + parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size to benchmark with" + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "-q", + "--quantization", + type=str, help=( - 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' - +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, ' - +'embed-int8wo, marlin_qqq' - ) + "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + + "autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, " + + "embed-int8wo, marlin_qqq" + ), + ) + parser.add_argument( + "-s", + "--sparsity", + type=str, + help=("Which sparsity techniques to apply: semi-structured"), + ) + parser.add_argument( + "--kv_cache_quantization", + action="store_true", + help="Whether to quantize the KV cache", + ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + help="Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size", + ) + parser.add_argument( + "--linear_causal_mask", + action="store_true", + help="Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)", + ) + parser.add_argument( + "--save", action="store_true", help="Whether to save the quantized model." + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--compile_prefill", + action="store_true", + help="Whether to compile the prefill (improves prefill perf, but higher compile times)", + ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") + parser.add_argument( + "--memory_profile", type=Path, default=None, help="filename for memory profile." + ) + parser.add_argument( + "--device", type=str, default=default_device, help="Device to use" + ) + parser.add_argument( + "--precision", + type=lambda x: getattr(torch, x.split(".")[-1]), + default=torch.bfloat16, + help="dtype precision to use", + ) + parser.add_argument( + "--write_result", type=Path, default=None, help="Path where to write the result" + ) + parser.add_argument( + "--output_json_path", + type=Path, + default=None, + help="Path where to write the json result for dashboard", ) - parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') - parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') - parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') - parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--memory_profile', type=Path, default=None, help='filename for memory profile.') - parser.add_argument('--device', type=str, default=default_device, help='Device to use') - parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') - parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.prefill_size, + args.prompt, + args.interactive, + args.num_samples, + args.max_new_tokens, + args.batch_size, + args.top_k, + args.temperature, + args.checkpoint_path, + args.quantization, + args.sparsity, + args.kv_cache_quantization, + args.cache_size, + args.linear_causal_mask, + args.save, + args.compile, + args.compile_prefill, + args.profile, + args.memory_profile, + args.device, + args.precision, + args.write_result, + args.output_json_path, ) diff --git a/torchao/_models/llama/perf_profile.py b/torchao/_models/llama/perf_profile.py index 1a0d4e36c0..f613982221 100644 --- a/torchao/_models/llama/perf_profile.py +++ b/torchao/_models/llama/perf_profile.py @@ -2,9 +2,9 @@ ## Performance Profiling Example -An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.profiler.TransformerPerformanceCounter`. +An minimal version of `gpt-fast generate.py` that demonstrates usage of `torchao.prototype.profiler.TransformerPerformanceCounter`. - Outputs from gpt-fast are prefixed with GPT-Fast -- Outputs from `torchao.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. +- Outputs from `torchao.prototype.profiler.TransformerPerformanceCounter` are prefixed with `TransformerPerfCounter`. ## Usage ```python @@ -118,7 +118,7 @@ from torchao._models.llama.model import Transformer from torchao._models.llama.tokenizer import get_tokenizer -from torchao.profiler import ( +from torchao.prototype.profiler import ( CUDADeviceSpec, TransformerPerformanceCounter, total_model_params, diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py index 9c05d00b26..09a3448d6a 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/torchao/_models/sam/eval_combo.py @@ -350,6 +350,8 @@ def mlp_only(mod, name): autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) elif "autoquant_v2-float8" == compress: autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST) + elif "autoquant_v2-all" == compress: + autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.ALL_AUTOQUANT_CLASS_LIST) else: autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True) @@ -362,6 +364,10 @@ def mlp_only(mod, name): autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) elif "autoquant-float8" == compress: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST) + elif "autoquant-sparse" == compress: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST) + elif "autoquant-all" == compress: + autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST) else: autoquant(predictor.model.image_encoder, example_input=example_input, manual=True) predictor.model.image_encoder(example_input) diff --git a/torchao/_models/sam2/automatic_mask_generator.py b/torchao/_models/sam2/automatic_mask_generator.py index db544a9b61..891a2602ba 100644 --- a/torchao/_models/sam2/automatic_mask_generator.py +++ b/torchao/_models/sam2/automatic_mask_generator.py @@ -36,7 +36,7 @@ ) -class SAM2AutomaticMaskGenerator: +class SAM2AutomaticMaskGenerator(torch.nn.Module): def __init__( self, model: SAM2Base, @@ -105,7 +105,7 @@ def __init__( use_m2m (bool): Whether to add a one step refinement using previous mask predictions. multimask_output (bool): Whether to output multimask at each point of the grid. """ - + super().__init__() assert (points_per_side is None) != ( point_grids is None ), "Exactly one of points_per_side or point_grid must be provided." diff --git a/torchao/_models/sam2/build_sam.py b/torchao/_models/sam2/build_sam.py index 470cbfff99..d6847ede83 100644 --- a/torchao/_models/sam2/build_sam.py +++ b/torchao/_models/sam2/build_sam.py @@ -107,7 +107,7 @@ def build_sam2_video_predictor( **kwargs, ): hydra_overrides = [ - "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + "++model._target_=torchao._models.sam2.sam2_video_predictor.SAM2VideoPredictor", ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml b/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml index 58f3eb8155..b3ba469471 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml +++ b/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -2,18 +2,18 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -24,17 +24,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -45,7 +45,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -57,23 +57,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml b/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml index 26e5d4d39f..b051d3be63 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml +++ b/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml b/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml index a62c903aaa..6b108e708f 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml +++ b/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: sam2.modeling.sam2_base.SAM2Base + _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: sam2.modeling.backbones.hieradet.Hiera + _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: sam2.modeling.backbones.image_encoder.FpnNeck + _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: sam2.modeling.memory_attention.MemoryAttention + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: sam2.modeling.sam.transformer.RoPEAttention + _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: sam2.modeling.memory_encoder.MemoryEncoder + _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: sam2.modeling.memory_encoder.MaskDownSampler + _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: sam2.modeling.memory_encoder.Fuser + _target_: torchao._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: sam2.modeling.memory_encoder.CXBlock + _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/map_tensor.py b/torchao/_models/sam2/map_tensor.py new file mode 100644 index 0000000000..a32424d99b --- /dev/null +++ b/torchao/_models/sam2/map_tensor.py @@ -0,0 +1,617 @@ +import contextlib +import torch +from torch.utils._pytree import tree_map +from typing import Dict +from torch.nested._internal.nested_tensor import nested_view_from_values_offsets +import functools + +MAP_TENSOR_ATEN_OP_TABLE = {} + + +def implements(aten_ops_or_torch_fns): + if not isinstance(aten_ops_or_torch_fns, (list, tuple)): + aten_ops_or_torch_fns = [aten_ops_or_torch_fns] + + def decorator(func): + for op in aten_ops_or_torch_fns: + + @functools.wraps(op) + def wrapper(f, types, args, kwargs): + return func(f, types, args, kwargs) + + MAP_TENSOR_ATEN_OP_TABLE[op] = wrapper + return func + + return decorator + + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def wrap_dim(i, dim): + if i < 0: + return dim + i + return i + + +def unwrap(t): + if isinstance(t, MapTensor): + with no_dispatch(): + return t.elems + else: + return t + + +def unwrap_i(t, i): + if isinstance(t, MapTensor): + with no_dispatch(): + return t.elems[i] + else: + return t + + +def unwrap_fn(t, fn): + if isinstance(t, MapTensor): + with no_dispatch(): + return fn(t.elems) + else: + return None + + +def wrap(t): + if isinstance(t, torch.Tensor): + return MapTensor(t) + else: + return t + + +@implements(torch.ops.aten.native_layer_norm.default) +def layer_norm_impl(func, types, args, kwargs=None): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + norm_res = func(*unwrapped_args) + assert len(norm_res) == 3 + return tuple(wrap(a) for a in norm_res) + + +@implements(torch.ops.aten.add.Tensor) +def add_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + if not isinstance(args[0], MapTensor) and isinstance(args[1], MapTensor): + if args[0].dim() == (args[1].dim() + 1): + return NotImplemented + return NotImplemented + return wrap(func(*unwrapped_args, **unwrapped_kwargs)) + + +@implements([torch.ops.aten.cat.default, + torch.ops.aten.stack.default]) +def cat_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) <= 2, f"args: {unwrapped_args}" + # TODO: Use MapTensor type for filter + # First argument's dim + dim = unwrapped_args[0][0].dim() + size = unwrapped_args[0][0].size() + for a in unwrapped_args[0]: + if a.dim() > dim: + dim = a.dim() + size = a.size() + new_args = [] + for a in unwrapped_args[0]: + if a.dim() == dim: + new_args.append(a) + else: + assert a.dim() + 1 == dim + new_args.append(a.unsqueeze(0).expand((size[0],) + a.size())) + orig_dim = unwrapped_args[1] if len(unwrapped_args) == 2 else 0 + return wrap(func(new_args, wrap_dim(orig_dim, dim - 1) + 1)) + + +@implements(torch.ops.aten.select.int) +def select_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + return wrap(func(unwrapped_args[0], unwrapped_args[1] + 1, unwrapped_args[2])) + + +@implements(torch.ops.aten.slice.Tensor) +def slice_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 4, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1, + unwrapped_args[2], + unwrapped_args[3])) + + +@implements([torch.ops.aten.mean.dim, + torch.ops.aten.max.dim, + torch.ops.aten.argmax.default, + torch.ops.aten.min.dim, + torch.ops.aten.any.dim, + torch.ops.aten.amax.default, + torch.ops.aten.amin.default, + torch.ops.aten.all.default, + torch.ops.aten.sum.dim_IntList]) +def reductions_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + # TODO: THIS MIGHT BE WRONG + if len(unwrapped_args) == 3 and len(unwrapped_kwargs) == 0: + assert len(unwrapped_args[1]) == 1 + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]], + unwrapped_args[2])) + if len(unwrapped_args) == 2 and len(unwrapped_kwargs) == 1: + assert len(unwrapped_args[1]) == 1 + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]], + **unwrapped_kwargs)) + if len(unwrapped_args) == 2 and len(unwrapped_kwargs) == 0 and type(unwrapped_args[1]) == list: + assert len(unwrapped_args[1]) == 1 + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]])) + if len(unwrapped_args) == 2 and len(unwrapped_kwargs) == 0 and type(unwrapped_args[1]) == int: + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], wrap_dim(unwrapped_args[1], dim - 1) + 1)) + if len(args) == 1 and len(kwargs) == 0: + return wrap(func(unwrapped_args[0])) + return NotImplemented + + +@implements([torch.ops.aten._unsafe_view.default, + torch.ops.aten.expand.default]) +def view_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + input_size = unwrapped_args[0].size() + bigger_size = list(input_size[:1]) + unwrapped_args[1] + return wrap(func(unwrapped_args[0], bigger_size)) + + +@implements(torch.ops.aten.view.default) +def view_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + input_size = unwrapped_args[0].size() + bigger_size = list(input_size[:1]) + unwrapped_args[1] + if unwrapped_args[0].size() == tuple(bigger_size): + return wrap(args[0].elems) + return wrap(unwrapped_args[0].reshape(bigger_size)) + + +@implements([torch.ops.aten.mm.default, + torch.ops.aten.bmm.default]) +def mm_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + return wrap(torch.matmul(*unwrapped_args)) + + +@implements(torch.ops.aten.unsqueeze.default) +def unsqueeze_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + new_i = unwrapped_args[1] + if new_i >= 0: + new_i += 1 + return wrap(func(unwrapped_args[0], new_i)) + + +@implements(torch.ops.aten.addmm.default) +def addmm_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + return wrap(torch.matmul(unwrapped_args[1], unwrapped_args[2]) + unwrapped_args[0]) + + +@implements(torch.ops.aten.convolution.default) +def convolution_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 9, f"args: {unwrapped_args}" + a = unwrapped_args[0] + a = unwrapped_args[0].flatten(0, 1) + # TODO: It's scary that this .contiguous seems necessary, but I we're below composite conv + # which might expected contiguous output + resa = func(*((a,) + unwrapped_args[1:])).contiguous() + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + return wrap(resb) + + +@implements(torch.ops.aten.upsample_bilinear2d.default) +def upsample_bilinear2d_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + a = unwrapped_args[0] + a = unwrapped_args[0].flatten(0, 1) + # NOTE: It's scary that this .contiguous seems necessary, but we're below composite upsample + # which might expected contiguous output + resa = func(*((a,) + unwrapped_args[1:])).contiguous() + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + return wrap(resb) + + +@implements(torch.ops.aten.transpose.int) +def transpose_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1, + wrap_dim(unwrapped_args[2], dim - 1) + 1)) + + +@implements(torch.ops.aten.unbind.int) +def unbind_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + wrap_dim(unwrapped_args[1], dim - 1) + 1)) + + +@implements(torch.ops.aten.permute.default) +def permute_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + dim = unwrapped_args[0].dim() + return wrap(func(unwrapped_args[0], + ([0] + [wrap_dim(u, dim - 1) + 1 for u in unwrapped_args[1]]))) + + +@implements(torch.ops.aten._scaled_dot_product_efficient_attention.default) +def _scaled_dot_product_efficient_attention_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(args) == 5 + if all(isinstance(a, MapTensor) for a in args[:3]): + # assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + sdpa_res = wrap(func(unwrapped_args[0].flatten(0, 1), + unwrapped_args[1].flatten(0, 1), + unwrapped_args[2].flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4], **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if isinstance(args[0], MapTensor) and not any(isinstance(a, MapTensor) for a in args[1:]): + # assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 4 + assert unwrapped_args[2].dim() == 4 + a0 = unwrapped_args[0] + a1_size = unwrapped_args[1].size() + a1 = unwrapped_args[1].unsqueeze(0).expand((a0.size(0),) + a1_size) + a2 = unwrapped_args[2].unsqueeze(0).expand((a0.size(0),) + a1_size) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4], **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and (not isinstance(args[2], MapTensor))): + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 4 + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[0].size()[1:]) + a2 = unwrapped_args[2].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[2].size()[1:]) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4])) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and isinstance(args[2], MapTensor)): + # assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 5, f"args: {unwrapped_args}" + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + a0_size = unwrapped_args[0].size() + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + a0_size) + a1 = unwrapped_args[1] + a2 = unwrapped_args[2] + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + unwrapped_args[3], + unwrapped_args[4], **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view((a1_size[0],) + a0_size)),) + sdpa_res[1:] + return NotImplemented + + +@implements(torch.ops.aten._scaled_dot_product_flash_attention.default) +def _scaled_dot_product_flash_attention_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(args) == 3 + assert len(unwrapped_kwargs) == 1 + assert len(unwrapped_args) == 3, f"args: {unwrapped_args}" + if all(isinstance(a, MapTensor) for a in args[:3]): + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + sdpa_res = wrap(func(unwrapped_args[0].flatten(0, 1), + unwrapped_args[1].flatten(0, 1), + unwrapped_args[2].flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if isinstance(args[0], MapTensor) and not any(isinstance(a, MapTensor) for a in args[1:]): + assert unwrapped_args[0].dim() == 5 + assert unwrapped_args[1].dim() == 4 + assert unwrapped_args[2].dim() == 4 + a0 = unwrapped_args[0] + a1_size = unwrapped_args[1].size() + a1 = unwrapped_args[1].unsqueeze(0).expand((a0.size(0),) + a1_size) + a2 = unwrapped_args[2].unsqueeze(0).expand((a0.size(0),) + a1_size) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and (not isinstance(args[2], MapTensor))): + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 4 + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[0].size()[1:]) + a2 = unwrapped_args[2].unsqueeze(0).expand((a1_size[0],) + unwrapped_args[2].size()[1:]) + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view(unwrapped_args[0].size())),) + sdpa_res[1:] + if ((not isinstance(args[0], MapTensor)) and isinstance(args[1], MapTensor) and isinstance(args[2], MapTensor)): + assert unwrapped_args[0].dim() == 4 + assert unwrapped_args[1].dim() == 5 + assert unwrapped_args[2].dim() == 5 + a0_size = unwrapped_args[0].size() + a1_size = unwrapped_args[1].size() + a0 = unwrapped_args[0].unsqueeze(0).expand((a1_size[0],) + a0_size) + a1 = unwrapped_args[1] + a2 = unwrapped_args[2] + sdpa_res = wrap(func(a0.flatten(0, 1), + a1.flatten(0, 1), + a2.flatten(0, 1), + **unwrapped_kwargs)) + return (wrap(sdpa_res[0].view((a1_size[0],) + a0_size)),) + sdpa_res[1:] + return NotImplemented + + +# torch.ops.aten._unsafe_index.Tensor is only needed by inductor for compile +@implements([torch.ops.aten._unsafe_index.Tensor, + torch.ops.aten.index.Tensor]) +def index_ops_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + assert len(unwrapped_kwargs) == 0 + assert len(unwrapped_args) == 2, f"args: {unwrapped_args}" + # if len(args[1]) == 1 and isinstance(args[1][0], MapTensor) and isinstance(args[0], MapTensor): + # return wrap(func(*unwrapped_args)) + if len(args[1]) == 1 and isinstance(args[1][0], MapTensor) and not isinstance(args[0], MapTensor): + tensors = [func(args[0], [args[1][0].elems[i]]) for i in range(len(args[1][0].elems))] + values = torch.cat(tensors) + lengths = torch.tensor([0] + [t.size(0) for t in tensors], pin_memory=True).to(values.device, non_blocking=True) + offsets = torch.cumsum(lengths, dim=0) + nt = nested_view_from_values_offsets(values, offsets) + assert nt.layout == torch.jagged + return wrap(nt) + if isinstance(args[0], MapTensor) and not isinstance(args[1][0], MapTensor) and len(args[1]) == 1: + return wrap(func(args[0].elems, [args[1][0].unsqueeze(0)])) + if isinstance(args[0], MapTensor) and not isinstance(args[1][0], MapTensor) and isinstance(args[1][1], MapTensor)and len(args[1]) == 2: + res = [] + for a0, a11 in zip(args[0].elems.unbind(), args[1][1].elems.unbind()): + res.append(func(a0, [args[1][0], a11])) + return wrap(torch.stack(res)) + if isinstance(args[0], MapTensor) and isinstance(args[1][0], MapTensor) and len(args[1]) == 1: + tensors = [func(args[0].elems[i], [args[1][0].elems[i]]) for i in range(len(args[0].elems))] + values = torch.cat(tensors) + lengths = torch.tensor([0] + [t.size(0) for t in tensors], pin_memory=True).to(values.device, non_blocking=True) + offsets = torch.cumsum(lengths, dim=0) + nt = nested_view_from_values_offsets(values, offsets) + assert nt.layout == torch.jagged + return wrap(nt) + a = unwrapped_args[0] + a = unwrapped_args[0].flatten(0, 1) + resa = func(a, args[1]) + resb = resa.view((unwrapped_args[0].size(0), unwrapped_args[0].size(1)) + resa.size()[1:]) + return wrap(resb) + + +# Prims +@implements(torch.ops.aten.dim.default) +def dim_impl(func, types, args, kwargs): + assert len(args) == 1 + assert len(kwargs) == 0 + ret_dim = func(args[0].elems) - 1 + assert ret_dim >= 0 + return ret_dim + + +@implements(torch.ops.aten.sym_size.default) +def sym_impl(func, types, args, kwargs): + assert len(args) == 1 + assert len(kwargs) == 0 + elems_size = func(args[0].elems) + assert len(elems_size) > 0 + return elems_size[1:] + + +@implements(torch.ops.aten.is_contiguous.default) +def is_contiguous_impl(func, types, args, kwargs): + assert len(args) == 1 + assert len(kwargs) == 0 + return func(args[0].elems) + + +@implements([torch.ops.aten.clamp.default, + torch.ops.aten.clone.default, + torch.ops.aten.cos.default, + torch.ops.aten.div.Tensor, + torch.ops.aten.eq.Scalar, + torch.ops.aten.gelu.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.relu.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.sin.default, + torch.ops.aten.sqrt.default, + torch.ops.aten.sub.Tensor, + torch.ops.aten.unbind.int, + torch.ops.aten.where.self, + torch.ops.aten.zeros_like.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.gt.Scalar, + torch.ops.aten.ge.Scalar, + torch.ops.aten.bitwise_not.default, + torch.ops.aten.lt.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.eq.Tensor, + torch.ops.aten.abs.default, + torch.ops.aten.ne.Scalar, + torch.ops.aten.le.Tensor, + torch.ops.aten.view_as_complex.default, + torch.ops.aten.view_as_real.default, + torch.ops.aten.neg.default, + torch.ops.aten.le.Scalar, + torch.ops.aten.rsub.Scalar, + # Sketchy new in place ops + torch.ops.aten.bitwise_and_.Tensor, + torch.ops.aten.bitwise_or_.Tensor, + torch.ops.aten.le.Tensor, + torch.ops.aten.logical_and.default, + # in place ops + torch.ops.aten.add_.Tensor, + torch.ops.aten.copy_.default, + # Prims + torch.ops.prim.layout.default]) +def forwardables_impl(func, types, args, kwargs): + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + return wrap(func(*unwrapped_args, **unwrapped_kwargs)) + + +def run_invariant_test(res, func, args, kwargs): + # Compares 0th element of list of results with + # func applied to 0th arg and kwarg. + # Rough test to maintain per-op accuracy. + if isinstance(res, torch.Tensor): + unwrapped_args_0 = tree_map(lambda x: unwrap_i(x, 0), args) + unwrapped_kwargs_0 = tree_map(lambda x: unwrap_i(x, 0), kwargs) + if func == torch.ops.aten.view.default: + res_0 = torch.ops.aten.reshape.default(*unwrapped_args_0, **unwrapped_kwargs_0) + else: + res_0 = func(*unwrapped_args_0, **unwrapped_kwargs_0) + if res.elems[0].size() != res_0.size(): + import pdb; pdb.set_trace() + if not torch.allclose(res.elems[0], res_0, atol=1e-3, rtol=1e-3): + import pdb; pdb.set_trace() + else: + pass + # print("res got type: ", type(res)) + # import pdb; pdb.set_trace() + return res + + +class MapTensor(torch.Tensor): + @staticmethod + def __new__(cls, elems): + # print("elems.layout: ", elems.layout) + return torch.Tensor._make_wrapper_subclass(cls, + elems.shape[1:], + dtype=elems.dtype, + device=elems.device, + layout=elems.layout, + dispatch_layout=True, + dispatch_sizes_strides_policy=("sizes" if elems.layout == torch.jagged else None), + storage_size=(elems._values.untyped_storage().size() if elems.layout == torch.jagged else None)) + + def __init__(self, elems): + self.elems = elems + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func in MAP_TENSOR_ATEN_OP_TABLE: + res = MAP_TENSOR_ATEN_OP_TABLE[func](func, types, args, kwargs) + # run_invariant_test(res, func, args, kwargs) + return res + return NotImplemented + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + # flatten/unflatten is needed for compile + def __tensor_flatten__(self): + ctx = {} + inner_tensors = ["elems"] + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride): + from torch._subclasses.fake_tensor import FakeTensor + + # inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen] + assert len(inner_tensors) == 1, f"{inner_tensors}" + elems = inner_tensors["elems"] + + return MapTensor(elems) + + def __repr__(self): + return f"MapTensor({self.elems.size()})" + +# ts is a higher dim Tensor +def to_map_tensor(ts: torch.Tensor): + return MapTensor(ts) diff --git a/torchao/_models/sam2/modeling/backbones/image_encoder.py b/torchao/_models/sam2/modeling/backbones/image_encoder.py index 37e9266bc9..7225316bc7 100644 --- a/torchao/_models/sam2/modeling/backbones/image_encoder.py +++ b/torchao/_models/sam2/modeling/backbones/image_encoder.py @@ -28,7 +28,15 @@ def __init__( def forward(self, sample: torch.Tensor): # Forward through backbone - features, pos = self.neck(self.trunk(sample)) + with torch.autograd.profiler.record_function("self.neck(self.trunk(sample))"): + from torchao._models.sam2.map_tensor import MapTensor + from torchao._models.sam2.map_tensor import to_map_tensor + if isinstance(sample, MapTensor): + features, pos = self.neck(self.trunk(sample.elems.flatten(0, 1))) + features = [to_map_tensor(t.unsqueeze(1)) for t in features] + pos = [to_map_tensor(t.unsqueeze(1)) for t in pos] + else: + features, pos = self.neck(self.trunk(sample)) if self.scalp > 0: # Discard the lowest resolution features features, pos = features[: -self.scalp], pos[: -self.scalp] diff --git a/torchao/_models/sam2/modeling/position_encoding.py b/torchao/_models/sam2/modeling/position_encoding.py index 5ba359d8d2..f4cd77fd4b 100644 --- a/torchao/_models/sam2/modeling/position_encoding.py +++ b/torchao/_models/sam2/modeling/position_encoding.py @@ -164,18 +164,18 @@ def forward_with_coords( # 3. https://github.com/lucidrains/rotary-embedding-torch -def init_t_xy(end_x: int, end_y: int): - t = torch.arange(end_x * end_y, dtype=torch.float32) +def init_t_xy(end_x: int, end_y: int, device=None): + t = torch.arange(end_x * end_y, dtype=torch.float32, device=device) t_x = (t % end_x).float() t_y = torch.div(t, end_x, rounding_mode="floor").float() return t_x, t_y -def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): - freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) - freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, device=None): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4, device=device)[: (dim // 4)].float() / dim)) - t_x, t_y = init_t_xy(end_x, end_y) + t_x, t_y = init_t_xy(end_x, end_y, device=device) freqs_x = torch.outer(t_x, freqs_x) freqs_y = torch.outer(t_y, freqs_y) freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/torchao/_models/sam2/modeling/sam/transformer.py index 2e3d85ccd4..5574cb3fa2 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/torchao/_models/sam2/modeling/sam/transformer.py @@ -325,9 +325,10 @@ def forward( # Apply rotary position encoding w = h = math.sqrt(q.shape[-2]) - self.freqs_cis = self.freqs_cis.to(q.device) + # NOTE: Disabling this. + # self.freqs_cis = self.freqs_cis.to(q.device) if self.freqs_cis.shape[0] != q.shape[-2]: - self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + self.freqs_cis = self.compute_cis(end_x=w, end_y=h, device=q.device) # .to(q.device) if q.shape[-2] != k.shape[-2]: assert self.rope_k_repeat diff --git a/torchao/_models/sam2/modeling/sam2_base.py b/torchao/_models/sam2/modeling/sam2_base.py index 20874e0581..f467c448a6 100644 --- a/torchao/_models/sam2/modeling/sam2_base.py +++ b/torchao/_models/sam2/modeling/sam2_base.py @@ -628,7 +628,7 @@ def _prepare_memory_conditioned_features( if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = torch.tensor(pos_list, device=device) + obj_pos = torch.tensor(pos_list).pin_memory().to(device=device, non_blocking=True) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) @@ -709,8 +709,8 @@ def _encode_new_memory( maskmem_out = self.memory_encoder( pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied ) - maskmem_features = maskmem_out["vision_features"] - maskmem_pos_enc = maskmem_out["vision_pos_enc"] + maskmem_features = maskmem_out["vision_features"].clone() + maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] # add a no-object embedding to the spatial memory to indicate that the frame # is predicted to be occluded (i.e. no object is appearing in the frame) if self.no_obj_embed_spatial is not None: @@ -809,6 +809,7 @@ def _encode_memory_in_output( current_out["maskmem_features"] = None current_out["maskmem_pos_enc"] = None + @torch.autograd.profiler.record_function("track_step") def track_step( self, frame_idx, @@ -854,13 +855,13 @@ def track_step( object_score_logits, ) = sam_outputs - current_out["pred_masks"] = low_res_masks - current_out["pred_masks_high_res"] = high_res_masks - current_out["obj_ptr"] = obj_ptr + current_out["pred_masks"] = low_res_masks.clone() + current_out["pred_masks_high_res"] = high_res_masks.clone() + current_out["obj_ptr"] = obj_ptr.clone() if not self.training: # Only add this in inference (to avoid unused param in activation checkpointing; # it's mainly used in the demo to encode spatial memories w/ consolidated masks) - current_out["object_score_logits"] = object_score_logits + current_out["object_score_logits"] = object_score_logits.clone() # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) @@ -870,7 +871,7 @@ def track_step( point_inputs, run_mem_encoder, high_res_masks, - object_score_logits, + object_score_logits.clone(), current_out, ) diff --git a/torchao/_models/sam2/sam2_image_predictor.py b/torchao/_models/sam2/sam2_image_predictor.py index 8fe01995ee..f404fe00e4 100644 --- a/torchao/_models/sam2/sam2_image_predictor.py +++ b/torchao/_models/sam2/sam2_image_predictor.py @@ -17,7 +17,7 @@ from torchao._models.sam2.utils.transforms import SAM2Transforms -class SAM2ImagePredictor: +class SAM2ImagePredictor(torch.nn.Module): def __init__( self, sam_model: SAM2Base, diff --git a/torchao/_models/sam2/sam2_video_predictor.py b/torchao/_models/sam2/sam2_video_predictor.py index c7e01ccf97..46ab610556 100644 --- a/torchao/_models/sam2/sam2_video_predictor.py +++ b/torchao/_models/sam2/sam2_video_predictor.py @@ -11,8 +11,8 @@ from tqdm import tqdm -from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base -from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames +from torchao._models.sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from torchao._models.sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames class SAM2VideoPredictor(SAM2Base): @@ -40,7 +40,21 @@ def __init__( self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond - @torch.inference_mode() + @staticmethod + def batch_inference_states(inference_states: list): + assert all(dict == type(state) for state in inference_states) + num_states = len(inference_states) + assert num_states > 0 + import copy + batched_inference_state = copy.copy(inference_states[0]) + + from torchao._models.sam2.map_tensor import to_map_tensor + # NOTE: Making a build assumption only images differ + all_images = torch.stack([state["images"] for state in inference_states]) + batched_inference_state["images"] = to_map_tensor(all_images) + return batched_inference_state + + @torch.no_grad() def init_state( self, video_path, @@ -169,7 +183,7 @@ def _get_obj_num(self, inference_state): """Get the total number of unique object ids received so far in this session.""" return len(inference_state["obj_idx_to_id"]) - @torch.inference_mode() + @torch.no_grad() def add_new_points_or_box( self, inference_state, @@ -317,7 +331,7 @@ def add_new_points(self, *args, **kwargs): """Deprecated method. Please use `add_new_points_or_box` instead.""" return self.add_new_points_or_box(*args, **kwargs) - @torch.inference_mode() + @torch.no_grad() def add_new_mask( self, inference_state, @@ -589,7 +603,7 @@ def _get_empty_mask_ptr(self, inference_state, frame_idx): ) return current_out["obj_ptr"] - @torch.inference_mode() + @torch.no_grad() def propagate_in_video_preflight(self, inference_state): """Prepare inference_state and consolidate temporary outputs before tracking.""" # Tracking has started and we don't allow adding new objects until session is reset. @@ -659,7 +673,7 @@ def propagate_in_video_preflight(self, inference_state): input_frames_inds.update(mask_inputs_per_frame.keys()) assert all_consolidated_frame_inds == input_frames_inds - @torch.inference_mode() + @torch.no_grad() def propagate_in_video( self, inference_state, @@ -773,7 +787,7 @@ def _add_output_per_object( obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] obj_output_dict[storage_key][frame_idx] = obj_out - @torch.inference_mode() + @torch.no_grad() def clear_all_prompts_in_frame( self, inference_state, frame_idx, obj_id, need_output=True ): @@ -844,7 +858,7 @@ def clear_all_prompts_in_frame( ) return frame_idx, obj_ids, video_res_masks - @torch.inference_mode() + @torch.no_grad() def reset_state(self, inference_state): """Remove all input points or mask in all frames throughout the video.""" self._reset_tracking_results(inference_state) @@ -909,6 +923,7 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size): features = (expanded_image,) + features return features + @torch.autograd.profiler.record_function("_run_single_frame_inference") def _run_single_frame_inference( self, inference_state, @@ -1038,7 +1053,7 @@ def _get_maskmem_pos_enc(self, inference_state, current_out): expanded_maskmem_pos_enc = None return expanded_maskmem_pos_enc - @torch.inference_mode() + @torch.no_grad() def remove_object(self, inference_state, obj_id, strict=False, need_output=True): """ Remove an object id from the tracking state. If strict is True, we check whether diff --git a/torchao/dtypes/README.md b/torchao/dtypes/README.md new file mode 100644 index 0000000000..c1124c648f --- /dev/null +++ b/torchao/dtypes/README.md @@ -0,0 +1,19 @@ +# README + +## File Structure of the `dtypes` Folder + +The `dtypes` folder contains several important files and subfolders that are organized as follows: + +- **affine_quantized_tensor.py**: This is the main file, from which the subfolders `uintx` and `floatx` inherit. It contains the base tensor subclass `AffineQuantizedTensor` and code for layout and tensorImpl registration. + +- **affine_quantized_tensor_ops.py**: This file defines all the overriden aten ops and different dispatch kernels related to affine quantized tensors. + +- **utils.py**: A utility file that provides helper functions and common utilities used across different files in the `dtypes` folder. + +- **nf4tensor.py**: This file is specific to the NF4 tensor implementation, and layouts. + +### Subfolders + +- **uintx**: A subfolder that contains layouts and tensor subclasses inheriting from `affine_quantized_tensor.py`. It is specialized for handling unsigned integer quantized tensors. + +- **floatx**: Similar to `uintx`, this subfolder contains layouts and tensor subclasses that inherit from `affine_quantized_tensor.py`, but it is focused on floating-point quantized tensors. diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d1fbacdcb4..c7d98cb56e 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,14 +1,12 @@ from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, - MarlinQQQTensor, to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, - to_marlinqqq_quantized_intx, ) from .floatx import ( Float8Layout, @@ -16,11 +14,14 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( BlockSparseLayout, + Int4CPULayout, MarlinQQQLayout, + MarlinQQQTensor, MarlinSparseLayout, SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, + to_marlinqqq_quantized_intx, ) from .utils import ( Layout, @@ -48,4 +49,5 @@ "UintxLayout", "MarlinQQQTensor", "MarlinQQQLayout", + "Int4CPULayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 93d2766d1e..7aca25ecc5 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -16,10 +16,8 @@ choose_qparams_affine, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, - choose_qparams_and_quantize_affine_qqq, dequantize_affine, dequantize_affine_floatx, - dequantize_affine_qqq, quantize_affine, quantize_affine_floatx, ) @@ -33,14 +31,12 @@ __all__ = [ "AffineQuantizedTensor", - "MarlinQQQTensor", "register_layout", "to_affine_quantized_intx", "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", "to_affine_quantized_fpx", - "to_marlinqqq_quantized_intx", ] @@ -459,57 +455,6 @@ def _apply_fn_to_data(self, fn): # 2 - we're given non-floats - quantizing long to int8 is crazy -class MarlinQQQTensor(AffineQuantizedTensor): - """ - MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. - - To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, - please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq - """ - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - int_data, s_group, s_channel = self.tensor_impl.get_plain() - nbits = int(math.log2(self.quant_max - self.quant_min + 1)) - group_size = max(self.block_size) - return dequantize_affine_qqq( - int_data, s_group, s_channel, nbits, group_size, output_dtype - ) - - @classmethod - def from_hp_to_intx( - cls, - input_float: torch.Tensor, - block_size: Tuple[int, ...], - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, - _layout: Optional[Layout] = None, - ): - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - nbits = int(math.log2(quant_max - quant_min + 1)) - group_size = max(block_size) - data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( - input_float, nbits, group_size - ) - data = _layout.post_process(data) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype, - ) - - ###################################################### # Layout and TensorImpl Subclass Registration # ###################################################### @@ -522,7 +467,6 @@ def from_hp_to_intx( to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx -to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bd7ff7d333..8938e7472c 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -20,7 +20,7 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) -from torchao.dtypes.uintx.marlin_qqq_layout import ( +from torchao.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, ) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index a6059f93a3..4b1f3d39c8 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,8 +1,10 @@ from .block_sparse_layout import ( BlockSparseLayout, ) -from .marlin_qqq_layout import ( +from .marlin_qqq_tensor import ( MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, ) from .marlin_sparse_layout import ( MarlinSparseLayout, @@ -11,6 +13,7 @@ SemiSparseLayout, ) from .tensor_core_tiled_layout import ( + Int4CPULayout, TensorCoreTiledLayout, ) from .uintx_layout import ( @@ -23,5 +26,8 @@ "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", + "Int4CPULayout", "MarlinQQQLayout", + "MarlinQQQTensor", + "to_marlinqqq_quantized_intx", ] diff --git a/torchao/dtypes/uintx/marlin_qqq_layout.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py similarity index 79% rename from torchao/dtypes/uintx/marlin_qqq_layout.py rename to torchao/dtypes/uintx/marlin_qqq_tensor.py index c3b2a78394..b75d959b41 100644 --- a/torchao/dtypes/uintx/marlin_qqq_layout.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -1,5 +1,7 @@ import logging +import math from dataclasses import dataclass +from typing import Optional, Tuple import torch from torch.utils._python_dispatch import ( @@ -8,18 +10,75 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.uintx.plain_layout import ( _aqt_is_int8_reduced_range, ) from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + choose_qparams_and_quantize_affine_qqq, + dequantize_affine_qqq, +) logger = logging.getLogger(__name__) aten = torch.ops.aten +class MarlinQQQTensor(AffineQuantizedTensor): + """ + MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + + To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + int_data, s_group, s_channel = self.tensor_impl.get_plain() + nbits = int(math.log2(self.quant_max - self.quant_min + 1)) + group_size = max(self.block_size) + return dequantize_affine_qqq( + int_data, s_group, s_channel, nbits, group_size, output_dtype + ) + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + block_size: Tuple[int, ...], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Optional[Layout] = None, + ): + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + nbits = int(math.log2(quant_max - quant_min + 1)) + group_size = max(block_size) + data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + input_float, nbits, group_size + ) + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + @dataclass(frozen=True) class MarlinQQQLayout(Layout): pass @@ -279,3 +338,6 @@ def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bi if bias is not None: out += bias.to(out.dtype) return out + + +to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index e2c94a7a38..a554fd9bc6 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -41,13 +41,18 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( w_vals_int8 = weight_tensor.tensor_impl.int_data w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # must pad + row, col = tmp.shape + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( w_vals_int8, - tmp.t(), + tmp_padded.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t() + ).t()[:row, :] y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] ) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index ced3fc8922..df79b653e8 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -13,7 +13,12 @@ ) from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, fill_defaults, find_multiple +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + fill_defaults, + find_multiple, +) aten = torch.ops.aten @@ -71,9 +76,14 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) + if is_device(input_tensor.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] @@ -383,3 +393,251 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout(self) -> Layout: return self._layout + + +@dataclass(frozen=True) +class Int4CPULayout(Layout): + """Only for PyTorch version at least 2.6""" + + pass + + +@register_layout(Int4CPULayout) +class Int4CPUAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm_for_cpu` + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + (unpacked Tensor shape is n * k) + Note: we also pack scale and zero point together here for tinygemm kernel + Note: technically Int4 CPU layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Int4CPULayout) + + if TORCH_VERSION_AT_LEAST_2_6: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) + elif TORCH_VERSION_AT_LEAST_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + else: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if not is_device(torch.device(self.device).type, device): + raise ValueError( + f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device), + self.transposed, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = Int4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + + cur_shape = self.shape + assert len(cur_shape) == 2 + original_shape = (cur_shape[0], cur_shape[1] * 2) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + zero_point_domain = ZeroPointDomain.FLOAT + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index eea7e42666..7764c0871f 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -37,9 +37,9 @@ */ #ifdef USE_ATEN -using namespace at::native::mps; +using at::native::mps::MetalShaderLibrary; #else -#include +#include #endif static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT( diff --git a/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h new file mode 100644 index 0000000000..3aca35e699 --- /dev/null +++ b/torchao/experimental/kernels/mps/src/MetalShaderLibrary.h @@ -0,0 +1,64 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +class MetalShaderLibrary { + public: + MetalShaderLibrary(const std::string& src) : shaderSource(src) { + lib = compileLibraryFromSource(shaderSource); + } + MetalShaderLibrary(const MetalShaderLibrary&) = delete; + MetalShaderLibrary(MetalShaderLibrary&&) = delete; + + id getPipelineStateForFunc( + const std::string& fname) { + id func = loadFunc(fname); + + NSError* error = nil; + id device = get_metal_device(); + auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; + if (cpl == nil) { + throw std::runtime_error( + "Failed to construct pipeline state: " + + std::string(error.description.UTF8String)); + } + return cpl; + + } + + private: + std::string shaderSource; + id lib = nil; + + id loadFunc(const std::string& func_name) const { + id func = [lib + newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; + if (func == nil) { + throw std::runtime_error("Can't get function:" + func_name); + } + return func; + } + + id compileLibraryFromSource( + const std::string& source) { + NSError* error = nil; + MTLCompileOptions* options = [MTLCompileOptions new]; + [options setLanguageVersion:MTLLanguageVersion3_1]; + NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; + id device = get_metal_device(); + id library = [device newLibraryWithSource:kernel_source + options:options + error:&error]; + if (library == nil) { + throw std::runtime_error( + "Failed to compile: " + std::string(error.description.UTF8String)); + } + return library; + } +}; diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.h b/torchao/experimental/kernels/mps/src/OperationUtils.h index 7cb902f23f..5a41b264af 100644 --- a/torchao/experimental/kernels/mps/src/OperationUtils.h +++ b/torchao/experimental/kernels/mps/src/OperationUtils.h @@ -6,101 +6,12 @@ #pragma once -#include -#include - -static void throw_exception(const std::string& str) { - std::cerr << str << std::endl; - throw std::runtime_error(str); -} - -inline void dispatch_block( - [[maybe_unused]] id queue, - void (^block)()) { - __block std::optional block_exception; - try { - block(); - } catch (...) { - block_exception = std::current_exception(); - } - if (block_exception) { - std::rethrow_exception(*block_exception); - } -} - -inline id getMetalDevice() { - @autoreleasepool { - NSArray* devices = [MTLCopyAllDevices() autorelease]; - if (devices.count == 0) { - throw_exception("Metal is not supported"); - } - return devices[0]; - } -} - -static id MTL_DEVICE = getMetalDevice(); - -static id compileLibraryFromSource( - id device, - const std::string& source) { - NSError* error = nil; - MTLCompileOptions* options = [MTLCompileOptions new]; - [options setLanguageVersion:MTLLanguageVersion3_1]; - NSString* kernel_source = [NSString stringWithUTF8String:source.c_str()]; - id library = [device newLibraryWithSource:kernel_source - options:options - error:&error]; - if (library == nil) { - throw_exception( - "Failed to compile: " + std::string(error.description.UTF8String)); - } - return library; -} - -class MetalShaderLibrary { - public: - MetalShaderLibrary(const std::string& src) : shaderSource(src) { - lib = compileLibraryFromSource(device, shaderSource); - } - MetalShaderLibrary(const MetalShaderLibrary&) = delete; - MetalShaderLibrary(MetalShaderLibrary&&) = delete; - - id getPipelineStateForFunc( - const std::string& fname) { - return get_compute_pipeline_state(load_func(fname)); - } - - private: - std::string shaderSource; - id device = MTL_DEVICE; - id lib = nil; - - id load_func(const std::string& func_name) const { - id func = [lib - newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; - if (func == nil) { - throw_exception("Can't get function:" + func_name); - } - return func; - } - - id get_compute_pipeline_state( - id func) const { - NSError* error = nil; - auto cpl = [device newComputePipelineStateWithFunction:func error:&error]; - if (cpl == nil) { - throw_exception( - "Failed to construct pipeline state: " + - std::string(error.description.UTF8String)); - } - return cpl; - } -}; +id getMetalDevice(); class MPSStream { public: MPSStream() { - _commandQueue = [MTL_DEVICE newCommandQueue]; + _commandQueue = [getMetalDevice() newCommandQueue]; } ~MPSStream() { @@ -136,14 +47,6 @@ class MPSStream { id _commandEncoder = nil; }; -inline void finalize_block(MPSStream* mpsStream) { - id encoder = mpsStream->commandEncoder(); - id cmdBuffer = mpsStream->commandBuffer(); - [encoder endEncoding]; - [cmdBuffer commit]; - [cmdBuffer waitUntilCompleted]; -} - inline MPSStream* getCurrentMPSStream() { return new MPSStream(); } diff --git a/torchao/experimental/kernels/mps/src/OperationUtils.mm b/torchao/experimental/kernels/mps/src/OperationUtils.mm new file mode 100644 index 0000000000..795c93225a --- /dev/null +++ b/torchao/experimental/kernels/mps/src/OperationUtils.mm @@ -0,0 +1,20 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include + +id getMetalDevice() { + @autoreleasepool { + NSArray* devices = [MTLCopyAllDevices() autorelease]; + if (devices.count == 0) { + throw std::runtime_error("Metal is not supported"); + } + static id MTL_DEVICE = devices[0]; + return MTL_DEVICE; + } +} diff --git a/torchao/experimental/kernels/mps/src/common.h b/torchao/experimental/kernels/mps/src/common.h new file mode 100644 index 0000000000..0710d37b3a --- /dev/null +++ b/torchao/experimental/kernels/mps/src/common.h @@ -0,0 +1,51 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#ifdef USE_ATEN +#include +using namespace at::native::mps; +#elif defined(USE_EXECUTORCH) +#include +using namespace executorch::backends::mps::delegate; +#else +#include +#endif + +inline void dispatch_block( + MPSStream* mpsStream, + void (^block)()) { +#if defined(USE_ATEN) + dispatch_sync_with_rethrow(mpsStream->queue(), block); +#elif defined(USE_EXECUTORCH) + dispatch_sync(mpsStream->queue(), block); +#else + (void)mpsStream; + block(); +#endif +} + +inline void optionally_wait_for_command_completion(MPSStream* mpsStream) { +#if defined(USE_ATEN) +#elif defined(USE_EXECUTORCH) + ET_CHECK(mpsStream->synchronize(SyncType::COMMIT_AND_WAIT) == executorch::runtime::Error::Ok); +#else + id encoder = mpsStream->commandEncoder(); + id cmdBuffer = mpsStream->commandBuffer(); + [encoder endEncoding]; + [cmdBuffer commit]; + [cmdBuffer waitUntilCompleted]; +#endif +} + +inline id get_metal_device() { +#if defined(USE_ATEN) || defined(USE_EXECUTORCH) + return MPSDevice::getInstance()->device(); +#else + return getMetalDevice(); +#endif +} diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index d37001350a..ae3951e217 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -9,24 +9,11 @@ #include #include +#include #include -#include +#include // metal_lowbit_quantized_lib #include -#include -#include -#include - -#ifdef USE_ATEN -#include -using namespace at::native::mps; -inline void finalize_block(MPSStream* mpsStream) {} -void (*dispatch_block)(dispatch_queue_t, dispatch_block_t) = - dispatch_sync_with_rethrow; -#else -#include -#endif - namespace torchao::kernels::mps::lowbit { namespace { @@ -103,7 +90,7 @@ inline void linear_lowbit_quant_weights_mps_impl( 0}; MPSStream* mpsStream = getCurrentMPSStream(); - dispatch_block(mpsStream->queue(), ^() { + dispatch_block(mpsStream, ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); id cpl = @@ -119,7 +106,7 @@ inline void linear_lowbit_quant_weights_mps_impl( length:sizeof(uint32_t) * sizes.size() atIndex:5]; dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K); - finalize_block(mpsStream); + optionally_wait_for_command_completion(mpsStream); } }); } diff --git a/torchao/experimental/kernels/mps/test/Makefile b/torchao/experimental/kernels/mps/test/Makefile index e8213818c5..3c0da54f7c 100644 --- a/torchao/experimental/kernels/mps/test/Makefile +++ b/torchao/experimental/kernels/mps/test/Makefile @@ -1,7 +1,7 @@ all: test_lowbit -test_lowbit: test_lowbit.mm - clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $< -framework Metal -framework Foundation +test_lowbit: test_lowbit.mm ../src/OperationUtils.mm + clang++ -I${TORCHAO_ROOT} -O3 -std=c++17 -Wall -Wextra -o $@ $^ -framework Metal -framework Foundation run: test_lowbit ./test_lowbit diff --git a/torchao/experimental/kernels/mps/test/test_lowbit.mm b/torchao/experimental/kernels/mps/test/test_lowbit.mm index 2d86223034..7fb20d254a 100644 --- a/torchao/experimental/kernels/mps/test/test_lowbit.mm +++ b/torchao/experimental/kernels/mps/test/test_lowbit.mm @@ -31,7 +31,7 @@ id rc = [device newBufferWithLength:length options:MTLResourceStorageModeShared]; if (rc == nil) { - throw_exception( + throw std::runtime_error( "Can't allocate " + std::to_string(length) + " bytes on GPU"); } return rc; @@ -80,7 +80,7 @@ void reference_linear_lowbit_quant_weights_cpu( : M(m), K(k), N(n), qGroupSize(group_size) {} void init() { - allocBuffers(MTL_DEVICE); + allocBuffers(getMetalDevice()); T* a_ptr = reinterpret_cast([buf_A contents]); uint8_t* w_ptr = reinterpret_cast([buf_W contents]); diff --git a/torchao/experimental/ops/mps/CMakeLists.txt b/torchao/experimental/ops/mps/CMakeLists.txt index 044433ef95..820205fa27 100644 --- a/torchao/experimental/ops/mps/CMakeLists.txt +++ b/torchao/experimental/ops/mps/CMakeLists.txt @@ -26,10 +26,14 @@ endif() find_package(Torch REQUIRED) # Generate metal_shader_lib.h by running gen_metal_shader_lib.py +set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal) +file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal) +set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py) set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h) add_custom_command( OUTPUT ${GENERATED_METAL_SHADER_LIB} - COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB} + COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB} + DEPENDS ${METAL_FILES} ${GEN_SCRIPT} COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py" ) add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB}) @@ -41,7 +45,7 @@ message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") include_directories(${TORCHAO_INCLUDE_DIRS}) include_directories(${CMAKE_INSTALL_PREFIX}/include) -add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm) +add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten OBJECT linear_fp_act_xbit_weight_aten.mm) add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib) target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") @@ -53,8 +57,25 @@ find_library(METAL_LIB Metal) find_library(FOUNDATION_LIB Foundation) target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) -install( - TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten - EXPORT _targets - DESTINATION lib +add_library(torchao_ops_mps_aten SHARED) +target_link_libraries(torchao_ops_mps_aten PRIVATE + torchao_ops_mps_linear_fp_act_xbit_weight_aten ) +install(TARGETS torchao_ops_mps_aten DESTINATION lib) + +if(TORCHAO_BUILD_EXECUTORCH_OPS) + include_directories(${CMAKE_INSTALL_PREFIX}/../..) + include_directories(${CMAKE_INSTALL_PREFIX}/schema/include) + include_directories(${CMAKE_INSTALL_PREFIX}/../third-party/flatbuffers/include) + add_library(torchao_ops_mps_linear_fp_act_xbit_weight_executorch OBJECT linear_fp_act_xbit_weight_executorch.mm) + add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_executorch generated_metal_shader_lib) + target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE executorch executorch_core mpsdelegate) + target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_executorch PRIVATE ${METAL_LIB} ${FOUNDATION_LIB}) + + add_library(torchao_ops_mps_executorch STATIC) + target_link_libraries(torchao_ops_mps_executorch PRIVATE + torchao_ops_mps_linear_fp_act_xbit_weight_executorch + ) + install(TARGETS torchao_ops_mps_executorch DESTINATION lib) +endif() diff --git a/torchao/experimental/ops/mps/aten/register.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm similarity index 78% rename from torchao/experimental/ops/mps/aten/register.mm rename to torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index 92a3ba89f0..e11e55c5a0 100644 --- a/torchao/experimental/ops/mps/aten/register.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -70,12 +70,13 @@ void check_linear_mps_args( } template -Tensor linear_mps_kernel( +Tensor linear_mps_kernel_out( const Tensor& A, const Tensor& B, int64_t group_size, const Tensor& S, - const Tensor& Z) { + const Tensor& Z, + Tensor& C) { TORCH_CHECK( A.is_mps(), __func__, ": A is on ", A.device(), " but expected on mps"); TORCH_CHECK( @@ -84,6 +85,8 @@ Tensor linear_mps_kernel( S.is_mps(), __func__, ": S is on ", S.device(), " but expected on mps"); TORCH_CHECK( Z.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); + TORCH_CHECK( + C.is_mps(), __func__, ": Z is on ", Z.device(), " but expected on mps"); check_linear_mps_args(A, B, group_size, S, Z); @@ -91,8 +94,6 @@ Tensor linear_mps_kernel( auto N = B.size(0); auto K = A.size(1); - auto C = at::empty({M, N}, A.options()); - LowBitQuantWeights::linear( getMTLBufferStorage(A), getMTLBufferStorage(B), @@ -108,6 +109,19 @@ Tensor linear_mps_kernel( return C; } +template +Tensor linear_mps_kernel( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto M = A.size(0); + auto N = B.size(0); + auto C = at::empty({M, N}, A.options()); + return linear_mps_kernel_out(A, B, group_size, S, Z, C); +} + template Tensor linear_mps_kernel_meta( const Tensor& A, @@ -169,6 +183,20 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { "_linear_fp_act_6bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); m.def( "_linear_fp_act_7bit_weight(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z) -> Tensor"); + m.def( + "_linear_fp_act_1bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_2bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_3bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_4bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_5bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_6bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "_linear_fp_act_7bit_weight.out(Tensor A, Tensor B, int group_size, Tensor S, Tensor Z, *, Tensor(a!) out) -> Tensor(a!)"); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -189,6 +217,13 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { m.impl("_linear_fp_act_5bit_weight", &linear_mps_kernel<5>); m.impl("_linear_fp_act_6bit_weight", &linear_mps_kernel<6>); m.impl("_linear_fp_act_7bit_weight", &linear_mps_kernel<7>); + m.impl("_linear_fp_act_1bit_weight.out", &linear_mps_kernel_out<1>); + m.impl("_linear_fp_act_2bit_weight.out", &linear_mps_kernel_out<2>); + m.impl("_linear_fp_act_3bit_weight.out", &linear_mps_kernel_out<3>); + m.impl("_linear_fp_act_4bit_weight.out", &linear_mps_kernel_out<4>); + m.impl("_linear_fp_act_5bit_weight.out", &linear_mps_kernel_out<5>); + m.impl("_linear_fp_act_6bit_weight.out", &linear_mps_kernel_out<6>); + m.impl("_linear_fp_act_7bit_weight.out", &linear_mps_kernel_out<7>); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm new file mode 100644 index 0000000000..2892a67245 --- /dev/null +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_executorch.mm @@ -0,0 +1,138 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::backends::mps::delegate::getMTLBufferStorage; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::tensor_is_rank; + +namespace { + +std::string scalar_type_to_string(const ScalarType& scalar_type) { + switch (scalar_type) { + case ScalarType::Float: + return "float"; + case ScalarType::Half: + return "half"; + case ScalarType::BFloat16: + return "bfloat"; + default: + ET_CHECK_MSG( + false, "Unsupported type by lowbit quantized linear"); + return "undefined"; + } +} + +template +bool check_linear_mps_args( + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z) { + auto N = B.size(0); + auto K = A.size(1); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + A.scalar_type() == ScalarType::BFloat16 || + A.scalar_type() == ScalarType::Half || + A.scalar_type() == ScalarType::Float, + "Expect A to be either 32-bit or 16-bit float tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + tensor_is_rank(A, 2), "Expect A to be 2D tensor."); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.scalar_type() == ScalarType::Byte, "Expect B to be uint8 tensor."); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + B.size(1) == (K / 8) * nbit, "Expect B.size(1) == (K / 8) * nbit"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE(K % 8 == 0, "Expect K to be multiple of 8"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + group_size == 32 || group_size == 64 || group_size == 128 || + group_size == 256, + "Expect group_size to be 32, 64, 128 or 256"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + S.dim() == 2 && S.size(1) == N, + "Expect S to be 2d tensor with shape [:, N]"); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + Z.dim() == 2 && Z.size(1) == N, + "Expect Z to be 2d tensor with shape [:, N]"); + + return true; +} + +template +Tensor& linear_mps_kernel_et_ctx_out( + KernelRuntimeContext& ctx, + const Tensor& A, + const Tensor& B, + int64_t group_size, + const Tensor& S, + const Tensor& Z, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + check_linear_mps_args(A, B, group_size, S, Z), + InvalidArgument, + out); + + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + + torchao::kernels::mps::lowbit::LowBitQuantWeights::linear( + getMTLBufferStorage(A), + getMTLBufferStorage(B), + group_size, + getMTLBufferStorage(S), + getMTLBufferStorage(Z), + getMTLBufferStorage(out), + M, + K, + N, + scalar_type_to_string(A.scalar_type())); + + return out; +} + +} // namespace + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_1bit_weight.out", linear_mps_kernel_et_ctx_out<1>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_2bit_weight.out", linear_mps_kernel_et_ctx_out<2>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_3bit_weight.out", linear_mps_kernel_et_ctx_out<3>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_4bit_weight.out", linear_mps_kernel_et_ctx_out<4>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_5bit_weight.out", linear_mps_kernel_et_ctx_out<5>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_6bit_weight.out", linear_mps_kernel_et_ctx_out<6>); +} + +namespace { +EXECUTORCH_LIBRARY(torchao, "_linear_fp_act_7bit_weight.out", linear_mps_kernel_et_ctx_out<7>); +} diff --git a/torchao/experimental/ops/mps/test/test_lowbit.py b/torchao/experimental/ops/mps/test/test_lowbit.py index f4c460a368..acff5624c8 100644 --- a/torchao/experimental/ops/mps/test/test_lowbit.py +++ b/torchao/experimental/ops/mps/test/test_lowbit.py @@ -11,7 +11,7 @@ from parameterized import parameterized -libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) ) diff --git a/torchao/experimental/ops/mps/test/test_quantizer.py b/torchao/experimental/ops/mps/test/test_quantizer.py index 00c08738c2..5b3331c6a8 100644 --- a/torchao/experimental/ops/mps/test/test_quantizer.py +++ b/torchao/experimental/ops/mps/test/test_quantizer.py @@ -17,7 +17,7 @@ from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer from torchao.experimental.quant_api import _quantize -libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib" +libname = "libtorchao_ops_mps_aten.dylib" libpath = os.path.abspath( os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname) ) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index be72a59aab..0904d1d174 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -469,21 +469,19 @@ def quantize(self, model: nn.Module) -> nn.Module: return model -from torchao.experimental._linear_8bit_act_xbit_weight_layout import Linear8BitActXBitWeightLayout -from torchao.quantization.quant_api import ( - _get_linear_subclass_inserter, - MappingType, - to_affine_quantized_intx, - ZeroPointDomain, -) - - def int8_dynamic_activation_intx_weight( group_size: int = 128, nbit: int = 4, has_weight_zeros: bool = False, target: str = "native", ): + from torchao.experimental._linear_8bit_act_xbit_weight_layout import Linear8BitActXBitWeightLayout + from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + MappingType, + to_affine_quantized_intx, + ZeroPointDomain, + ) def apply(weight): assert weight.shape[-1] % group_size == 0 @@ -541,10 +539,11 @@ def quantize_and_pack_weights(self, weights, nbit, group_size): ) weight_scales = torch.transpose_copy(weight_scales, 1, 0) weight_zeros = torch.transpose_copy(weight_zeros, 1, 0) - self.weight_scales = weight_scales - self.weight_zeros = -weight_zeros * weight_scales - - self.packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + weight_zeros = -weight_zeros * weight_scales + self.weight_scales = nn.Parameter(weight_scales, requires_grad=False) + self.weight_zeros = nn.Parameter(weight_zeros, requires_grad=False) + packed_weights = self._pack_weights_op(weight_qvals.cpu()).to(device="mps") + self.packed_weights = nn.Parameter(packed_weights, requires_grad=False) def forward(self, x): assert x.dim() >= 2 diff --git a/torchao/experimental/temp_build.py b/torchao/experimental/temp_build.py new file mode 100644 index 0000000000..fb9d413037 --- /dev/null +++ b/torchao/experimental/temp_build.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import subprocess +import tempfile +import torch + +def cmake_build_torchao_ops(cmake_lists_path, temp_build_dir): + from distutils.sysconfig import get_python_lib + print("Building torchao ops for ATen target") + cmake_prefix_path = get_python_lib() + subprocess.run( + [ + "cmake", + "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, + "-DCMAKE_INSTALL_PREFIX=" + temp_build_dir.name, + "-S " + cmake_lists_path, + "-B " + temp_build_dir.name, + ] + ) + subprocess.run( + [ + "cmake", + "--build", + temp_build_dir.name, + "-j 16", + "--target install", + "--config Release", + ] + ) + +def temp_build_and_load_torchao_ops(cmake_lists_path): + temp_build_dir = tempfile.TemporaryDirectory() + cmake_build_torchao_ops(cmake_lists_path, temp_build_dir) + libs = glob.glob(f"{temp_build_dir.name}/lib/libtorchao_ops_aten.*") + libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) + assert len(libs) == 1 + torch.ops.load_library(libs[0]) + print(f"TorchAO ops are loaded from {libs[0]}") diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 175ab03f3c..d4a5516154 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -53,6 +53,35 @@ def short_str(self): return "axs" +@dataclass +class Float8TypeConfig: + """ + Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz. + + Currently, ROCm only supports fnuz variants. + """ + + # The preferred e4m3 type. + e4m3_dtype = torch.float8_e4m3fn + + # The preferred e5m2 type. + e5m2_dtype = torch.float8_e5m2 + + def __post_init__(self): + if torch.version.hip and torch.cuda.is_available(): + prop = torch.cuda.get_device_properties(0) + MI300_ARCH = ("gfx940", "gfx941", "gfx942") + if prop.gcnArchName.split(":")[0] in MI300_ARCH: + self.e4m3_dtype = torch.float8_e4m3fnuz + self.e5m2_dtype = torch.float8_e5m2fnuz + + +# User defined type for using the individual F8 type based on config +type_config = Float8TypeConfig() +e4m3_dtype = type_config.e4m3_dtype +e5m2_dtype = type_config.e5m2_dtype + + @dataclass(frozen=True) class CastConfig: """ @@ -62,9 +91,11 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE static_scale: Optional[torch.Tensor] = None + target_dtype: Optional[torch.dtype] = None def short_str(self): - return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}" + dtype = {e4m3_dtype: "e4m3", e5m2_dtype: "e5m2"}[self.target_dtype] + return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}" def __post_init__(self): if self.scaling_type is ScalingType.STATIC: @@ -75,6 +106,9 @@ def __post_init__(self): assert ( self.scaling_type is ScalingType.DYNAMIC ), "only dynamic scaling type is supported for axiswise scaling granularity" + assert self.target_dtype is None or ( + self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1 + ), "must specify a 8-bit floating-point dtype" @dataclass(frozen=True) @@ -101,29 +135,6 @@ def __post_init__(self): ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." -@dataclass -class Float8TypeConfig: - """ - Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz. - - Currently, ROCm only supports fnuz variants. - """ - - # The preferred e4m3 type. - e4m3_dtype = torch.float8_e4m3fn - - # The preferred e5m2 type. - e5m2_dtype = torch.float8_e5m2 - - def __post_init__(self): - if torch.version.hip and torch.cuda.is_available(): - prop = torch.cuda.get_device_properties(0) - MI300_ARCH = ("gfx940", "gfx941", "gfx942") - if prop.gcnArchName.split(":")[0] in MI300_ARCH: - self.e4m3_dtype = torch.float8_e4m3fnuz - self.e5m2_dtype = torch.float8_e5m2fnuz - - @dataclass(frozen=True) class Float8GemmConfig: """ @@ -170,7 +181,6 @@ class Float8LinearConfig: # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` - # TODO(this PR): throw warning if fast_accum False is used with axiswise scaling # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() @@ -277,6 +287,20 @@ def __post_init__(self): is_disabled_1 == is_disabled_2 ), f"incompatible operand precision for {gemm_name}" + for cc1, cc2, operand_name, default_dtype in [ + (cc_i, cc_i_gw, "input", e4m3_dtype), + (cc_w, cc_w_gi, "weight", e4m3_dtype), + (cc_go, cc_go_gw, "grad_output", e5m2_dtype), + ]: + # Override the dataclass being frozen + if cc1.target_dtype is None: + object.__setattr__(cc1, "target_dtype", default_dtype) + if cc2.target_dtype is None: + object.__setattr__(cc2, "target_dtype", default_dtype) + assert ( + cc1.target_dtype == cc2.target_dtype + ), f"{operand_name} must be cast to the same dtype in both matmuls it's used in" + if self.use_fp8_all_gather_only: assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True" @@ -317,21 +341,10 @@ def recipe_name_to_linear_config( cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only - # fast with `use_fast_accum=True`. Note that rowwise scaling is more - # accurate than tensorwise scaling, so the overall impact on accuracy - # of tensorwise vs rowwise taking this flag into account will vary. - gc_o = Float8GemmConfig(use_fast_accum=True) - gc_gi = Float8GemmConfig(use_fast_accum=True) - gc_gw = Float8GemmConfig(use_fast_accum=True) - return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, - gemm_config_output=gc_o, - gemm_config_grad_input=gc_gi, - gemm_config_grad_weight=gc_gw, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: @@ -346,26 +359,23 @@ def recipe_name_to_linear_config( # * `input`, `weight` and `grad_output` now only need to be scaled # axiswise across a single dim compared to vanilla all-axiswise, # which is more amenable to fast kernels + # * the e4m3 dtype is used across the board, including for gradients # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) - cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) - - # The current rowwise CUTLASS kernels in `torch._scaled_mm` are only - # fast with `use_fast_accum=True`. Note that rowwise scaling is more - # accurate than tensorwise scaling, so the overall impact on accuracy - # of tensorwise vs rowwise taking this flag into account will vary. - gc_o = Float8GemmConfig(use_fast_accum=True) - gc_gi = Float8GemmConfig(use_fast_accum=True) - gc_gw = Float8GemmConfig(use_fast_accum=True) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) return Float8LinearConfig( cast_config_input=cc_i, @@ -374,9 +384,6 @@ def recipe_name_to_linear_config( cast_config_input_for_grad_weight=cc_i_gw, cast_config_weight_for_grad_input=cc_w_gi, cast_config_grad_output_for_grad_weight=cc_go_gw, - gemm_config_output=gc_o, - gemm_config_grad_input=gc_gi, - gemm_config_grad_weight=gc_gw, ) else: diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 776de917f1..d412519c36 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -15,9 +15,9 @@ from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDelayed, - NoopFwToFloat8E5M2BwDynamic, - NoopFwToFloat8E5M2BwStatic, + NoopFwToFloat8BwDelayed, + NoopFwToFloat8BwDynamic, + NoopFwToFloat8BwStatic, _maybe_initialize_amaxes_scales_for_float8_cast, get_maybe_axiswise_dim, hp_tensor_to_float8_delayed, @@ -32,8 +32,6 @@ hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - e4m3_dtype, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -136,7 +134,7 @@ def forward( else: input_maybe_fp8 = hp_tensor_to_float8_dynamic( input_hp, - e4m3_dtype, + c.cast_config_input.target_dtype, linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input.scaling_granularity, @@ -150,7 +148,7 @@ def forward( else: weight_maybe_fp8_t = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight.target_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight.scaling_granularity, @@ -186,7 +184,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output.scaling_granularity, @@ -204,7 +202,7 @@ def backward(ctx, grad_output): # the entire tensor. weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic( weight_hp_t, - e4m3_dtype, + c.cast_config_weight_for_grad_input.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, scaling_granularity=c.cast_config_weight_for_grad_input.scaling_granularity, @@ -236,7 +234,7 @@ def backward(ctx, grad_output): else: grad_output_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( grad_output_reshaped, - e5m2_dtype, + c.cast_config_grad_output_for_grad_weight.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.GRAD_OUTPUT, scaling_granularity=c.cast_config_grad_output_for_grad_weight.scaling_granularity, @@ -250,7 +248,7 @@ def backward(ctx, grad_output): else: input_reshaped_maybe_fp8_dim1 = hp_tensor_to_float8_dynamic( input_hp_reshaped, - e4m3_dtype, + c.cast_config_input_for_grad_weight.target_dtype, ctx.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, scaling_granularity=c.cast_config_input_for_grad_weight.scaling_granularity, @@ -347,11 +345,11 @@ def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len device = self.weight.device - # TODO(future PR): dtype values below don't have the other float8 - # flavors, fix it - default_input = torch.finfo(torch.float8_e4m3fn).max - default_weight = torch.finfo(torch.float8_e4m3fn).max - default_grad_output = torch.finfo(torch.float8_e5m2).max + default_input = torch.finfo(self.config.cast_config_input.target_dtype).max + default_weight = torch.finfo(self.config.cast_config_weight.target_dtype).max + default_grad_output = torch.finfo( + self.config.cast_config_grad_output.target_dtype + ).max # Note: for now, create all the buffers if any are needed, to postpone # the work to make the scale and amax syncing and history calculation @@ -438,14 +436,14 @@ def cast_input_to_float8( self.fp8_amax_history_input, self.fp8_scale_input, scale_fn_name, - e4m3_dtype, + self.config.cast_config_input.target_dtype, is_amax_initialized, reduce_amax=True, ) input_fp8 = hp_tensor_to_float8_delayed( input, self.fp8_scale_input, - e4m3_dtype, + self.config.cast_config_input.target_dtype, self.fp8_amax_input, linear_mm_config=self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, @@ -453,14 +451,17 @@ def cast_input_to_float8( elif self.scaling_type_input is ScalingType.DYNAMIC: input_fp8 = hp_tensor_to_float8_dynamic( input, - e4m3_dtype, + self.config.cast_config_input.target_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) else: assert self.scaling_type_input is ScalingType.STATIC input_fp8 = hp_tensor_to_float8_static( - input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config + input, + self.fp8_static_scale_input, + self.config.cast_config_input.target_dtype, + self.linear_mm_config, ) return input_fp8 @@ -476,14 +477,14 @@ def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: self.fp8_amax_history_weight, self.fp8_scale_weight, scale_fn_name, - e4m3_dtype, + self.config.cast_config_weight.target_dtype, self.is_amax_initialized, reduce_amax=True, ) self.fp8_amax_weight.fill_(tensor_to_amax(weight)) return self.fp8_scale_weight elif self.scaling_type_weight is ScalingType.DYNAMIC: - return tensor_to_scale(weight, e4m3_dtype) + return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype) else: assert self.scaling_type_weight is ScalingType.STATIC return self.fp8_static_scale_weight @@ -499,7 +500,7 @@ def cast_weight_to_float8_t( weight_fp8 = hp_tensor_and_scale_to_float8( weight, weight_scale, - e4m3_dtype, + self.config.cast_config_weight.target_dtype, self.linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ) @@ -514,7 +515,7 @@ def cast_weight_to_original_t(self, weight: torch.Tensor): def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is ScalingType.DELAYED: scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8E5M2BwDelayed.apply( + output = NoopFwToFloat8BwDelayed.apply( output, self.fp8_amax_grad_output, self.fp8_amax_history_grad_output, @@ -522,15 +523,21 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: scale_fn_name, self.is_amax_initialized, self.linear_mm_config, + self.config.cast_config_grad_output.target_dtype, ) elif self.scaling_type_grad_output is ScalingType.DYNAMIC: - output = NoopFwToFloat8E5M2BwDynamic.apply(output, self.linear_mm_config) + output = NoopFwToFloat8BwDynamic.apply( + output, + self.linear_mm_config, + self.config.cast_config_grad_output.target_dtype, + ) else: assert self.scaling_type_grad_output is ScalingType.STATIC - output = NoopFwToFloat8E5M2BwStatic.apply( + output = NoopFwToFloat8BwStatic.apply( output, self.fp8_static_scale_grad_output, self.linear_mm_config, + self.config.cast_config_grad_output.target_dtype, ) return output @@ -547,19 +554,16 @@ def float8_post_forward(self): return def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor: - has_any_axiswise_scaling = ( - self.config.cast_config_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_input_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_weight_for_grad_input.scaling_granularity - is ScalingGranularity.AXISWISE - or self.config.cast_config_grad_output_for_grad_weight.scaling_granularity - is ScalingGranularity.AXISWISE + has_any_axiswise_scaling = any( + cc.scaling_granularity is ScalingGranularity.AXISWISE + for cc in [ + self.config.cast_config_input, + self.config.cast_config_weight, + self.config.cast_config_grad_output, + self.config.cast_config_input_for_grad_weight, + self.config.cast_config_weight_for_grad_input, + self.config.cast_config_grad_output_for_grad_weight, + ] ) if not has_any_axiswise_scaling: @@ -682,6 +686,7 @@ def from_float( WeightWithDynamicFloat8CastTensor( new_mod.weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.target_dtype, ) ) elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: @@ -692,6 +697,7 @@ def from_float( new_mod.fp8_amax_history_weight, new_mod.fp8_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.target_dtype, new_mod.is_amax_initialized, ) ) @@ -702,6 +708,7 @@ def from_float( new_mod.weight, new_mod.fp8_static_scale_weight, new_mod.linear_mm_config, + new_mod.config.cast_config_weight.target_dtype, ) ) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index c4fc88eb37..64d2f7bc63 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -15,8 +15,6 @@ from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_utils import ( amax_history_to_scale_stack, - e4m3_dtype, - e5m2_dtype, ) log = logging.getLogger(__name__) @@ -227,6 +225,9 @@ def inner_func(): fp8_weight_amax_history_stack = [None] * len(fp8_layers) fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) + input_dtypes = set() + weight_dtypes = set() + grad_output_dtypes = set() scale_fn_recipes = set() for idx, child in enumerate(fp8_layers): @@ -238,8 +239,15 @@ def inner_func(): fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output + input_dtypes.add(child.config.cast_config_input.target_dtype) + weight_dtypes.add(child.config.cast_config_weight.target_dtype) + grad_output_dtypes.add(child.config.cast_config_grad_output.target_dtype) scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) + (input_dtype,) = input_dtypes + (weight_dtype,) = weight_dtypes + (grad_output_dtype,) = grad_output_dtypes + if len(scale_fn_recipes) != 1: raise ValueError( f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" @@ -297,13 +305,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, e4m3_dtype, scale_fn_recipe + fp8_input_amax_history_stack, input_dtype, scale_fn_recipe ) new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, e4m3_dtype, scale_fn_recipe + fp8_weight_amax_history_stack, weight_dtype, scale_fn_recipe ) new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, e5m2_dtype, scale_fn_recipe + fp8_grad_output_amax_history_stack, grad_output_dtype, scale_fn_recipe ) # Iterate through the layers and update the scales diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 921d50e093..2af4160de4 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -85,7 +85,10 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None): ) def float8_transpose(aten_op, args, kwargs=None): new_data = aten_op(args[0]._data, *args[1:], **kwargs) - new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + if args[0]._scale.ndim > 1: + new_scale = aten_op(args[0]._scale, *args[1:], **kwargs) + else: + new_scale = args[0]._scale if aten_op == aten.transpose.int: _assert_tensorwise_scale(aten_op, args[0]._scale) diff --git a/torchao/float8/float8_python_api.py b/torchao/float8/float8_python_api.py index 6608dba958..402ce2eb0f 100644 --- a/torchao/float8/float8_python_api.py +++ b/torchao/float8/float8_python_api.py @@ -37,19 +37,25 @@ def addmm_float8_unwrapped( a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() - if output_dtype == torch.float32 and bias is not None: + post_inverse_scale = None + if ( + a_scale.shape == (a_data.shape[0], 1) + and b_scale.shape == (1, b_data.shape[1]) + and not use_fast_accum + ): + # The rowwise CUTLASS-based kernel is so slow without fast-accum that + # we'd rather use the tensorwise cuBLAS-based kernel and do the scaling + # manually afterwards (hoping Inductor will be able to fuse it). + post_inverse_scale = a_inverse_scale * b_inverse_scale + a_inverse_scale = a_inverse_scale.new_ones(()) + b_inverse_scale = a_inverse_scale.new_ones(()) + + post_bias = None + if output_dtype == torch.float32: # Bias is not supported by _scaled_mm when output is fp32 - output = torch._scaled_mm( - a_data, - b_data, - scale_a=a_inverse_scale, - scale_b=b_inverse_scale, - scale_result=output_scale, - out_dtype=output_dtype, - use_fast_accum=use_fast_accum, - ) - output += bias - return output + post_bias = bias + bias = None + output = torch._scaled_mm( a_data, b_data, @@ -60,4 +66,10 @@ def addmm_float8_unwrapped( out_dtype=output_dtype, use_fast_accum=use_fast_accum, ) + + if post_inverse_scale is not None: + output *= post_inverse_scale + if post_bias is not None: + output += post_bias + return output diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index c8fe61c8a4..3a9841e625 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -22,7 +22,6 @@ ) from torchao.float8.float8_utils import ( amax_history_to_scale, - e5m2_dtype, tensor_to_amax, tensor_to_scale, ) @@ -182,7 +181,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDelayed(torch.autograd.Function): +class NoopFwToFloat8BwDelayed(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with delayed scaling, initialize if needed @@ -198,6 +197,7 @@ def forward( scale_fn_name, is_amax_initialized, linear_mm_config: LinearMMConfig, + target_dtype: torch.dtype, ): ctx.save_for_backward( fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output @@ -205,6 +205,7 @@ def forward( ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.linear_mm_config = linear_mm_config + ctx.target_dtype = target_dtype return tensor @staticmethod @@ -223,7 +224,7 @@ def backward(ctx, go): fp8_amax_history_grad_output, fp8_scale_grad_output, scale_fn_name, - e5m2_dtype, + ctx.target_dtype, is_amax_initialized, reduce_amax=True, ) @@ -233,16 +234,16 @@ def backward(ctx, go): res = hp_tensor_and_scale_to_float8( go, fp8_scale_grad_output, - e5m2_dtype, + ctx.target_dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - empty_grads = None, None, None, None, None, None + empty_grads = None, None, None, None, None, None, None return res, *empty_grads @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwDynamic(torch.autograd.Function): +class NoopFwToFloat8BwDynamic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with dynamic scaling @@ -253,27 +254,29 @@ def forward( ctx, tensor, linear_mm_config: LinearMMConfig, + target_dtype: torch.dtype, ): ctx.linear_mm_config = linear_mm_config + ctx.target_dtype = target_dtype return tensor @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None - gradY_scale = tensor_to_scale(gradY, e5m2_dtype) + return gradY, None, None + gradY_scale = tensor_to_scale(gradY, ctx.target_dtype) fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.target_dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None + return fp8_tensor, None, None @torch._dynamo.allow_in_graph -class NoopFwToFloat8E5M2BwStatic(torch.autograd.Function): +class NoopFwToFloat8BwStatic(torch.autograd.Function): """ Forward: no-op Backward: convert to float8_e5m2 with static scaling @@ -285,21 +288,23 @@ def forward( tensor, scale, linear_mm_config: LinearMMConfig, + target_dtype: torch.dtype, ): ctx.save_for_backward(scale) ctx.linear_mm_config = linear_mm_config + ctx.target_dtype = target_dtype return tensor @staticmethod def backward(ctx, gradY): if tensor_already_casted_to_fp8(gradY): - return gradY, None + return gradY, None, None, None (gradY_scale,) = ctx.saved_tensors fp8_tensor = hp_tensor_and_scale_to_float8( gradY, gradY_scale, - e5m2_dtype, + ctx.target_dtype, ctx.linear_mm_config, GemmInputRole.GRAD_OUTPUT, ) - return fp8_tensor, None, None + return fp8_tensor, None, None, None diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index 20f40330a8..fe2498e2b0 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -10,7 +10,6 @@ from torch.distributed._tensor import DTensor from torchao.float8.float8_utils import ( - e4m3_dtype, to_fp8_saturated, ) @@ -133,7 +132,7 @@ def forward( ctx, tensor: torch.Tensor, scale: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, @@ -213,7 +212,7 @@ def backward(ctx, g): def hp_tensor_and_scale_to_float8( hp_tensor: torch.Tensor, s: torch.Tensor, - float8_dtype=e4m3_dtype, + float8_dtype: torch.dtype, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, axiswise_dim: Optional[int] = None, diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a3fc4ce7e5..37cb67c7e7 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -8,13 +8,12 @@ RowwiseParallel, ) -from torchao.float8.config import ScalingType +from torchao.float8.config import ScalingType, e4m3_dtype from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8E5M2BwDynamic, + NoopFwToFloat8BwDynamic, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole -from torchao.float8.float8_utils import e4m3_dtype # subclass the ColwiseParallel and RowwiseParallel classes # to add the float8 support @@ -49,7 +48,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -70,7 +69,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # DTensor(torch.Tensor) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply( + outputs, + mod.linear_mm_config, + mod.config.cast_config_grad_output.target_dtype, + ) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -103,7 +106,7 @@ def _prepare_input_fn( input_tensor = hp_tensor_to_float8_dynamic( input_tensor, - e4m3_dtype, + mod.config.cast_config_input.target_dtype, mod.linear_mm_config, gemm_input_role=GemmInputRole.INPUT, ) # DTensor(Float8Tensor) @@ -123,7 +126,11 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me outputs = outputs.redistribute(placements=output_layouts, async_op=True) # fwd noop bwd cast to DTensor(Float8Tensor) - outputs = NoopFwToFloat8E5M2BwDynamic.apply(outputs, mod.linear_mm_config) + outputs = NoopFwToFloat8BwDynamic.apply( + outputs, + mod.linear_mm_config, + mod.config.cast_config_grad_output.target_dtype, + ) # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 29319f3814..90927659f8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -10,7 +10,7 @@ import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8TypeConfig, ScalingGranularity +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -28,12 +28,6 @@ } -# User defined type for using the individual F8 type based on config -type_config = Float8TypeConfig() -e4m3_dtype = type_config.e4m3_dtype -e5m2_dtype = type_config.e5m2_dtype - - @torch.no_grad() def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): """Converts the amax value of a tensor to the fp8 scale. @@ -173,7 +167,7 @@ def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def fp8_tensor_statistics( - tensor: torch.Tensor, float8_dtype=e4m3_dtype + tensor: torch.Tensor, float8_dtype: torch.dtype ) -> Tuple[int, ...]: """Calculate FP8 tensor stats diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 8c60995a86..9fde9922af 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Set, Tuple import torch import torch.nn as nn @@ -22,7 +22,7 @@ LinearMMConfig, hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_utils import EPS, e4m3_dtype +from torchao.float8.float8_utils import EPS @torch.no_grad() @@ -54,9 +54,14 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + target_dtypes: Set[torch.dtype] = { + float8_linear.config.cast_config_weight.target_dtype + for float8_linear in float8_linears + } if not weights: return + (target_dtype,) = target_dtypes # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial @@ -69,7 +74,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # upcast to float64 to ensure same numeric between compile and eager origin_dtype = amax_tensor.dtype amax_tensor = amax_tensor.to(torch.float64) - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + scale_tensor = torch.finfo(target_dtype).max / amax_tensor # Replicate if origin_dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) local_scale_tensor = scale_tensor.to_local().to(torch.float32) @@ -134,6 +139,7 @@ def __new__( cls, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( @@ -153,10 +159,12 @@ def __init__( self, tensor: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor self._linear_mm_config = linear_mm_config + self._dtype = dtype # for dynamic scaling # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step @@ -166,9 +174,10 @@ def __init__( def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithDynamicFloat8CastTensor( - args[0]._tensor, args[0]._linear_mm_config + args[0]._tensor, args[0]._linear_mm_config, args[0]._dtype ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal mm_config @@ -176,6 +185,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -185,40 +199,42 @@ def unwrap(t): if func not in _ops_to_preserve_subclass: return out return pytree.tree_map_only( - torch.Tensor, lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config), out + torch.Tensor, + lambda x: WeightWithDynamicFloat8CastTensor(x, mm_config, dtype), + out, ) def __tensor_flatten__(self): + tensors = ["_tensor"] if self._precomputed_scale: - return ["_tensor", "_precomputed_scale"], self._linear_mm_config - else: - return ["_tensor"], self._linear_mm_config + tensors.append("_precomputed_scale") + return tensors, {"mm_config": self._linear_mm_config, "dtype": self._dtype} @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], getattr(inner_tensors, "_precomputed_scale", None), ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, linear_mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): if self._precomputed_scale is not None: float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._precomputed_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) else: float8_tensor = hp_tensor_to_float8_dynamic( self._tensor, - e4m3_dtype, + self._dtype, self._linear_mm_config, reduce_amax=True, gemm_input_role=GemmInputRole.WEIGHT, @@ -268,6 +284,7 @@ def __new__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): return torch.Tensor._make_wrapper_subclass( @@ -290,6 +307,7 @@ def __init__( amax_history_buffer: torch.Tensor, scale_buffer: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, is_amax_initialized: bool, ): self._tensor = tensor @@ -297,6 +315,7 @@ def __init__( self._amax_history_buffer = amax_history_buffer self._scale_buffer = scale_buffer self._linear_mm_config = linear_mm_config + self._dtype = dtype # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -312,9 +331,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): args[0]._amax_history_buffer, args[0]._scale_buffer, args[0]._linear_mm_config, + args[0]._dtype, args[0].is_amax_initialized, ) mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None amax_buffer: Optional[torch.Tensor] = None amax_history_buffer: Optional[torch.Tensor] = None scale_buffer: Optional[torch.Tensor] = None @@ -326,6 +347,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype nonlocal amax_buffer if amax_buffer is None: amax_buffer = t._amax_buffer @@ -354,6 +380,7 @@ def unwrap(t): amax_history_buffer, scale_buffer, mm_config, + dtype, is_amax_initialized, ), out, @@ -369,6 +396,7 @@ def __tensor_flatten__(self): ], { "mm_config": self._linear_mm_config, + "dtype": self._dtype, "is_amax_initialized": self.is_amax_initialized, }, ) @@ -381,11 +409,12 @@ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): inner_tensors["_amax_history_buffer"], inner_tensors["_scale_buffer"], metadata["mm_config"], + metadata["dtype"], metadata["is_amax_initialized"], ) def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config})" + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config}, dtype={self._dtype})" def fsdp_pre_all_gather(self, mesh): # initialize if needed @@ -401,7 +430,7 @@ def fsdp_pre_all_gather(self, mesh): self._amax_history_buffer, self._scale_buffer, "max", # TODO(before land): read this from parent - e4m3_dtype, + self._dtype, self.is_amax_initialized, reduce_amax=True, ) @@ -410,7 +439,7 @@ def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_to_float8_delayed( self._tensor, self._scale_buffer, - e4m3_dtype, + self._dtype, self._amax_buffer, self._linear_mm_config, GemmInputRole.WEIGHT, @@ -447,6 +476,7 @@ def __new__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -466,19 +496,25 @@ def __init__( tensor: torch.Tensor, static_scale: torch.Tensor, linear_mm_config: LinearMMConfig, + dtype: torch.dtype, ): self._tensor = tensor self._static_scale = static_scale self._linear_mm_config = linear_mm_config + self._dtype = dtype @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.detach.default: return WeightWithStaticFloat8CastTensor( - args[0]._tensor, args[0]._static_scale, args[0]._linear_mm_config + args[0]._tensor, + args[0]._static_scale, + args[0]._linear_mm_config, + args[0]._dtype, ) static_scale: Optional[torch.Tensor] = None mm_config: Optional[LinearMMConfig] = None + dtype: Optional[torch.dtype] = None def unwrap(t): nonlocal static_scale @@ -489,6 +525,11 @@ def unwrap(t): mm_config = t._linear_mm_config else: assert t._linear_mm_config == mm_config + nonlocal dtype + if dtype is None: + dtype = t._dtype + else: + assert t._dtype == dtype return t._tensor args, kwargs = pytree.tree_map_only( @@ -499,30 +540,35 @@ def unwrap(t): return out return pytree.tree_map_only( torch.Tensor, - lambda x: WeightWithStaticFloat8CastTensor(x, static_scale, mm_config), + lambda x: WeightWithStaticFloat8CastTensor( + x, static_scale, mm_config, dtype + ), out, ) def __tensor_flatten__(self): - return ["_tensor", "_static_scale"], self._linear_mm_config + return ["_tensor", "_static_scale"], { + "mm_config": self._linear_mm_config, + "dtype": self._dtype, + } @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - mm_config = flatten_spec return WeightWithStaticFloat8CastTensor( inner_tensors["_tensor"], inner_tensors["_static_scale"], - mm_config, + flatten_spec["mm_config"], + flatten_spec["dtype"], ) def __repr__(self): - return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config})" + return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config}, dtype={self.dtype})" def fsdp_pre_all_gather(self, mesh): float8_tensor = hp_tensor_and_scale_to_float8( self._tensor, self._static_scale, - torch.float8_e4m3fn, + self._dtype, self._linear_mm_config, GemmInputRole.WEIGHT, ) diff --git a/torchao/kernel/README.md b/torchao/kernel/README.md index ab97d148f2..903bca5a68 100644 --- a/torchao/kernel/README.md +++ b/torchao/kernel/README.md @@ -6,6 +6,9 @@ Set this to a nonzero value to enable the kernels generated by the autotuner. This is turned off by default, because it is still an experimental feature and also can take a long time to run. +`TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE` +Use this to enable exhaustive search for both int8mm and scaled_mm kernels. + Searching a new config can take a long time and we'll save the updated data in `data.pkl`. If you'd like to contributed updated configs for your hardware or shapes, please open a pull request. `TORCHAO_AUTOTUNER_DATA_PATH=torchao/kernel/configs/data_a100.pkl` diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 81c2550246..afc5bcfa3f 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -10,7 +10,8 @@ from torchao.kernel import intmm_triton else: intmm_triton = None -except ImportError: +except ImportError as e: + print("import error:", e) # On cpu-only builds might not be available. intmm_triton = None @@ -56,7 +57,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: and j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) - + if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( @@ -75,8 +76,8 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: try: return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) except Exception: - # fallback path, would run on H100 for float8 dtypes - # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' + # fallback path, would run on H100 for float8 dtypes + # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: diff --git a/torchao/kernel/intmm_triton.py b/torchao/kernel/intmm_triton.py index 4e84d9cd3c..f6f42e2f53 100644 --- a/torchao/kernel/intmm_triton.py +++ b/torchao/kernel/intmm_triton.py @@ -7,35 +7,50 @@ import triton.language as tl from torchao.kernel.autotuner import get_best_config_fn +from torchao.utils import TORCH_VERSION_AFTER_2_5 -int8_powers_of_two = [32, 64, 128, 256] -int8_mm_kernel_configs = sum( - [ - # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" +# TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option +int8_mm_kernel_configs = ( + sum( [ - (i, j, k, 1, 1), - (i, j, k, 1, 2), - (i, j, k, 2, 2), - (i, j, k, 1, 4), - (i, j, k, 2, 4), - (i, j, k, 3, 4), - (i, j, k, 4, 4), - (i, j, k, 1, 8), - (i, j, k, 2, 8), - (i, j, k, 3, 8), - (i, j, k, 4, 8), - (i, j, k, 5, 8), - (i, j, k, 6, 8), - (i, j, k, 7, 8), - (i, j, k, 8, 8), - ] - for (i, j, k) in itertools.product( - int8_powers_of_two, int8_powers_of_two, int8_powers_of_two - ) - ], - [], + # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" + [ + (i, j, k, 1, 1), + (i, j, k, 1, 2), + (i, j, k, 2, 2), + (i, j, k, 1, 4), + (i, j, k, 2, 4), + (i, j, k, 3, 4), + (i, j, k, 4, 4), + (i, j, k, 1, 8), + (i, j, k, 2, 8), + (i, j, k, 3, 8), + (i, j, k, 4, 8), + (i, j, k, 5, 8), + (i, j, k, 6, 8), + (i, j, k, 7, 8), + (i, j, k, 8, 8), + ] + for (i, j, k) in itertools.product( + [32, 64, 128, 256], repeat=3 + ) + ], + [] + ) ) +if TORCH_VERSION_AFTER_2_5: + if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE": + int8_mm_kernel_configs = [ + (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( + [16, 32, 64, 128, 256], repeat=3 + ) + for num_stages in [1, 2, 3, 4, 5, 6, 7, 8] + for num_warps in [2, 4, 8] + ] + + # Baseline configs from pytorch/pytorch # https://github.com/pytorch/pytorch/blob/7718a1cd4f8e0b794c18a31ebd6353d6273c534e/torch/_inductor/kernel/mm_common.py#L132-L147 # int8_mm_kernel_configs = [ diff --git a/torchao/ops.py b/torchao/ops.py index 9713f68eb2..2774deb08a 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -3,13 +3,22 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 - lib = torch.library.Library("torchao", "FRAGMENT") -lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") -lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") -lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") -lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") -lib.define("marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor") +lib.define( + "quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor" +) +lib.define( + "unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor" +) +lib.define( + "dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor" +) +lib.define( + "marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor" +) +lib.define( + "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" +) def register_custom_op(name): @@ -18,6 +27,7 @@ def decorator(func): return torch.library.register_fake(f"{name}")(func) else: return torch.library.impl_abstract(f"{name}")(func) + return decorator @@ -43,7 +53,9 @@ def quant_llm_linear( Returns output of linear layer """ - return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK) + return torch.ops.torchao.quant_llm_linear.default( + EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK + ) @register_custom_op("torchao::quant_llm_linear") @@ -55,12 +67,29 @@ def _( _scales: Tensor, splitK: int = 1, ) -> Tensor: - torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") - torch._check(_in_feats.dtype in (torch.float16, torch.bfloat16), lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}") - torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") - torch._check(_weights.dtype is torch.uint8, lambda: f"weight must be UINT8, got {_weights.dtype}") - torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") - torch._check(_scales.dtype in (torch.float16, torch.bfloat16), lambda: f"scale must be FP16 or BF16, got {_scales.dtype}") + torch._check( + _in_feats.dim() == 2, + lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D", + ) + torch._check( + _in_feats.dtype in (torch.float16, torch.bfloat16), + lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}", + ) + torch._check( + _weights.dim() == 2, + lambda: f"weight should be a 2d tensor, got {_weights.dim()}D", + ) + torch._check( + _weights.dtype is torch.uint8, + lambda: f"weight must be UINT8, got {_weights.dtype}", + ) + torch._check( + _scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D" + ) + torch._check( + _scales.dtype in (torch.float16, torch.bfloat16), + lambda: f"scale must be FP16 or BF16, got {_scales.dtype}", + ) BS, IC = _in_feats.shape OC, _ = _weights.shape @@ -71,7 +100,6 @@ def _( return _in_feats.new_empty((BS, OC)) - def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Tensor: """ Unpacks weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K`. @@ -115,7 +143,10 @@ def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor: return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) -def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: + +def dequantize_tensor_core_tiled_layout( + packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int +) -> Tensor: """ Dequantizes by: - Unpacking weights that were packed with `torch.ops.aten._convert_weight_to_int4pack` to original tensor of shape `N x K` @@ -143,7 +174,9 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens @register_custom_op("torchao::dequantize_tensor_core_tiled_layout") -def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: +def _( + packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int +) -> Tensor: # packed_w preconditions torch._check( packed_w.dim() == 4, @@ -166,12 +199,28 @@ def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles K = packed_w.size(1) * inner_k_tiles * 16 # scales_and_zeros preconditions - torch._check(scales_and_zeros.dtype is torch.bfloat16, lambda: "scales_and_zeros must be bfloat16") - torch._check(scales_and_zeros.dim() == 3, lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}") - torch._check(group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, lambda: "qGroupSize must be 32, 64, 128, or 256") - torch._check(scales_and_zeros.size(0) == K // group_size, lambda: "scales_and_zeros must have K // qGroupSize at dim 0") - torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") - torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") + torch._check( + scales_and_zeros.dtype is torch.bfloat16, + lambda: "scales_and_zeros must be bfloat16", + ) + torch._check( + scales_and_zeros.dim() == 3, + lambda: "scales_and_zeros must be 3D, got {scales_and_zeros.dim()}", + ) + torch._check( + group_size == 32 or group_size == 64 or group_size == 128 or group_size == 256, + lambda: "qGroupSize must be 32, 64, 128, or 256", + ) + torch._check( + scales_and_zeros.size(0) == K // group_size, + lambda: "scales_and_zeros must have K // qGroupSize at dim 0", + ) + torch._check( + scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1" + ) + torch._check( + scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2" + ) return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) @@ -224,27 +273,55 @@ def _( MAX_PARALLELISM = 64 # Verify num_bits - torch._check(bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}") + torch._check( + bits == 4 or bits == 8, lambda: f"num_bits must be 4 or 8. Got = {bits}" + ) pack_factor = 32 // bits # Verify M - torch._check(size_m == x.size(0), lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}") + torch._check( + size_m == x.size(0), + lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}", + ) # Verify K - torch._check(size_k == x.size(1), lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}") - torch._check(size_k % TILE_SIZE == 0, lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}") - torch._check((size_k // TILE_SIZE // 2) == weight_marlin.size(0), lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}") + torch._check( + size_k == x.size(1), + lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}", + ) + torch._check( + size_k % TILE_SIZE == 0, + lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}", + ) + torch._check( + (size_k // TILE_SIZE // 2) == weight_marlin.size(0), + lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}", + ) # Verify N - torch._check(s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}") - torch._check(weight_marlin.size(1) % TILE_SIZE == 0, lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}") + torch._check( + s.size(1) == size_n, lambda: f"s.size(1) = {s.size(1)}, size_n = {size_n}" + ) + torch._check( + weight_marlin.size(1) % TILE_SIZE == 0, + lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}", + ) actual_size_n = (weight_marlin.size(1) // TILE_SIZE) * pack_factor - torch._check(size_n == actual_size_n, lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}") + torch._check( + size_n == actual_size_n, + lambda: f"size_n = {size_n}, actual_size_n = {actual_size_n}", + ) # Verify meta - torch._check(meta.size(0) == size_k // 8 // 2 // 2, lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}") - torch._check(meta.size(1) == size_n * 2, lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}") + torch._check( + meta.size(0) == size_k // 8 // 2 // 2, + lambda: f"meta.size(0) = {meta.size(0)} is not size_k / 8 / 2 / 2 = {size_k // 8 // 2 // 2}", + ) + torch._check( + meta.size(1) == size_n * 2, + lambda: f"meta.size(1) = {meta.size(1)} is not size_n * 2 = {size_n * 2}", + ) # Verify A device and strides torch._check(x.is_cuda, lambda: "x is not on GPU") @@ -252,7 +329,9 @@ def _( # Verify B device and strides torch._check(weight_marlin.is_cuda, lambda: "weight_marlin is not on GPU") - torch._check(weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous") + torch._check( + weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous" + ) # Verify meta device and strides torch._check(meta.is_cuda, lambda: "meta is not on GPU") @@ -265,15 +344,27 @@ def _( # Verify groupsize groupsize = -1 if s.size(0) > 1: - torch._check(size_k % s.size(0) == 0, lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}") + torch._check( + size_k % s.size(0) == 0, + lambda: f"size_k = {size_k} is not divisible by s.size(0) = {s.size(0)}", + ) groupsize = size_k // s.size(0) groupsize //= 2 # Because of 24 - torch._check(groupsize == -1 or groupsize == 64, lambda: f"Unexpected groupsize = {groupsize}") + torch._check( + groupsize == -1 or groupsize == 64, + lambda: f"Unexpected groupsize = {groupsize}", + ) # Verify workspace size - torch._check(size_n % MIN_THREAD_N == 0, lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}") + torch._check( + size_n % MIN_THREAD_N == 0, + lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}", + ) min_workspace_size = (size_n // MIN_THREAD_N) * MAX_PARALLELISM - torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}") + torch._check( + workspace.numel() >= min_workspace_size, + lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}", + ) return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 8bf1d34260..1bdbcd96e1 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f - Times are in `ms`, see `benchmarks/benchmark_hqq.py`. - `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). -- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. GPU details: diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 8abdad039a..743c6128a7 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -12,7 +12,8 @@ from hqq.core.utils import * import torch.nn.functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.dtypes.utils import is_device class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - W_q_torch, self.inner_k_tiles - ) + if is_device(W_q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + W_q_torch, self.inner_k_tiles + ) + else: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) del W_q_torch, scales_torch, zeros_torch @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants( .contiguous() ) if TORCH_VERSION_AT_LEAST_2_5: - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val @@ -232,9 +239,14 @@ def pack_scales_and_zeros(self, scales, zeros): def matmul(self, x): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x, self.weight_int4pack, self.groupsize, self.scales_and_zeros - ) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) + else: + c = torch.ops.aten._weight_int4pack_mm( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) new_shape = origin_x_size[:-1] + (self.out_features,) c = c.reshape(new_shape) return c diff --git a/torchao/profiler/__init__.py b/torchao/prototype/profiler/__init__.py similarity index 99% rename from torchao/profiler/__init__.py rename to torchao/prototype/profiler/__init__.py index e748438e87..976d4e3a05 100644 --- a/torchao/profiler/__init__.py +++ b/torchao/prototype/profiler/__init__.py @@ -1,4 +1,3 @@ - # Re-exports from .device_spec import CUDADeviceSpec, DeviceSpec from .performance_counter import ( @@ -20,4 +19,3 @@ "DeviceSpec", "total_model_params", ] - diff --git a/torchao/profiler/device_spec.py b/torchao/prototype/profiler/device_spec.py similarity index 100% rename from torchao/profiler/device_spec.py rename to torchao/prototype/profiler/device_spec.py diff --git a/torchao/profiler/performance_counter.py b/torchao/prototype/profiler/performance_counter.py similarity index 100% rename from torchao/profiler/performance_counter.py rename to torchao/prototype/profiler/performance_counter.py diff --git a/torchao/profiler/utils.py b/torchao/prototype/profiler/utils.py similarity index 100% rename from torchao/profiler/utils.py rename to torchao/prototype/profiler/utils.py diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index a11fe861e4..977c1fd288 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -30,7 +30,9 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, - benchmark_model, + TorchAOBaseTensor, + is_sm_at_least_89, + is_sm_at_least_90, ) from torchao.quantization.granularity import ( @@ -61,7 +63,9 @@ "autoquant_v2", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", + "ALL_AUTOQUANT_CLASS_LIST", "_is_linear", ] @@ -288,7 +292,7 @@ def to_quantized(self, error_on_unseen, **kwargs): ) elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen - self = AQFloatLinearWeight.from_float(self.weight) + self = AQDefaultLinearWeight.from_float(self.weight) return self # only want to print shape (at start) and final result (at end) @@ -360,7 +364,7 @@ def count_shapes(self, do_print=True): print(f"best_cls={best_cls}\n") # TODO handle random cls args/kwargs? or should they be curried? if best_cls is None: - best_cls = AQFloatLinearWeight + best_cls = AQDefaultLinearWeight self = best_cls.from_float(self.weight) return self @@ -802,7 +806,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 -class AQFloatLinearWeight(torch.Tensor, AQMixin): +class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the @@ -823,6 +827,130 @@ def from_float(cls, weight): return weight +class Float32Tensor(TorchAOBaseTensor): + """ Tensor subclass tensor for fp32 dtype + """ + def __init__(self, weight): + self.weight = weight.to(torch.float32) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + + @classmethod + def from_float(cls, weight): + return cls(weight) + +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat32LinearWeight, cls).from_float(weight) + + +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) + + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + @classmethod + def from_float(cls, weight): + return super(AQFloat16LinearWeight, cls).from_float(weight) + + class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn @@ -936,7 +1064,7 @@ def get_weight_block_size(x): # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, # AQInt8WeightOnlyQuantizedLinearWeight3, @@ -945,17 +1073,30 @@ def get_weight_block_size(x): ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt4G64WeightOnlyQuantizedLinearWeight, ] +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + OTHER_AUTOQUANT_CLASS_LIST = [ AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, ] +ALL_AUTOQUANT_CLASS_LIST = list(set(DEFAULT_AUTOQUANT_CLASS_LIST + DEFAULT_INT4_AUTOQUANT_CLASS_LIST + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST)) +if is_sm_at_least_89(): + ALL_AUTOQUANT_CLASS_LIST += [AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight] + +if is_sm_at_least_90(): + ALL_AUTOQUANT_CLASS_LIST += [AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight] + def _replace_with_custom_fn_if_matches_filter( model, diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index dc68f59ceb..cb7c8d0481 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -17,8 +17,10 @@ import torch.nn.functional as F from torch.utils._pytree import tree_flatten, tree_unflatten +from torchao.dtypes.utils import is_device from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_6, find_multiple, ) @@ -537,12 +539,20 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x.to(precision), - weight_int4pack, - groupsize, - scales_and_zeros.to(scales_precision), - ).to(dtype=x.dtype) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) + else: + c = torch.ops.aten._weight_int4pack_mm( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -570,8 +580,6 @@ def __init__( super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: - from .utils import find_multiple - self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -591,19 +599,32 @@ def __init__( assert ( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.zeros( - ( - out_features // 8, - in_features // (inner_k_tiles * 16), - 32, - inner_k_tiles // 2, + if is_device(device.type, "cpu"): + self.register_buffer( + "weight", + torch.zeros( + ( + out_features, + in_features // 2, + ), + dtype=torch.uint8, + device=device, ), - dtype=torch.int32, - device=device, - ), - ) + ) + else: + self.register_buffer( + "weight", + torch.zeros( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + device=device, + ), + ) self.dtype = dtype self.register_buffer( "scales_and_zeros", @@ -738,8 +759,6 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - from .utils import find_multiple - logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) @@ -760,9 +779,19 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - w_int4x8.to(self.device), self.inner_k_tiles - ) + if ( + is_device(w_int4x8.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + weight_int4pack = ( + torch.ops.aten._convert_weight_to_int4pack_for_cpu( + w_int4x8.to(self.device), self.inner_k_tiles + ) + ) + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + w_int4x8.to(self.device), self.inner_k_tiles + ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( self.device @@ -846,9 +875,14 @@ def make_names_and_values_dict_func(q, qparams): # how much we need to pad the weight delta_k = int((new_k - k) / 2) q = q.to(self.device) - final_q = torch.ops.aten._convert_weight_to_int4pack( - F.pad(q, pad=(0, delta_k)), inner_k_tiles - ) + if is_device(self.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) + else: + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) @@ -1113,8 +1147,6 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - from .utils import find_multiple - logging.warn( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 022fe7d916..3fc2cb5ef0 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -3,7 +3,7 @@ Typically quantization algorithms will have different schemes for how the activa ## Benchmarks Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B. - +### CUDA backend | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | @@ -20,9 +20,16 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP | | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | | | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | | | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | +### XPU backend +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (bfloat16) | NA | 42.20 | 557.71 | 13.89 | 13.21 | +| | int8dq | NA | 9.87 | 65.35 | 14.60 | 6.62 | +| | int8wo | NA | 66.24 | 438.61 | 14.60 | 6.62 -Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. + +### CUDA backend | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 | @@ -31,6 +38,15 @@ Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a ma | | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 | | | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 | | | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 | +### XPU backend +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 | +| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 | +| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52 + + +Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU or Intel-Max1100 using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. @@ -333,7 +349,16 @@ We're trying to develop kernels for low bit quantization for intx quantization f You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. +### int8_dynamic_activation_intx_weight Quantization +We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. + +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ------------- | -------------------------------------------------| --------------| ------------------------| ---------------- | ----------------| +| Llama-3.1-8B | Base (bfloat16) | 1.24 | 18.62 | NA | 15.01 | +| | int8_dynamic_activation_intx_weight-4-256-false | 16.03 | 65.81 | NA | 4.11 | +| | int8_dynamic_activation_intx_weight-3-256-false | 18.94 | 59.97 | NA | 3.17 | +You try can out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. ### Automatic Inductor Configuration The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ff66e23cc9..14dfbab52b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -10,8 +10,11 @@ ) from .autoquant import ( + ALL_AUTOQUANT_CLASS_LIST, DEFAULT_AUTOQUANT_CLASS_LIST, + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, autoquant, ) @@ -89,7 +92,10 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", + "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", + "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", "int8_dynamic_activation_int4_weight", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 87cb5e2655..b8cd0125f0 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -6,9 +6,12 @@ from torchao.dtypes import ( AffineQuantizedTensor, Float8Layout, + MarlinSparseLayout, PlainLayout, + SemiSparseLayout, TensorCoreTiledLayout, ) +from torchao.dtypes.utils import Layout from torchao.float8.inference import Float8MMConfig from torchao.kernel import safe_int_mm from torchao.quantization.linear_activation_quantized_tensor import ( @@ -18,8 +21,17 @@ MappingType, ZeroPointDomain, ) -from torchao.quantization.utils import quantize_activation_per_token_absmax -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.quantization.utils import ( + compute_error, + quantize_activation_per_token_absmax, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, + is_sm_at_least_89, + is_sm_at_least_90, +) from .granularity import ( PerRow, @@ -36,7 +48,10 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", + "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", + "ALL_AUTOQUANT_CLASS_LIST", ] @@ -69,7 +84,15 @@ class AutoQuantizableLinearWeight(torch.Tensor): """ @staticmethod - def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + def __new__( + cls, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, + ): kwargs["device"] = weight.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else weight.layout @@ -82,12 +105,19 @@ def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwarg return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( - self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs + self, + weight, + qtensor_class_list, + *args, + mode=["relu", None], + min_sqnr=None, + **kwargs, ): self.weight = weight self.qtensor_class_list = qtensor_class_list self.logged_data = {} self.mode = mode + self.min_sqnr = min_sqnr def __repr__(self): return ( @@ -123,9 +153,25 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): else torch.randn(bias_shape, dtype=act_dtype, device=self.device) ) try: - res = q_cls._autoquant_test( - act_mat, self.weight, bias, best_time, self.mode + ref_output = AQDefaultLinearWeight._quantized_linear_op( + act_mat, self.weight, bias ) + q_output = q_cls._quantized_linear_op( + act_mat, q_cls.from_float(self.weight), bias + ) + if ( + self.min_sqnr is not None + and (sqnr := compute_error(q_output, ref_output)) + < self.min_sqnr + ): + print( + f"skipping q_cls: {q_cls} because the sqnr is too small, minimum expected sqnr: {self.min_sqnr}, got {sqnr}" + ) + res = torch.inf + else: + res = q_cls._autoquant_test( + act_mat, self.weight, bias, best_time, self.mode + ) except Exception as e: print( f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}" @@ -141,7 +187,7 @@ def to_quantized(self, error_on_unseen, **kwargs): ) elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen - self = AQFloatLinearWeight.from_float(self.weight) + self = AQDefaultLinearWeight.from_float(self.weight) return self # only want to print shape (at start) and final result (at end) @@ -194,34 +240,49 @@ def count_shapes(self, do_print=True): print( f">time (all shapes): {cur_time:0.4f}ms for {q_cls}, prev_best: {best_time:0.4f}ms" ) - if best_time >= cur_time: + if cur_time != torch.inf and best_time >= cur_time: best_time = cur_time best_cls = q_cls # if no new benchmarking was done, don't print the final result, it will be the same as for another layer if ran_new_benchmarks: print(f"best_cls={best_cls}\n") + + if best_cls is None: + best_cls = AQDefaultLinearWeight + # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) return self def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode + fn(self.weight), + self.qtensor_class_list, + dtype=self.dtype, + mode=self.mode, + min_sqnr=self.min_sqnr, ) def __tensor_flatten__(self): - return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] + return ["weight"], [ + self.qtensor_class_list, + self.mode, + self.min_sqnr, + self.dtype, + self.shape, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): weight = tensor_data_dict["weight"] - qtensor_class_list, mode, dtype, shape = tensor_attributes + qtensor_class_list, mode, min_sqnr, dtype, shape = tensor_attributes return cls( weight, qtensor_class_list, - mode, + mode=mode, + min_sqnr=min_sqnr, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride, @@ -349,6 +410,8 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedT AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ + layout: Layout = PlainLayout() + @classmethod def from_float(cls, weight): # TODO test if this is valid @@ -357,6 +420,9 @@ def from_float(cls, weight): # if in_features <= 16: # return weight + if weight.dim() != 2: + return weight + # avoid circular dep from torchao.dtypes import to_affine_quantized_intx @@ -382,7 +448,7 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - _layout = PlainLayout() + _layout = cls.layout input_quant_func = lambda x: to_affine_quantized_intx( x, input_mapping_type, @@ -469,6 +535,16 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): return res_f +class AQInt8DynamicallyQuantizedSemiSparseLinearWeight( + AQInt8DynamicallyQuantizedLinearWeight +): + layout: Layout = SemiSparseLayout() + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + return super()._autoquant_test(act_mat, weight, bias, best_time, None) + + class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight @@ -556,14 +632,16 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ group_size: int = 32 + layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8) @classmethod def from_float(cls, weight): group_size = cls.group_size - _layout = TensorCoreTiledLayout(inner_k_tiles=8) + _layout = cls.layout if weight.shape[-1] % group_size != 0: return weight + use_hqq = True mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -574,6 +652,13 @@ def from_float(cls, weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT + + if isinstance(_layout, MarlinSparseLayout): + mapping_type = MappingType.SYMMETRIC + preserve_zero = True + zero_point_domain = ZeroPointDomain.INT + use_hqq = False + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx( weight, mapping_type, @@ -608,7 +693,14 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight( group_size: int = 256 -class AQFloatLinearWeight(torch.Tensor, AQMixin): +class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( + AQInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 128 + layout: Layout = MarlinSparseLayout() + + +class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the @@ -629,6 +721,135 @@ def from_float(cls, weight): return weight +class Float32Tensor(TorchAOBaseTensor): + """Tensor subclass tensor for fp32 dtype""" + + def __init__(self, weight): + self.weight = weight.to(torch.float32) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float32 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), + ) + + @classmethod + def from_float(cls, weight): + return cls(weight) + + +@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + + +@Float32Tensor.implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@Float32Tensor.implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@Float32Tensor.implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +class BFloat16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.bfloat16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.bfloat16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class Float16Tensor(Float32Tensor): + def __init__(self, weight): + self.weight = weight.to(torch.float16) + + @staticmethod + def _quantized_linear_op(act_mat, w_qtensor, bias): + _DTYPE = torch.float16 + orig_dtype = act_mat.dtype + return torch.nn.functional.linear( + act_mat.to(_DTYPE), + w_qtensor.weight, + bias.to(_DTYPE) if bias is not None else bias, + ).to(dtype=orig_dtype) + + +class AQFloat32LinearWeight(Float32Tensor, AQMixin): + """ + AutoQuantizable version for float32 precision weight + + (also converts input activation and bias to float32, and restores the original precision after + linear) + """ + + @classmethod + def from_float(cls, weight): + return super(AQFloat32LinearWeight, cls).from_float(weight) + + +class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin): + """ + AutoQuantizable version for bfloat16 precision weight + + (also converts input activation and bias to bfloat16, and restores the original precision after + linear) + """ + + @classmethod + def from_float(cls, weight): + return super(AQBFloat16LinearWeight, cls).from_float(weight) + + +class AQFloat16LinearWeight(Float16Tensor, AQMixin): + """ + AutoQuantizable version for float16 precision weight + + (also converts input activation and bias to float16, and restores the original precision after + linear) + """ + + @classmethod + def from_float(cls, weight): + return super(AQFloat16LinearWeight, cls).from_float(weight) + + class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): """ AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn @@ -742,7 +963,7 @@ def get_weight_block_size(x): # here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison DEFAULT_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, # AQInt8WeightOnlyQuantizedLinearWeight3, @@ -751,17 +972,47 @@ def get_weight_block_size(x): ] DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [ - AQFloatLinearWeight, + AQDefaultLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt4G64WeightOnlyQuantizedLinearWeight, ] +DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ + AQFloat32LinearWeight, + AQBFloat16LinearWeight, + AQFloat16LinearWeight, +] + OTHER_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, ] +DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, + AQInt8DynamicallyQuantizedSemiSparseLinearWeight, +] + +ALL_AUTOQUANT_CLASS_LIST = list( + set( + DEFAULT_AUTOQUANT_CLASS_LIST + + DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + + DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST + ) +) +if is_sm_at_least_89(): + ALL_AUTOQUANT_CLASS_LIST += [ + AQFloat8WeightOnlyQuantizedLinearWeight, + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, + ] + +if is_sm_at_least_90(): + ALL_AUTOQUANT_CLASS_LIST += [AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight] + def _change_linears_to_autoquantizable(model, **kwargs): """ @@ -779,6 +1030,7 @@ def _change_linears_to_autoquantizable(model, **kwargs): "qtensor_class_list", DEFAULT_AUTOQUANT_CLASS_LIST ) kwargs["mode"] = kwargs.get("mode", ["relu", None]) + kwargs["min_sqnr"] = kwargs.get("min_sqnr", None) from torchao.quantization.quant_api import ( _get_subclass_inserter, _replace_with_custom_fn_if_matches_filter, @@ -853,6 +1105,7 @@ def autoquant( manual=False, set_inductor_config=True, supress_autoquant_errors=True, + min_sqnr=None, **aq_kwargs, ): """ @@ -887,6 +1140,9 @@ def autoquant( the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) supress_autoquant_errors (bool, optional): Whether to suppress errors during autoquantization. (defaults to True) + min_sqnr (float, optional): minimum acceptable signal to quantization noise ration (https://en.wikipedia.org/wiki/Signal-to-quantization-noise_ratio) for output of quantized layer v.s. non-quantized layer, this is used to filter + out quantization methods that causes too large numerical impact, user can start with a resaonable + number like 40 and adjust depending on the result **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -919,6 +1175,7 @@ def autoquant( filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, + min_sqnr=min_sqnr, **aq_kwargs, ) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index 46b48393a3..e86b2f8e64 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -147,8 +147,8 @@ def _(func, types, args, kwargs): ) input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(bias, aqt, original_weight_tensor) + qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs) + return func(bias, qtensor, original_weight_tensor) else: # aten.mm.default assert args[0].shape[-1] == args[1].shape[0], ( @@ -161,8 +161,8 @@ def _(func, types, args, kwargs): ) input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(aqt, original_weight_tensor) + qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs) + return func(qtensor, original_weight_tensor) @implements(aten.detach.default) @@ -203,7 +203,9 @@ def _(func, types, args, kwargs): args, kwargs, LinearActivationQuantizedTensor( - func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, ), ) @@ -216,7 +218,9 @@ def _(func, types, args, kwargs): args, kwargs, LinearActivationQuantizedTensor( - func(args[0].original_weight_tensor, *args[1:]), args[0].input_quant_func + func(args[0].original_weight_tensor, *args[1:]), + args[0].input_quant_func, + args[0].quant_kwargs, ), ) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index cbe6296407..d5f2dca5b4 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F +from torchao.dtypes.utils import is_device from torchao.quantization.GPTQ import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, @@ -23,6 +24,7 @@ ) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 from .api import FakeQuantizeConfig from .fake_quantizer import FakeQuantizer @@ -363,6 +365,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): inner_k_tiles=inner_k_tiles, precision=child.weight.dtype, scales_precision=config.scale_precision, + device=next(child.parameters()).device, ) setattr(module, name, quantized_linear) @@ -373,10 +376,19 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), - child.inner_k_tiles, - ) + if ( + is_device(q_weight.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) + else: + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c730ec9046..99da86b87b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -39,6 +39,7 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -52,8 +53,8 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_MI300, - is_sm_89, - is_sm_90, + is_sm_at_least_89, + is_sm_at_least_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant @@ -222,6 +223,12 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ + if isinstance(model, Float8Linear): + with torch.device("meta"): + new_module = nn.Linear(model.in_features, model.out_features) + new_module.weight = model.weight + new_module.bias = model.bias + model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization @@ -630,7 +637,8 @@ def int4_weight_only( "tensor_core_tiled" layout for speedup with tinygemm kernel Note: - This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference + This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm` + and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference of quantization algorithm compared to the more traditional type of integer quantization is the following: 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) @@ -668,6 +676,9 @@ def apply_int4_weight_only_quant(weight): mapping_type = MappingType.SYMMETRIC preserve_zero = True zero_point_domain = ZeroPointDomain.INT + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" return to_affine_quantized_intx( weight, @@ -856,11 +867,11 @@ def _normalize_granularity( for _granularity in processed_granularity: if isinstance(_granularity, PerTensor): assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "PerTensor quantization only works for CUDA>=8.9 and MI300+" elif isinstance(_granularity, PerRow): assert ( - is_sm_90() or is_MI300() + is_sm_at_least_90() or is_MI300() ), "PerRow quantization only works for CUDA>=9.0 and MI300+" else: raise ValueError(f"Invalid granularity type: {_granularity}") @@ -958,7 +969,7 @@ def float8_dynamic_activation_float8_weight( """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) @@ -1015,7 +1026,7 @@ def float8_static_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ assert ( - is_sm_89() or is_MI300() + is_sm_at_least_89() or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" if mm_config is None: mm_config = Float8MMConfig(use_fast_accum=True) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 036109bc8d..9715d99e08 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -8,6 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import is_device from torchao.quantization.utils import ( dequantize_per_channel, dynamically_quantize_per_channel, @@ -15,7 +16,7 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from torchao.utils import find_multiple +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple __all__ = [ "Int8DynamicallyQuantizedLinearWeight", @@ -458,12 +459,20 @@ def _quantized_op(act_mat, w_qtensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # matmul - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) + if is_device(act_mat.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) + else: + y = aten._weight_int4pack_mm( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) # remove out_feature padding orig_out_features = ( @@ -609,5 +618,10 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( input_float, 4, groupsize, dtype=input_float.dtype ) - int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) + if is_device(input_float.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + int_data = aten._convert_weight_to_int4pack_for_cpu( + input_int4x8, inner_k_tiles + ) + else: + int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 9083dd7621..e1cf98b549 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -9,6 +9,7 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode +from torchao.dtypes.utils import is_device from torchao.kernel import ( int_scaled_matmul, ) @@ -19,7 +20,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 __all__ = [ "compute_error", @@ -402,13 +403,8 @@ def groupwise_affine_quantize_tensor_from_qparams( zero_point_domain=ZeroPointDomain.FLOAT, ) if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: - int_data_device_type = int_data.device.type - # Move to cpu, until issue with MPS memory management of temporary tensors is resolved - if int_data_device_type == "mps": - int_data = int_data.cpu() - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - if int_data_device_type == "mps": - int_data = int_data.to(device="mps") + if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data @@ -422,8 +418,10 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert groupsize > 1 assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path - if TORCH_VERSION_AT_LEAST_2_5 and ( - w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1 + if ( + TORCH_VERSION_AT_LEAST_2_5 + and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) + and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/float8/dtensor_utils.py index 1fab31d850..84e4095263 100644 --- a/torchao/testing/float8/dtensor_utils.py +++ b/torchao/testing/float8/dtensor_utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import torch import torch.nn as nn import torch.nn.functional as F diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index 7744ae4e92..af46b7fa71 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -1,16 +1,13 @@ -import contextlib -from typing import List, Optional +from typing import List import torch import torch.distributed as dist import torch.nn as nn -import torchao.float8.config as config from torchao.float8.config import ( Float8LinearConfig, ScalingType, ) - from torchao.float8.float8_linear_utils import ( linear_requires_sync, sync_float8_amax_and_scale_history, @@ -52,7 +49,11 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) def check_parity_bf16_mp( @@ -87,7 +88,11 @@ def check_parity_bf16_mp( ref_model.parameters(), ref_model_bf16.parameters() ): param_bf16.detach().copy_(param_fp32) - test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) def check_parity_fp8_comm_only( @@ -104,7 +109,6 @@ def check_parity_fp8_comm_only( for iter_idx in range(10): losses: List[torch.Tensor] = [] for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): - optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) losses.append(model(local_inp).sum()) losses[-1].backward() @@ -123,9 +127,15 @@ def check_parity_fp8_comm_only( and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC ): precompute_float8_dynamic_scale_for_fsdp(model) - + if compile: # When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now. - assert (torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}" + assert ( + torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any() + ), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}" else: - test_cls.assertEqual(losses[0], losses[1], f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual( + losses[0], + losses[1], + f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}", + ) diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 7f37c3f30a..7b8ac121b6 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -1,9 +1,9 @@ import torch + from torchao.float8.config import ( - ScalingGranularity, - ScalingType, - CastConfig, + CastConfig, Float8LinearConfig, + ScalingType, ) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 39edc50085..d88241783f 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -1,15 +1,19 @@ -import unittest -import functools import copy -import torch -import torchao -import os +import functools +import unittest +import torch +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils -from torchao.dtypes import AffineQuantizedTensor -from torchao.dtypes import to_affine_quantized_intx +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + +import torchao +from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx +from torchao.quantization import int8_weight_only, quantize_ from torchao.quantization.quant_primitives import MappingType -from torchao.quantization import quantize_, int8_weight_only from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 """ @@ -36,10 +40,9 @@ class MyTestCase(TorchAOBasicTestCase): unittest.main() """ + # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 -def copy_tests( - my_cls, other_cls, suffix, test_failures=None, xfail_prop=None -): # noqa: B902 +def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 for name, value in my_cls.__dict__.items(): if name.startswith("test_"): # You cannot copy functions in Python, so we use closures here to @@ -70,7 +73,6 @@ def new_test(self, value=value): setattr(other_cls, f"{name}_{suffix}", new_test) - class TorchAOBasicTestCase(common_utils.TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -90,17 +92,21 @@ def test_flatten_unflatten(self): hp_tensor = torch.randn(4, 128) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } outer_size = lp_tensor.size() outer_stride = lp_tensor.stride() - reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + reconstructed = self.TENSOR_SUBCLASS.__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) self.assertEqual(lp_tensor.dequantize(), reconstructed.dequantize()) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_hp_tensor_device_dtype(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) - lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + self.FACTORY_FN(hp_tensor, **self.kwargs) @common_utils.parametrize("device1", COMMON_DEVICES) @common_utils.parametrize("device2", COMMON_DEVICES) @@ -141,7 +147,10 @@ def test_linear(self, device, dtype): hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) lp_res = torch.nn.functional.linear(hp_act_tensor, lp_tensor) - self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + self.assertGreater( + torchao.quantization.utils.compute_error(hp_res, lp_res), + self.LINEAR_MIN_SQNR, + ) class TorchAOCompileTestCase(common_utils.TestCase): @@ -165,6 +174,7 @@ class TorchAOCompileTestCase(common_utils.TestCase): def test_input_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): return tensor @@ -179,6 +189,7 @@ def f(tensor): def test_input_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) lp_tensor = self.FACTORY_FN(hp_tensor, **self.kwargs) + def f(tensor): return tensor.dequantize() @@ -192,6 +203,7 @@ def f(tensor): @common_utils.parametrize("dtype", COMMON_DTYPES) def test_output_tensor_subclass(self, device, dtype): hp_tensor = torch.randn(4, 128, device=device, dtype=dtype) + def f(hp_tensor): return self.FACTORY_FN(hp_tensor, **self.kwargs) @@ -201,7 +213,12 @@ def f(hp_tensor): self.assertTrue(isinstance(f(hp_tensor), self.TENSOR_SUBCLASS)) # bfloat16 seems to result in much larger numerical differences if dtype != torch.bfloat16: - self.assertGreater(torchao.quantization.utils.compute_error(ref.dequantize(), compiled.dequantize()), self.COMPILE_MIN_SQNR) + self.assertGreater( + torchao.quantization.utils.compute_error( + ref.dequantize(), compiled.dequantize() + ), + self.COMPILE_MIN_SQNR, + ) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) @@ -211,22 +228,18 @@ def test_linear_compile(self, device, dtype): hp_act_tensor = torch.randn(32, 128, device=device, dtype=dtype) hp_res = torch.nn.functional.linear(hp_act_tensor, hp_tensor) - l = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) - l.weight = torch.nn.Parameter(lp_tensor) - lp_res = torch.compile(l)(hp_act_tensor) - self.assertGreater(torchao.quantization.utils.compute_error(hp_res, lp_res), self.LINEAR_MIN_SQNR) + linear = torch.nn.Linear(128, 4, bias=False, device=device, dtype=dtype) + linear.weight = torch.nn.Parameter(lp_tensor) + lp_res = torch.compile(linear)(hp_act_tensor) + self.assertGreater( + torchao.quantization.utils.compute_error(hp_res, lp_res), + self.LINEAR_MIN_SQNR, + ) -import torch.distributed as dist -from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh -from torch.testing._internal.distributed._tensor.common_dtensor import ( - DTensorTestBase, - with_comms, - NUM_DEVICES, -) class TorchAOTensorParallelTestCase(DTensorTestBase): - """Basic test case for tensor subclasses - """ + """Basic test case for tensor subclasses""" + COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] TENSOR_SUBCLASS = AffineQuantizedTensor @@ -247,9 +260,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m @staticmethod @@ -266,9 +277,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module: # Construct DTensor from local shard dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)]) # Replace parameter in module - m.linear.weight = torch.nn.Parameter( - dtensor, requires_grad=False - ) + m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False) return m def quantize(self, m: torch.nn.Module) -> torch.nn.Module: @@ -289,7 +298,9 @@ def test_tp(self, dtype): class M(torch.nn.Module): def __init__(self, in_features, out_features, **kwargs) -> None: super().__init__(**kwargs) - self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda") + self.linear = torch.nn.Linear( + in_features, out_features, bias=False, device="cuda" + ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -301,12 +312,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: proj_up = M(1024, 2048).to(device).to(dtype) proj_dn = M(2048, 1024).to(device).to(dtype) example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype) - y = proj_dn(proj_up(example_input)) + proj_dn(proj_up(example_input)) # Quantize the model up_quant = self.quantize(proj_up) dn_quant = self.quantize(proj_dn) - y_q = dn_quant(up_quant(example_input)) + dn_quant(up_quant(example_input)) mesh = self.build_device_mesh() mesh.device_type = "cuda" @@ -316,11 +327,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dn_dist = self.rowwise_shard(dn_quant, mesh) # We need to turn inputs into DTensor form as well -- just a format change - input_dtensor = DTensor.from_local( - example_input, mesh, [Replicate()] - ) + input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()]) - y_d = dn_dist(up_dist(input_dtensor)) + dn_dist(up_dist(input_dtensor)) if not TORCH_VERSION_AT_LEAST_2_6: # Need torch 2.6 to support compiled tensor parallelism @@ -329,7 +338,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: up_compiled = torch.compile(up_dist) y_up = up_compiled(input_dtensor) dn_compiled = torch.compile(dn_dist) - y_dn = dn_compiled(y_up) + dn_compiled(y_up) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) diff --git a/torchao/utils.py b/torchao/utils.py index ba91fb3fe0..d56191ed6b 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -33,8 +33,8 @@ "TORCH_VERSION_AFTER_2_4", "TORCH_VERSION_AFTER_2_5", "is_MI300", - "is_sm_89", - "is_sm_90", + "is_sm_at_least_89", + "is_sm_at_least_90", ] @@ -612,7 +612,7 @@ def is_MI300(): return False -def is_sm_89(): +def is_sm_at_least_89(): return ( torch.cuda.is_available() and torch.version.cuda @@ -620,7 +620,7 @@ def is_sm_89(): ) -def is_sm_90(): +def is_sm_at_least_90(): return ( torch.cuda.is_available() and torch.version.cuda diff --git a/version.txt b/version.txt index faef31a435..a3df0a6959 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.7.0 +0.8.0