From 1797c75bbf36d660df0e17f75847044253cccc89 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 18 Dec 2024 05:36:53 -0800 Subject: [PATCH] fixing some corner cases not caught in unit tests Summary: shapes need to be divisible by 128 or they will not work with gemlite need fp32 accumulation for groupsize None on int4 Test Plan: python test_integration.py -k "test_gemlite" (new test for non divisible shape)a python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txta (previously these gave nonsense responses) Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 8 +++ torchao/_models/llama/generate.py | 85 ++++++++++++++++++-------- torchao/dtypes/uintx/gemlite_layout.py | 13 ++++ 3 files changed, 81 insertions(+), 25 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index faabf48ab..65a4d2093 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -962,6 +962,14 @@ def test_gemlite_layout(self, device, dtype): test_shape=test_shape, test_dtype=dtype, ) + # test that shapes with non divisible by 128 shapes aren't causing errors + self._test_lin_weight_subclass_api_impl( + lambda mod: quantize_(mod, gemlite_uintx_weight_only(None, 4, 32)), + device, + 15, + test_shape=[1, 1025, 513], + test_dtype=dtype, + ) @parameterized.expand(COMMON_DEVICE_DTYPE) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 231133c2c..9b0208375 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import json import os import platform import sys @@ -18,11 +19,6 @@ import torchao from torchao.quantization.quant_primitives import MappingType from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5 -from torchao._models.utils import ( - get_arch_name, - write_json_result_ossci, - write_json_result_local, -) torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False @@ -41,6 +37,14 @@ def elapsed_time(self, other_event): return abs(other_event.event_time - self.event_time) * 1000 +def get_arch_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_name() + else: + # This returns x86_64 or arm64 (for aarch64) + return platform.machine() + + def device_timer(device): if "cuda" in device: return torch.cuda.Event(enable_timing=True) @@ -61,6 +65,39 @@ def device_sync(device): print(f"device={device} is not yet suppported") +def write_json_result(output_json_path, headers, row): + """ + Write the result into JSON format, so that it can be uploaded to the benchmark database + to be displayed on OSS dashboard. The JSON format is defined at + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + mapping_headers = {headers[i]: v for i, v in enumerate(row)} + record = { + "benchmark": { + "name": "TorchAO benchmark", + "mode": "inference", + "dtype": mapping_headers["dtype"], + "extra_info": { + "device": mapping_headers["device"], + "arch": mapping_headers["arch"], + }, + }, + "model": { + "name": mapping_headers["name"], + "type": "model", + "origins": ["pytorch"], + }, + "metric": { + "name": mapping_headers["metric"], + "benchmark_values": [mapping_headers["actual"]], + "target_value": mapping_headers["target"], + }, + } + + with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f: + print(json.dumps(record), file=f) + + default_device = ( "cuda" if torch.cuda.is_available() @@ -135,7 +172,7 @@ def decode_n_tokens( next_token, next_prob = next_token.clone(), next_prob.clone() input_pos += 1 # in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step - new_tokens.append(next_token.clone()) + new_tokens.append(next_token.clone()) callback(new_tokens[-1]) new_probs.append(next_prob) cur_token = next_token @@ -279,7 +316,6 @@ def main( precision=torch.bfloat16, write_result: Optional[Path] = None, output_json_path: Optional[Path] = None, - output_json_local: bool = False, ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer.""" @@ -692,10 +728,20 @@ def ffn_or_attn_only(mod, fqn): example_input=inputs, ) if "autoquant-all" == quantization: + all_qtensor_classes = ( + torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) + if torchao.utils.is_sm_89(): + # this is fp8 related subclasses, should rename + all_qtensor_classes += ( + torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST + ) model = autoquant( model, manual=True, - qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST, + qtensor_class_list=all_qtensor_classes, example_input=inputs, ) else: @@ -713,10 +759,6 @@ def ffn_or_attn_only(mod, fqn): # do autoquantization model.finalize_autoquant() - elif "codebook" in quantization: - from torchao.prototype.quantization.codebook import codebook_weight_only - model.to(device) - quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64)) else: if not TORCH_VERSION_AT_LEAST_2_5: @@ -936,14 +978,13 @@ def callback(x): f.write(result_txt) f.close() + headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] + name = checkpoint_path.parent.name + arch = get_arch_name() + dtype = quantization or str(precision) + memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None] + performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None] if output_json_path: - headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] - name = checkpoint_path.parent.name - arch = get_arch_name() - dtype = quantization or "noquant" - memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None] - performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None] - write_json_result = write_json_result_local if output_json_local else write_json_result_ossci write_json_result(output_json_path, headers, memory_result) write_json_result(output_json_path, headers, performance_result) @@ -1045,11 +1086,6 @@ def callback(x): default=None, help="Path where to write the json result for dashboard", ) - parser.add_argument( - "--output_json_local", - action="store_true", - help="Whether to output json result for local machine or for CI machine, local option will fill in some dummy fields", - ) args = parser.parse_args() print(args) @@ -1077,5 +1113,4 @@ def callback(x): args.precision, args.write_result, args.output_json_path, - args.output_json_local, ) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 969816727..c775996ba 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -15,6 +15,7 @@ from torchao.dtypes.utils import Layout, is_device from torchao.quantization.quant_primitives import quantize_affine from torchao.utils import fill_defaults +import warnings aten = torch.ops.aten @@ -76,6 +77,14 @@ def apply_gemlite_quant( out_features, in_features = weight.shape group_size = in_features if group_size is None else group_size + if in_features % 128 != 0 and out_features % 128 != 0: + warnings.simplefilter("once", UserWarning) + warnings.warn( + "Gemlite only works for layers with in_features or out_features divisible by 128, " + + "some layers have been skipped", UserWarning + ) + return weight + quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size) layout = GemlitePackedLayout( @@ -173,6 +182,10 @@ def from_plain( exhaustive=False, use_cuda_graph=False, ) + if _layout.group_size == None and _layout.bit_width == 4: + from gemlite.core import GEMLITE_ACC_DTYPE + from gemlite.dtypes import DType + GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32 out_features, in_features = int_data.shape input_dtype, output_dtype = DType.FP16, DType.FP16