diff --git a/.github/workflows/dashboard_perf_test.yml b/.github/workflows/dashboard_perf_test.yml new file mode 100644 index 0000000000..c2933be107 --- /dev/null +++ b/.github/workflows/dashboard_perf_test.yml @@ -0,0 +1,46 @@ +name: A100-perf-nightly + +on: + workflow_dispatch: + schedule: + - cron: 0 7 * * 0-6 + +jobs: + benchmark: + runs-on: linux.aws.a100 + strategy: + matrix: + torch-spec: + - '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124' + steps: + - uses: actions/checkout@v3 + + - name: Setup miniconda + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: "3.9" + + - name: Run benchmark + shell: bash + run: | + set -eux + ${CONDA_RUN} python -m pip install --upgrade pip + ${CONDA_RUN} pip install ${{ matrix.torch-spec }} + ${CONDA_RUN} pip install -r dev-requirements.txt + ${CONDA_RUN} pip install . + + export CHECKPOINT_PATH=checkpoints + export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf + ${CONDA_RUN} python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf --hf_token ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + ${CONDA_RUN} python scripts/convert_hf_checkpoint.py --checkpoint_dir "${CHECKPOINT_PATH}/${MODEL_REPO}" + + mkdir -p ${{ runner.temp }}/benchmark-results + ${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --output_json_path ${{ runner.temp }}/benchmark-results/benchmark-results.json + + - name: Upload the benchmark results to OSS benchmark database for the dashboard + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: ${{ runner.temp }}/benchmark-results + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 065cc9c56d..7570700c65 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -3,22 +3,26 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import json import os +import platform import sys import time +from datetime import datetime from pathlib import Path from typing import Optional, Tuple -from datetime import datetime + import torch -import torchao import torch._dynamo.config import torch._inductor.config -from torchao.utils import get_model_size_in_bytes + +import torchao from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5 torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + class HostEvent: def __init__(self): self.event_time = None @@ -32,6 +36,15 @@ def elapsed_time(self, other_event): # return ms to match cuda event return abs(other_event.event_time - self.event_time) * 1000 + +def get_arch_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_name() + else: + # This returns x86_64 or arm64 (for aarch64) + return platform.machine() + + def device_timer(device): if "cuda" in device: return torch.cuda.Event(enable_timing=True) @@ -40,6 +53,7 @@ def device_timer(device): else: print(f"device={device} is not yet suppported") + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) @@ -50,19 +64,63 @@ def device_sync(device): else: print(f"device={device} is not yet suppported") -default_device = 'cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu' + +def write_json_result(output_json_path, headers, row): + """ + Write the result into JSON format, so that it can be uploaded to the benchmark database + to be displayed on OSS dashboard. The JSON format is defined at + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + mapping_headers = {headers[i]: v for i, v in enumerate(row)} + record = { + "benchmark": { + "name": "TorchAO benchmark", + "mode": "inference", + "dtype": mapping_headers["dtype"], + "extra_info": { + "device": mapping_headers["device"], + "arch": mapping_headers["arch"], + }, + }, + "model": { + "name": mapping_headers["name"], + "type": "model", + "origins": ["pytorch"], + }, + "metric": { + "name": mapping_headers["metric"], + "benchmark_values": [mapping_headers["actual"]], + "target_value": mapping_headers["target"], + }, + } + + with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f: + print(json.dumps(record), file=f) + + +default_device = ( + "cuda" + if torch.cuda.is_available() + else "xpu" + if torch.xpu.is_available() + else "cpu" +) # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.model import prepare_inputs_for_model, Transformer from torchao._models.llama.tokenizer import get_tokenizer -def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization q = torch.empty_like(probs_sort).exponential_(1) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): logits = logits / max(temperature, 1e-5) @@ -73,23 +131,38 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non probs = torch.nn.functional.softmax(logits, dim=-1) return probs + def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): probs = logits_to_probs(logits[:, -1], temperature, top_k) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs -def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: # input_pos: [B, S] logits = model(x, input_pos) return sample(logits, **sampling_kwargs)[0] -def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: # input_pos: [B, 1] assert input_pos.shape[-1] == 1 logits = model(x, input_pos) return sample(logits, **sampling_kwargs) -def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): new_tokens, new_probs = [], [] for i in range(num_new_tokens): with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): @@ -109,6 +182,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc def model_forward(model, x, input_pos): return model(x, input_pos) + @torch.no_grad() def generate( model: Transformer, @@ -117,15 +191,15 @@ def generate( batch_size: int, *, interactive: bool, - callback = lambda x: x, + callback=lambda x: x, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, - linear_causal_mask: bool=False, - prefill_start_event: Optional[torch.cuda.Event]=None, - prefill_end_event: Optional[torch.cuda.Event]=None, - decode_start_event: Optional[torch.cuda.Event]=None, - decode_end_event: Optional[torch.cuda.Event]=None, - **sampling_kwargs + linear_causal_mask: bool = False, + prefill_start_event: Optional[torch.cuda.Event] = None, + prefill_end_event: Optional[torch.cuda.Event] = None, + decode_start_event: Optional[torch.cuda.Event] = None, + decode_end_event: Optional[torch.cuda.Event] = None, + **sampling_kwargs, ) -> torch.Tensor: """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. @@ -136,12 +210,14 @@ def generate( T = prompt.size(-1) # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) - max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + max_seq_length = ( + min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + ) new_tokens = max_seq_length - T # format model input prompt, input_pos = prepare_inputs_for_model(prompt) - prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize + prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize # full prompt+output will be stored in seq seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device) @@ -151,13 +227,23 @@ def generate( with torch.device(device): if cache_size is None: cache_size = max_seq_length - assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt" - model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T) + assert ( + cache_size >= max_seq_length + ), "need cache_size to be greater than max_new_tokens + size-of-prompt" + model.setup_caches( + max_batch_size=batch_size, + max_seq_length=cache_size, + kv_cache_quantization=kv_cache_quantization, + linear_causal_mask=linear_causal_mask, + prompt_length=T, + ) # execute prefill if prefill_start_event is not None: prefill_start_event.record() - next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() + next_token = prefill( + model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs + ).clone() seq[:, T] = next_token.squeeze() if prefill_end_event is not None: prefill_end_event.record() @@ -166,19 +252,28 @@ def generate( if decode_start_event is not None: decode_start_event.record() input_pos = torch.tensor([T], device=device, dtype=torch.int) - generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) - seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(batch_size, -1), + input_pos, + new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq = torch.cat((seq[:, : T + 1], *generated_tokens), dim=-1) if decode_end_event is not None: decode_end_event.record() return seq + def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) if bos: tokens = [tokenizer.bos_id()] + tokens return torch.tensor(tokens, dtype=torch.int, device=device) + def _load_model(checkpoint_path, device, precision): checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) if "model" in checkpoint and "stories" in str(checkpoint_path): @@ -190,8 +285,10 @@ def _load_model(checkpoint_path, device, precision): return model.eval() + B_INST, E_INST = "[INST]", "[/INST]" + def main( prefill_size: Optional[int] = None, prompt: str = "Hello, my name is", @@ -201,12 +298,14 @@ def main( batch_size: int = 1, top_k: int = 200, temperature: float = 0.8, - checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + checkpoint_path: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" + ), quantization: Optional[str] = None, sparsity: Optional[str] = None, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, - linear_causal_mask: bool=False, + linear_causal_mask: bool = False, save: bool = False, compile: bool = True, compile_prefill: bool = False, @@ -215,13 +314,13 @@ def main( device=default_device, precision=torch.bfloat16, write_result: Optional[Path] = None, + output_json_path: Optional[Path] = None, ) -> None: - """Generates text samples based on a pre-trained Transformer model and tokenizer. - """ + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" if prefill_size is not None and prefill_size > 0: - # create prompt of prefill size - prompt = "prompt " * (int(prefill_size)-3) + # create prompt of prefill size + prompt = "prompt " * (int(prefill_size) - 3) torchao.quantization.utils.recommended_inductor_config_setter() @@ -236,8 +335,7 @@ def main( t0 = time.time() model = _load_model(checkpoint_path, device, precision) - - device_sync(device=device) # MKG + device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) @@ -248,54 +346,70 @@ def main( torch.manual_seed(1234) def ffn_only(mod, fqn): - return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn + return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn def not_ffn_only(mod, fqn): return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) def ffn_or_attn_only(mod, fqn): - return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn) + return isinstance(mod, torch.nn.Linear) and ( + "feed_forward" in fqn or "attention" in fqn + ) if quantization: from torchao.quantization import ( - quantize_, autoquant, - int8_weight_only, - int8_dynamic_activation_int8_weight, + float8_dynamic_activation_float8_weight, + float8_weight_only, + fpx_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, - fpx_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, uintx_weight_only, - float8_weight_only, - float8_dynamic_activation_float8_weight, ) - from torchao.utils import unwrap_tensor_subclass - from torchao.quantization.granularity import PerTensor, PerRow + from torchao.quantization.granularity import PerRow, PerTensor from torchao.utils import unwrap_tensor_subclass + if "spinquant" in quantization: from torchao.prototype.spinquant import apply_spinquant + apply_spinquant(model) if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: if sparsity and "semi" in sparsity: from torchao.dtypes import SemiSparseLayout - quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only) - quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only) + + quantize_( + model, + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), + filter_fn=ffn_only, + ) + quantize_( + model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only + ) else: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: if "hqq" in quantization: - use_hqq=True + use_hqq = True else: - use_hqq=False - group_size=int(quantization.split("-")[1]) - assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" + use_hqq = False + group_size = int(quantization.split("-")[1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size)) if "marlin" in quantization: if "qqq" in quantization: from torchao.dtypes import MarlinQQQLayout + quantize_( model, int8_dynamic_activation_int4_weight( @@ -307,25 +421,41 @@ def ffn_or_attn_only(mod, fqn): ) elif "semi" in sparsity: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only) + + quantize_( + model, + int4_weight_only(layout=MarlinSparseLayout()), + filter_fn=ffn_or_attn_only, + ) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) elif "embed-int8wo" in quantization: - quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding)) + quantize_( + model, + int8_weight_only(group_size=64), + filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), + ) elif quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 - from torchao.prototype.awq.example import get_calib_dataset + if not TORCH_VERSION_AT_LEAST_2_3: print("Awq requires torch2.3+") exit() - from torchao.prototype.awq import insert_awq_observer_, awq_uintx, AWQObservedLinear + from torchao.prototype.awq import ( + awq_uintx, + AWQObservedLinear, + insert_awq_observer_, + ) + quant_dtype = quantization.split("-")[1] group_size = int(quantization.split("-")[2]) quant_dtype = getattr(torch, quant_dtype, torch.uint8) - model=model.to(device) + model = model.to(device) # get calibration data - insert_awq_observer_(model, 1, 256, quant_dtype=quant_dtype, group_size=group_size) + insert_awq_observer_( + model, 1, 256, quant_dtype=quant_dtype, group_size=group_size + ) TransformerEvalWrapper( model=model.to(device), tokenizer=tokenizer, @@ -333,12 +463,18 @@ def ffn_or_attn_only(mod, fqn): input_prep_func=prepare_inputs_for_model, device=device, ).run_eval( - tasks=['wikitext'], + tasks=["wikitext"], limit=1, ) is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) use_hqq = "hqq" in quantization - quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear) + quantize_( + model, + awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ), + is_observed_linear, + ) elif "uintx" in quantization: # uintx-nbits-group_size, e.g. "uintx-2-64" if "hqq" in quantization: @@ -349,18 +485,36 @@ def ffn_or_attn_only(mod, fqn): _quant_args = quantization.split("-") nbits = int(_quant_args[1]) assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8" - _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} + _NBITS_TO_DTYPE = { + 1: torch.uint1, + 2: torch.uint2, + 3: torch.uint3, + 4: torch.uint4, + 5: torch.uint5, + 6: torch.uint6, + 7: torch.uint7, + 8: torch.uint8, + } dtype = _NBITS_TO_DTYPE[nbits] group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) elif "int8_dynamic_activation_intx_weight" in quantization: - from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight - assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires fp32 precision" + from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, + ) + + assert ( + precision == torch.float32 + ), "int8_dynamic_activation_intx_weight requires fp32 precision" # Build kernels in temp location, and load them in torch # This requires an ARM CPU from torchao.experimental.temp_build import temp_build_and_load_torchao_ops - temp_build_and_load_torchao_ops(cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + "/../../experimental") + + temp_build_and_load_torchao_ops( + cmake_lists_path=os.path.dirname(os.path.realpath(__file__)) + + "/../../experimental" + ) # Quantize model _quant_args = quantization.split("-") @@ -380,31 +534,38 @@ def ffn_or_attn_only(mod, fqn): quantize_(model, float8_weight_only()) elif "float8dq" in quantization: granularity = str(quantization.split("-")[-1]) - if granularity=="tensor": + if granularity == "tensor": granularity = PerTensor() - elif granularity=="row": + elif granularity == "row": granularity = PerRow() else: granularity = PerTensor() - quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity)) + quantize_( + model, float8_dynamic_activation_float8_weight(granularity=granularity) + ) elif "autoquant_v2" in quantization: - from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 from torchao._models._eval import InputRecorder from torchao._models.llama.model import prepare_inputs_for_model + from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 calibration_seq_length = 256 calibration_limit = 1 - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - prepare_inputs_for_model, - False, # pad_calibration_inputs - model.config.vocab_size, - device="cuda" - ).record_inputs( - ["wikitext"], - 1, - ).get_inputs()[0].values[0] + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda", + ) + .record_inputs( + ["wikitext"], + 1, + ) + .get_inputs()[0] + .values[0] + ) inputs = prepare_inputs_for_model(inputs) with torch.device("cuda"): model.setup_caches( @@ -412,19 +573,54 @@ def ffn_or_attn_only(mod, fqn): ) if "autoquant_v2-int4" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) elif "autoquant_v2-float8" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) elif "autoquant_v2-fp" == quantization: - model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, + example_input=inputs, + batch_size=calibration_seq_length, + ) elif "autoquant_v2-all" == quantization: - all_qtensor_classes = torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + all_qtensor_classes = ( + torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST + + torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) if torchao.utils.is_sm_89(): # this is fp8 related subclasses, should rename - all_qtensor_classes += torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST - model = autoquant_v2(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs, batch_size=calibration_seq_length) + all_qtensor_classes += ( + torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST + ) + model = autoquant_v2( + model, + manual=True, + qtensor_class_list=all_qtensor_classes, + example_input=inputs, + batch_size=calibration_seq_length, + ) else: - model = autoquant_v2(model, manual=True, example_input=inputs, batch_size=calibration_seq_length) + model = autoquant_v2( + model, + manual=True, + example_input=inputs, + batch_size=calibration_seq_length, + ) print("running generate") generate( @@ -446,17 +642,22 @@ def ffn_or_attn_only(mod, fqn): calibration_seq_length = 256 calibration_limit = 1 - inputs = InputRecorder( - tokenizer, - calibration_seq_length, - prepare_inputs_for_model, - False, # pad_calibration_inputs - model.config.vocab_size, - device="cuda" - ).record_inputs( - ["wikitext"], - 1, - ).get_inputs()[0].values[0] + inputs = ( + InputRecorder( + tokenizer, + calibration_seq_length, + prepare_inputs_for_model, + False, # pad_calibration_inputs + model.config.vocab_size, + device="cuda", + ) + .record_inputs( + ["wikitext"], + 1, + ) + .get_inputs()[0] + .values[0] + ) inputs = prepare_inputs_for_model(inputs) with torch.device("cuda"): model.setup_caches( @@ -464,17 +665,43 @@ def ffn_or_attn_only(mod, fqn): ) if "autoquant-int4" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) elif "autoquant-float8" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) if "autoquant-fp" == quantization: - model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs) + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) if "autoquant-all" == quantization: - all_qtensor_classes = torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + all_qtensor_classes = ( + torchao.quantization.DEFAULT_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST + + torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST + ) if torchao.utils.is_sm_89(): # this is fp8 related subclasses, should rename - all_qtensor_classes += torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST - model = autoquant(model, manual=True, qtensor_class_list = all_qtensor_classes, example_input=inputs) + all_qtensor_classes += ( + torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST + ) + model = autoquant( + model, + manual=True, + qtensor_class_list=all_qtensor_classes, + example_input=inputs, + ) else: model = autoquant(model, manual=True, example_input=inputs) @@ -498,8 +725,9 @@ def ffn_or_attn_only(mod, fqn): # standalone sparsity elif sparsity: from torchao.sparsity import semi_sparse_weight, sparsify_ + if "semi" in sparsity: - #TODO there is a bug here, need to fix + # TODO there is a bug here, need to fix sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 @@ -507,39 +735,48 @@ def ffn_or_attn_only(mod, fqn): if save: output_dir = str(checkpoint_path.cwd()) filename = str(checkpoint_path.name).split(".")[0] - torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt")) + torch.save( + model.state_dict(), + os.path.join(output_dir, filename + f"-{quantization}.pt"), + ) if compile: print("Compiling Model") global decode_one_token, prefill - decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + decode_one_token = torch.compile( + decode_one_token, mode="reduce-overhead", fullgraph=True + ) if compile_prefill: prefill = torch.compile(prefill, fullgraph=True, dynamic=True) if memory_profile: if device == "cuda": - torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + torch.cuda.memory._record_memory_history( + True, trace_alloc_max_entries=250000, trace_alloc_record_context=True + ) elif device == "xpu": - torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + torch.xpu.memory._record_memory_history( + True, trace_alloc_max_entries=250000, trace_alloc_record_context=True + ) else: print("Memory profiling only works on CUDA or XPU devices") - + aggregate_metrics = { - 'tokens_per_sec': [], - 'time': [], - 'decode_tokens_per_sec': [], - 'prefill_time': [], + "tokens_per_sec": [], + "time": [], + "decode_tokens_per_sec": [], + "prefill_time": [], } start = -1 if compile else 0 for i in range(start, num_samples): - if i==0: + if i == 0: if device == "cuda": - torch.cuda.reset_peak_memory_stats() # MKG + torch.cuda.reset_peak_memory_stats() # MKG elif device == "xpu": - torch.xpu.reset_peak_memory_stats() # MKG - device_sync(device=device) # MKG + torch.xpu.reset_peak_memory_stats() # MKG + device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") if is_chat: @@ -548,8 +785,9 @@ def ffn_or_attn_only(mod, fqn): if interactive and i >= 0: buffer = [] - period_id = tokenizer.encode('.')[0] + period_id = tokenizer.encode(".")[0] done_generating = False + def callback(x): nonlocal done_generating if done_generating: @@ -558,16 +796,22 @@ def callback(x): if x.item() == tokenizer.eos_id(): done_generating = True if len(buffer) == 4 or done_generating: - print(''.join(buffer), end='', flush=True) + print("".join(buffer), end="", flush=True) buffer.clear() # print(, end='', flush=True) + else: - callback = lambda x : x + callback = lambda x: x t0 = time.perf_counter() - prefill_start_event, prefill_end_event = device_timer(device), device_timer(device) - decode_start_event, decode_end_event = device_timer(device), device_timer(device) + prefill_start_event, prefill_end_event = device_timer(device), device_timer( + device + ) + decode_start_event, decode_end_event = device_timer(device), device_timer( + device + ) import contextlib - if (i != num_samples - 1 or not profile): + + if i != num_samples - 1 or not profile: prof = contextlib.nullcontext() else: torch.profiler._utils._init_for_cuda_graphs() @@ -595,60 +839,69 @@ def callback(x): continue if hasattr(prof, "export_chrome_trace"): prof.export_chrome_trace(f"{profile}.json") - device_sync(device=device) # MKG + device_sync(device=device) # MKG t = time.perf_counter() - t0 if not interactive and prefill_size is None: - tok_list = y[0].tolist() - # truncate text after end of string token - tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())] - print(tokenizer.decode(tokens)) + tok_list = y[0].tolist() + # truncate text after end of string token + tokens = ( + tok_list + if tokenizer.eos_id() not in tok_list + else tok_list[: tok_list.index(tokenizer.eos_id())] + ) + print(tokenizer.decode(tokens)) else: print() - tokens_generated = (y.size(-1) - prompt_length) + tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t - aggregate_metrics['tokens_per_sec'].append(tokens_sec) - aggregate_metrics['time'].append(t) + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["time"].append(t) decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000 decode_tokens_sec = tokens_generated / decode_time - aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec) + aggregate_metrics["decode_tokens_per_sec"].append(decode_tokens_sec) prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000 - aggregate_metrics['prefill_time'].append(prefill_time) - print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", - f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec") + aggregate_metrics["prefill_time"].append(prefill_time) + print( + f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", + f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec", + ) print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") - if memory_profile and i==0: + if memory_profile and i == 0: if device == "cuda": snapshot = torch.cuda.memory._snapshot() elif device == "xpu": snapshot = torch.xpu.memory._snapshot() else: print("Memory profiling only works on CUDA or XPU devices") - - with open(f"{memory_profile}.pickle", 'wb') as f: + + with open(f"{memory_profile}.pickle", "wb") as f: from pickle import dump + dump(snapshot, f) print( f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", - "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" + "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html", ) break print("==========") - #ignore first sample for warmup - tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() - ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item() - decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item() + # ignore first sample for warmup + tokpersec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() + ttft = torch.mean(torch.tensor(aggregate_metrics["prefill_time"])).item() + decode_tokpersec = torch.mean( + torch.tensor(aggregate_metrics["decode_tokens_per_sec"]) + ).item() bandwidth = model_size * tokpersec - mem = torch.cuda.max_memory_reserved() /1e9 + mem = torch.cuda.max_memory_reserved() / 1e9 print(f"Average overall tokens/sec: {tokpersec:.2f}") print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s") print(f"Average TTFT: {ttft:.04f} s") - if device == "cuda": - mem = torch.cuda.max_memory_reserved() /1e9 + if device == "cuda": + mem = torch.cuda.max_memory_reserved() / 1e9 elif device == "xpu": - mem = torch.xpu.max_memory_reserved() /1e9 + mem = torch.xpu.max_memory_reserved() / 1e9 print(f"Average tokens/sec: {tokpersec:.2f}") if batch_size > 1: print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}") @@ -658,71 +911,163 @@ def callback(x): if write_result: result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " - result_txt += f"repro: python generate.py " + result_txt += "repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" result_txt += f"--sparsity {sparsity} " if sparsity else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " - result_txt += f"--compile " if compile else "" - result_txt += f"--compile_prefill " if compile_prefill else "" + result_txt += "--compile " if compile else "" + result_txt += "--compile_prefill " if compile_prefill else "" result_txt += f"--prefill_size {prefill_size}" if prefill_size else "" result_txt += f"--profile {profile} " if profile else "" result_txt += f"--profile {memory_profile} " if memory_profile else "" - result_txt += f"--interactive " if interactive else "" + result_txt += "--interactive " if interactive else "" result_txt += f"--num_samples {num_samples} " result_txt += f"--max_new_tokens {max_new_tokens} " result_txt += f"--batch_size {batch_size} " result_txt += f"--top_k {top_k} " result_txt += f"--temperature {temperature} " result_txt += f"--cache_size {cache_size}" if cache_size else "" - result_txt += f"--kv_cache_quantization " if kv_cache_quantization else "" - result_txt += f"--linear_causal_mask " if linear_causal_mask else "" + result_txt += "--kv_cache_quantization " if kv_cache_quantization else "" + result_txt += "--linear_causal_mask " if linear_causal_mask else "" - f=open(write_result, "a") + f = open(write_result, "a") f.write(result_txt) f.close() + headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"] + name = checkpoint_path.parent.name + arch = get_arch_name() + dtype = quantization or str(precision) + memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None] + performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None] + if output_json_path: + write_json_result(output_json_path, headers, memory_result) + write_json_result(output_json_path, headers, performance_result) -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') - parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode') - parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') - parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') - parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') - parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with') - parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') - parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') - parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, - help=( - 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' - +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, ' - +'embed-int8wo, marlin_qqq' - ) + + parser = argparse.ArgumentParser(description="Your CLI description.") + parser.add_argument( + "--prefill_size", type=int, default=0, help="Whether to run in ttft mode" + ) + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." ) - parser.add_argument('-s', '--sparsity', type=str, + parser.add_argument( + "--interactive", + action="store_true", + help="Whether to launch in interactive mode", + ) + parser.add_argument("--num_samples", type=int, default=5, help="Number of samples.") + parser.add_argument( + "--max_new_tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size to benchmark with" + ) + parser.add_argument("--top_k", type=int, default=200, help="Top-k for sampling.") + parser.add_argument( + "--temperature", type=float, default=0.8, help="Temperature for sampling." + ) + parser.add_argument( + "--checkpoint_path", + type=Path, + default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + help="Model checkpoint path.", + ) + parser.add_argument( + "-q", + "--quantization", + type=str, help=( - 'Which sparsity techniques to apply: semi-structured' - ) + "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + + "autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, " + + "embed-int8wo, marlin_qqq" + ), + ) + parser.add_argument( + "-s", + "--sparsity", + type=str, + help=("Which sparsity techniques to apply: semi-structured"), + ) + parser.add_argument( + "--kv_cache_quantization", + action="store_true", + help="Whether to quantize the KV cache", + ) + parser.add_argument( + "--cache_size", + type=int, + default=None, + help="Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size", + ) + parser.add_argument( + "--linear_causal_mask", + action="store_true", + help="Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)", + ) + parser.add_argument( + "--save", action="store_true", help="Whether to save the quantized model." + ) + parser.add_argument( + "--compile", action="store_true", help="Whether to compile the model." + ) + parser.add_argument( + "--compile_prefill", + action="store_true", + help="Whether to compile the prefill (improves prefill perf, but higher compile times)", + ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") + parser.add_argument( + "--memory_profile", type=Path, default=None, help="filename for memory profile." + ) + parser.add_argument( + "--device", type=str, default=default_device, help="Device to use" + ) + parser.add_argument( + "--precision", + type=lambda x: getattr(torch, x.split(".")[-1]), + default=torch.bfloat16, + help="dtype precision to use", + ) + parser.add_argument( + "--write_result", type=Path, default=None, help="Path where to write the result" + ) + parser.add_argument( + "--output_json_path", + type=Path, + default=None, + help="Path where to write the json result for dashboard", ) - parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') - parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') - parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') - parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.') - parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') - parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') - parser.add_argument('--profile', type=Path, default=None, help='Profile path.') - parser.add_argument('--memory_profile', type=Path, default=None, help='filename for memory profile.') - parser.add_argument('--device', type=str, default=default_device, help='Device to use') - parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') - parser.add_argument('--write_result', type=Path, default=None, help='Path where to write the result') args = parser.parse_args() main( - args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.prefill_size, + args.prompt, + args.interactive, + args.num_samples, + args.max_new_tokens, + args.batch_size, + args.top_k, + args.temperature, + args.checkpoint_path, + args.quantization, + args.sparsity, + args.kv_cache_quantization, + args.cache_size, + args.linear_causal_mask, + args.save, + args.compile, + args.compile_prefill, + args.profile, + args.memory_profile, + args.device, + args.precision, + args.write_result, + args.output_json_path, )