diff --git a/scripts/prepare.sh b/scripts/prepare.sh index 9cbc8295e..fe037de29 100644 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -1,11 +1,13 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B +python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B-Instruct python scripts/download.py --repo_id meta-llama/Llama-3.2-3B python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4 python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B-Instruct python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B # neuralmagic doesn't come with tokenizer, so we need to copy it over mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model diff --git a/torchao/_models/llama/.gitignore b/torchao/_models/llama/.gitignore new file mode 100644 index 000000000..819b89eab --- /dev/null +++ b/torchao/_models/llama/.gitignore @@ -0,0 +1 @@ +moby.txt diff --git a/torchao/_models/llama/demo_summarize.sh b/torchao/_models/llama/demo_summarize.sh new file mode 100644 index 000000000..5aeb91ee6 --- /dev/null +++ b/torchao/_models/llama/demo_summarize.sh @@ -0,0 +1,8 @@ +# grab moby dick prompt +wget -nc -O moby.txt https://gist.githubusercontent.com/jcaip/f319146bb543e92e23b2c76815b0f29f/raw/31a9cd12b0b59f323eb197c9534953bdac352986/gistfile1.txt + +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B-Instruct + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq_prefill_wo_decode --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 56950c0e0..a084f9d12 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -34,6 +34,7 @@ def run_evaluation( device = "cuda", precision = torch.bfloat16, quantization: Optional[str] = None, + sparsity:Optional[str] = None, compile=False, max_length=None, calibration_tasks: Optional[List[str]] = None, @@ -44,7 +45,7 @@ def run_evaluation( """Runs the evaluation of a model using LM Eval.""" print( f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, " - +f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, " + +f"quantization: {quantization}, sparsity: {sparsity}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, " +f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n" ) torchao.quantization.utils.recommended_inductor_config_setter() @@ -236,6 +237,13 @@ def run_evaluation( "float8wo, float8dq, float8saq" ), ) + parser.add_argument( + "--sparsity", + type=str, + help=( + "Which sparsity techniques to apply: semi-structured" + ), + ) parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') @@ -251,6 +259,7 @@ def run_evaluation( args.device, args.precision, args.quantization, + args.sparstiy, args.compile, args.max_length, args.calibration_tasks, diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 6e2e4f713..81b981f29 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -25,7 +25,7 @@ ) torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False - +torch.backends.cuda.enable_cudnn_sdp(True) class HostEvent: def __init__(self): @@ -256,6 +256,7 @@ def _load_model(checkpoint_path, device, precision): def main( prefill_size: Optional[int] = None, prompt: str = "Hello, my name is", + demo_summarize_prompt: Optional[str] = None, interactive: bool = False, num_samples: int = 5, max_new_tokens: int = 100, @@ -285,7 +286,11 @@ def main( if prefill_size is not None and prefill_size > 0: # create prompt of prefill size - prompt = "prompt " * (int(prefill_size) - 3) + if demo_summarize_prompt is None: + prompt = "prompt " * (int(prefill_size) - 2) + else: + with open(demo_summarize_prompt, "r") as f: + prompt = f.read() torchao.quantization.utils.recommended_inductor_config_setter() @@ -306,6 +311,12 @@ def main( tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if demo_summarize_prompt is not None: + end_tag = encode_tokens(tokenizer, "\n ", bos=False, device=device) + encoded = encoded[:prefill_size-end_tag.size(0)] + encoded = torch.cat((encoded, end_tag), dim=0) + prompt_length = encoded.size(0) torch.manual_seed(1234) @@ -390,6 +401,8 @@ def ffn_or_attn_only(mod, fqn): quantize_( model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only ) + elif "int8dq_prefill_wo_decode" in quantization: + quantize_(model, int8_dynamic_activation_int8_weight(weight_only_decode=True)) else: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: @@ -809,14 +822,23 @@ def callback(x): nonlocal done_generating if done_generating: return - buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:]) if x.item() == tokenizer.eos_id(): done_generating = True if len(buffer) == 4 or done_generating: print("".join(buffer), end="", flush=True) buffer.clear() - # print(, end='', flush=True) + # print(, end="", flush=True) + + elif demo_summarize_prompt is not None and i >= 0: + buffer = [] + period_id = tokenizer.encode(".")[0] + def callback(x): + buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:]) + if len(buffer) == 4: + print("".join(buffer), end="", flush=True) + buffer.clear() else: callback = lambda x: x t0 = time.perf_counter() @@ -851,7 +873,7 @@ def callback(x): decode_start_event=decode_start_event, decode_end_event=decode_end_event, ) - if i == -1: + if i < 0: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") continue if hasattr(prof, "export_chrome_trace"): @@ -859,7 +881,7 @@ def callback(x): device_sync(device=device) # MKG t = time.perf_counter() - t0 - if not interactive and prefill_size is None: + if not interactive and demo_summarize_prompt is None: tok_list = y[0].tolist() # truncate text after end of string token tokens = ( @@ -869,7 +891,7 @@ def callback(x): ) print(tokenizer.decode(tokens)) else: - print() + print("\n") tokens_generated = y.size(-1) - prompt_length tokens_sec = tokens_generated / t aggregate_metrics["tokens_per_sec"].append(tokens_sec) @@ -913,7 +935,7 @@ def callback(x): bandwidth = model_size * tokpersec mem = torch.cuda.max_memory_reserved() / 1e9 print(f"Average overall tokens/sec: {tokpersec:.2f}") - print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s") + print(f"Average decode tokens/sec: {decode_tokpersec:.04f} s") print(f"Average TTFT: {ttft:.04f} s") if device == "cuda": mem = torch.cuda.max_memory_reserved() / 1e9 @@ -975,6 +997,9 @@ def callback(x): parser.add_argument( "--prompt", type=str, default="Hello, my name is", help="Input prompt." ) + parser.add_argument( + "--demo_summarize_prompt", type=str, help="Read prompt from text file" + ) parser.add_argument( "--interactive", action="store_true", @@ -1073,6 +1098,7 @@ def callback(x): main( args.prefill_size, args.prompt, + args.demo_summarize_prompt, args.interactive, args.num_samples, args.max_new_tokens, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index af950cb79..c7c0de04d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -803,8 +803,33 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: ) +def _int8_symm_per_token_reduced_range_quant_noop_decode( + x: torch.Tensor, +) -> torch.Tensor: + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = 1e-5 + quant_min = -127 + quant_max = 127 + if x.shape[1] == 1: + return x + else: + return to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + eps=eps, + quant_min=quant_min, + quant_max=quant_max, + scale_dtype=torch.float32 if x.dtype == torch.float16 else None, + ) + + def int8_dynamic_activation_int8_weight( - layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC + layout=PlainLayout(), + act_mapping_type=MappingType.SYMMETRIC, + weight_only_decode=False, ): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight @@ -831,11 +856,14 @@ def get_weight_block_size(x): eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - # input settings - if act_mapping_type == MappingType.SYMMETRIC: - input_quant_func = _int8_symm_per_token_reduced_range_quant + if weight_only_decode: + input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode else: - input_quant_func = _int8_asymm_per_token_quant + # input settings + if act_mapping_type == MappingType.SYMMETRIC: + input_quant_func = _int8_symm_per_token_reduced_range_quant + else: + input_quant_func = _int8_asymm_per_token_quant block_size = get_weight_block_size(weight) weight = to_affine_quantized_intx(