Skip to content

Commit

Permalink
int8 dynamic prefill weight only decode (#1436)
Browse files Browse the repository at this point in the history
This PR adds in weight_only_decode option to int8_dynamic_activation_int8_weight, which when set will use dynamic quantization for matmuls of shape (> 1, x) * (x, n) and weight only quantization for the batch_size=1 case.

It also updates generate.py to take in a text file for the prompt, we use this to demonstrate these prefill speedups with sh demo_summarize.sh.
  • Loading branch information
jcaip authored Dec 30, 2024
1 parent 52a5137 commit 52b6f4d
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 14 deletions.
2 changes: 2 additions & 0 deletions scripts/prepare.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B-Instruct
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B-Instruct
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
# neuralmagic doesn't come with tokenizer, so we need to copy it over
mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model
Expand Down
1 change: 1 addition & 0 deletions torchao/_models/llama/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
moby.txt
8 changes: 8 additions & 0 deletions torchao/_models/llama/demo_summarize.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# grab moby dick prompt
wget -nc -O moby.txt https://gist.githubusercontent.com/jcaip/f319146bb543e92e23b2c76815b0f29f/raw/31a9cd12b0b59f323eb197c9534953bdac352986/gistfile1.txt

export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B-Instruct

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq_prefill_wo_decode --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --prefill_size 8192 --max_new_tokens 256 --num_samples 1 --demo_summarize_prompt moby.txt
11 changes: 10 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_evaluation(
device = "cuda",
precision = torch.bfloat16,
quantization: Optional[str] = None,
sparsity:Optional[str] = None,
compile=False,
max_length=None,
calibration_tasks: Optional[List[str]] = None,
Expand All @@ -44,7 +45,7 @@ def run_evaluation(
"""Runs the evaluation of a model using LM Eval."""
print(
f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, "
+f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
+f"quantization: {quantization}, sparsity: {sparsity}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
+f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n"
)
torchao.quantization.utils.recommended_inductor_config_setter()
Expand Down Expand Up @@ -236,6 +237,13 @@ def run_evaluation(
"float8wo, float8dq, float8saq"
),
)
parser.add_argument(
"--sparsity",
type=str,
help=(
"Which sparsity techniques to apply: semi-structured"
),
)
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
Expand All @@ -251,6 +259,7 @@ def run_evaluation(
args.device,
args.precision,
args.quantization,
args.sparstiy,
args.compile,
args.max_length,
args.calibration_tasks,
Expand Down
42 changes: 34 additions & 8 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

torch.backends.cuda.enable_cudnn_sdp(True)

class HostEvent:
def __init__(self):
Expand Down Expand Up @@ -256,6 +256,7 @@ def _load_model(checkpoint_path, device, precision):
def main(
prefill_size: Optional[int] = None,
prompt: str = "Hello, my name is",
demo_summarize_prompt: Optional[str] = None,
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
Expand Down Expand Up @@ -285,7 +286,11 @@ def main(

if prefill_size is not None and prefill_size > 0:
# create prompt of prefill size
prompt = "prompt " * (int(prefill_size) - 3)
if demo_summarize_prompt is None:
prompt = "prompt " * (int(prefill_size) - 2)
else:
with open(demo_summarize_prompt, "r") as f:
prompt = f.read()

torchao.quantization.utils.recommended_inductor_config_setter()

Expand All @@ -306,6 +311,12 @@ def main(
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)

if demo_summarize_prompt is not None:
end_tag = encode_tokens(tokenizer, "\n <END_TEXT>", bos=False, device=device)
encoded = encoded[:prefill_size-end_tag.size(0)]
encoded = torch.cat((encoded, end_tag), dim=0)

prompt_length = encoded.size(0)

torch.manual_seed(1234)
Expand Down Expand Up @@ -390,6 +401,8 @@ def ffn_or_attn_only(mod, fqn):
quantize_(
model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only
)
elif "int8dq_prefill_wo_decode" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight(weight_only_decode=True))
else:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
Expand Down Expand Up @@ -809,14 +822,23 @@ def callback(x):
nonlocal done_generating
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
if x.item() == tokenizer.eos_id():
done_generating = True
if len(buffer) == 4 or done_generating:
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end='', flush=True)
# print(, end="", flush=True)

elif demo_summarize_prompt is not None and i >= 0:
buffer = []
period_id = tokenizer.encode(".")[0]

def callback(x):
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
if len(buffer) == 4:
print("".join(buffer), end="", flush=True)
buffer.clear()
else:
callback = lambda x: x
t0 = time.perf_counter()
Expand Down Expand Up @@ -851,15 +873,15 @@ def callback(x):
decode_start_event=decode_start_event,
decode_end_event=decode_end_event,
)
if i == -1:
if i < 0:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "export_chrome_trace"):
prof.export_chrome_trace(f"{profile}.json")
device_sync(device=device) # MKG
t = time.perf_counter() - t0

if not interactive and prefill_size is None:
if not interactive and demo_summarize_prompt is None:
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = (
Expand All @@ -869,7 +891,7 @@ def callback(x):
)
print(tokenizer.decode(tokens))
else:
print()
print("\n")
tokens_generated = y.size(-1) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
Expand Down Expand Up @@ -913,7 +935,7 @@ def callback(x):
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() / 1e9
print(f"Average overall tokens/sec: {tokpersec:.2f}")
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
print(f"Average decode tokens/sec: {decode_tokpersec:.04f} s")
print(f"Average TTFT: {ttft:.04f} s")
if device == "cuda":
mem = torch.cuda.max_memory_reserved() / 1e9
Expand Down Expand Up @@ -975,6 +997,9 @@ def callback(x):
parser.add_argument(
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
)
parser.add_argument(
"--demo_summarize_prompt", type=str, help="Read prompt from text file"
)
parser.add_argument(
"--interactive",
action="store_true",
Expand Down Expand Up @@ -1073,6 +1098,7 @@ def callback(x):
main(
args.prefill_size,
args.prompt,
args.demo_summarize_prompt,
args.interactive,
args.num_samples,
args.max_new_tokens,
Expand Down
38 changes: 33 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,33 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
)


def _int8_symm_per_token_reduced_range_quant_noop_decode(
x: torch.Tensor,
) -> torch.Tensor:
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
if x.shape[1] == 1:
return x
else:
return to_affine_quantized_intx(
x,
mapping_type,
_get_per_token_block_size(x),
target_dtype,
eps=eps,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=torch.float32 if x.dtype == torch.float16 else None,
)


def int8_dynamic_activation_int8_weight(
layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC
layout=PlainLayout(),
act_mapping_type=MappingType.SYMMETRIC,
weight_only_decode=False,
):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
Expand All @@ -831,11 +856,14 @@ def get_weight_block_size(x):
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
if act_mapping_type == MappingType.SYMMETRIC:
input_quant_func = _int8_symm_per_token_reduced_range_quant
if weight_only_decode:
input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode
else:
input_quant_func = _int8_asymm_per_token_quant
# input settings
if act_mapping_type == MappingType.SYMMETRIC:
input_quant_func = _int8_symm_per_token_reduced_range_quant
else:
input_quant_func = _int8_asymm_per_token_quant

block_size = get_weight_block_size(weight)
weight = to_affine_quantized_intx(
Expand Down

0 comments on commit 52b6f4d

Please sign in to comment.