-
Notifications
You must be signed in to change notification settings - Fork 191
/
generate.py
1124 lines (1009 loc) · 41.4 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import platform
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple
import torch
import torch._dynamo.config
import torch._inductor.config
import torchao
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5
from torchao._models.utils import (
get_arch_name,
write_json_result_ossci,
write_json_result_local,
)
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
torch.backends.cuda.enable_cudnn_sdp(True)
class HostEvent:
def __init__(self):
self.event_time = None
def record(self):
self.event_time = time.perf_counter()
def elapsed_time(self, other_event):
if self.event_time is None:
raise ValueError("Event not recorded!")
# return ms to match cuda event
return abs(other_event.event_time - self.event_time) * 1000
def device_timer(device):
if "cuda" in device:
return torch.cuda.Event(enable_timing=True)
elif ("cpu" in device) or ("mps" in device):
return HostEvent()
else:
print(f"device={device} is not yet suppported")
def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={device} is not yet suppported")
default_device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "cpu"
)
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from torchao._models.llama.model import prepare_inputs_for_model, Transformer
from torchao._models.llama.tokenizer import get_tokenizer
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def prefill(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)[0]
def decode_one_token(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)
def decode_n_tokens(
model: Transformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
callback=lambda _: _,
**sampling_kwargs,
):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token
return new_tokens, new_probs
def model_forward(model, x, input_pos):
return model(x, input_pos)
@torch.no_grad()
def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
batch_size: int,
*,
interactive: bool,
callback=lambda x: x,
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool = False,
prefill_start_event: Optional[torch.cuda.Event] = None,
prefill_end_event: Optional[torch.cuda.Event] = None,
decode_start_event: Optional[torch.cuda.Event] = None,
decode_end_event: Optional[torch.cuda.Event] = None,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
device = prompt.device
T = prompt.size(-1)
# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
max_seq_length = (
min(T + max_new_tokens, model.config.block_size) if not interactive else 350
)
new_tokens = max_seq_length - T
# format model input
prompt, input_pos = prepare_inputs_for_model(prompt)
prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize
# full prompt+output will be stored in seq
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
seq[:, :T] = prompt
# setup model caches
with torch.device(device):
if cache_size is None:
cache_size = max_seq_length
assert (
cache_size >= max_seq_length
), "need cache_size to be greater than max_new_tokens + size-of-prompt"
model.setup_caches(
max_batch_size=batch_size,
max_seq_length=cache_size,
kv_cache_quantization=kv_cache_quantization,
linear_causal_mask=linear_causal_mask,
prompt_length=T,
)
# execute prefill
if prefill_start_event is not None:
prefill_start_event.record()
next_token = prefill(
model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs
).clone()
seq[:, T] = next_token.squeeze()
if prefill_end_event is not None:
prefill_end_event.record()
# execute token generation
if decode_start_event is not None:
decode_start_event.record()
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(
model,
next_token.view(batch_size, -1),
input_pos,
new_tokens - 1,
callback=callback,
**sampling_kwargs,
)
seq = torch.cat((seq[:, : T + 1], *generated_tokens), dim=-1)
if decode_end_event is not None:
decode_end_event.record()
return seq
def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)
def _load_model(checkpoint_path, device, precision):
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]
with torch.device("meta"):
model = Transformer.from_name(checkpoint_path.parent.name)
model.load_state_dict(checkpoint, assign=True)
model = model.to(device=device, dtype=precision)
return model.eval()
B_INST, E_INST = "[INST]", "[/INST]"
def main(
prefill_size: Optional[int] = None,
prompt: str = "Hello, my name is",
demo_summarize_prompt: Optional[str] = None,
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
batch_size: int = 1,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path(
"checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"
),
quantization: Optional[str] = None,
sparsity: Optional[str] = None,
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool = False,
save: bool = False,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
memory_profile: Optional[Path] = None,
device=default_device,
precision=torch.bfloat16,
write_result: Optional[Path] = None,
output_json_path: Optional[Path] = None,
output_json_local: bool = False,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
if prefill_size is not None and prefill_size > 0:
# create prompt of prefill size
if demo_summarize_prompt is None:
prompt = "prompt " * (int(prefill_size) - 2)
else:
with open(demo_summarize_prompt, "r") as f:
prompt = f.read()
torchao.quantization.utils.recommended_inductor_config_setter()
assert checkpoint_path.is_file(), checkpoint_path
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
print(f"Using device={device}")
is_chat = "chat" in str(checkpoint_path)
print("Loading model ...")
t0 = time.time()
model = _load_model(checkpoint_path, device, precision)
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
if demo_summarize_prompt is not None:
end_tag = encode_tokens(tokenizer, "\n <END_TEXT>", bos=False, device=device)
encoded = encoded[:prefill_size-end_tag.size(0)]
encoded = torch.cat((encoded, end_tag), dim=0)
prompt_length = encoded.size(0)
torch.manual_seed(1234)
def ffn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn
def not_ffn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn)
def ffn_or_attn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and (
"feed_forward" in fqn or "attention" in fqn
)
if quantization:
from torchao.quantization import (
autoquant,
float8_dynamic_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
uintx_weight_only,
gemlite_uintx_weight_only,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.utils import unwrap_tensor_subclass
if "spinquant" in quantization:
from torchao.prototype.spinquant import apply_spinquant
apply_spinquant(model)
if quantization.startswith("gemlite"):
import os, pwd
import gemlite
from gemlite.core import GemLiteLinearTriton, set_autotune
_quant_args = quantization.split("-")
bit_width = int(_quant_args[-2])
group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1])
try:
packing_bitwidth = int(_quant_args[-3])
except:
# if only 2 inputs found, use default value
packing_bitwidth = 32
quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))
# try to load gemlite kernel config
try:
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
print(f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
except:
print(f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
print("running gemlite warmup")
generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
)
GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
if sparsity and "semi" in sparsity:
from torchao.dtypes import SemiSparseLayout
quantize_(
model,
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
filter_fn=ffn_only,
)
quantize_(
model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only
)
elif "int8dq_prefill_wo_decode" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight(weight_only_decode=True))
else:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
if "hqq" in quantization:
use_hqq = True
else:
use_hqq = False
group_size = int(quantization.split("-")[1])
assert group_size in [
32,
64,
128,
256,
], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
if "marlin" in quantization:
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=128,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
elif "semi" in sparsity:
from torchao.dtypes import MarlinSparseLayout
quantize_(
model,
int4_weight_only(layout=MarlinSparseLayout()),
filter_fn=ffn_or_attn_only,
)
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
elif "embed-int8wo" in quantization:
quantize_(
model,
int8_weight_only(group_size=64),
filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding),
)
elif quantization.startswith("awq"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
if not TORCH_VERSION_AT_LEAST_2_3:
print("Awq requires torch2.3+")
exit()
from torchao.prototype.awq import (
awq_uintx,
AWQObservedLinear,
insert_awq_observer_,
)
quant_dtype = quantization.split("-")[1]
group_size = int(quantization.split("-")[2])
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
model = model.to(device)
# get calibration data
insert_awq_observer_(
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
)
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=256,
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=["wikitext"],
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
quantize_(
model,
awq_uintx(
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
),
is_observed_linear,
)
elif "uintx" in quantization:
# uintx-nbits-group_size, e.g. "uintx-2-64"
if "hqq" in quantization:
# uintx-nbits-group_size-hqq
use_hqq = True
else:
use_hqq = False
_quant_args = quantization.split("-")
nbits = int(_quant_args[1])
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
_NBITS_TO_DTYPE = {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
8: torch.uint8,
}
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
elif "int8_dynamic_activation_intx_weight" in quantization:
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
)
assert (
precision == torch.float32
), "int8_dynamic_activation_intx_weight requires fp32 precision"
# Build kernels in temp location, and load them in torch
# This requires an ARM CPU
from torchao.experimental.temp_build import temp_build_and_load_torchao_ops
temp_build_and_load_torchao_ops(
cmake_lists_path=os.path.dirname(os.path.realpath(__file__))
+ "/../../experimental"
)
# Quantize model
_quant_args = quantization.split("-")
nbit = int(_quant_args[1])
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
group_size = int(_quant_args[2])
has_weight_zeros = bool(_quant_args[3])
quantize_(
model,
int8_dynamic_activation_intx_weight(
group_size=group_size,
nbit=nbit,
has_weight_zeros=has_weight_zeros,
),
)
elif "float8wo" in quantization:
quantize_(model, float8_weight_only())
elif "float8dq" in quantization:
granularity = str(quantization.split("-")[-1])
if granularity == "tensor":
granularity = PerTensor()
elif granularity == "row":
granularity = PerRow()
else:
granularity = PerTensor()
quantize_(
model, float8_dynamic_activation_float8_weight(granularity=granularity)
)
elif "autoquant_v2" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
calibration_seq_length = 256
calibration_limit = 1
inputs = (
InputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
False, # pad_calibration_inputs
model.config.vocab_size,
device="cuda",
)
.record_inputs(
["wikitext"],
1,
)
.get_inputs()[0]
.values[0]
)
inputs = prepare_inputs_for_model(inputs)
with torch.device("cuda"):
model.setup_caches(
max_batch_size=1, max_seq_length=calibration_seq_length
)
if "autoquant_v2-int4" == quantization:
model = autoquant_v2(
model,
manual=True,
qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
example_input=inputs,
batch_size=calibration_seq_length,
)
elif "autoquant_v2-float8" == quantization:
model = autoquant_v2(
model,
manual=True,
qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST,
example_input=inputs,
batch_size=calibration_seq_length,
)
elif "autoquant_v2-fp" == quantization:
model = autoquant_v2(
model,
manual=True,
qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
example_input=inputs,
batch_size=calibration_seq_length,
)
elif "autoquant_v2-all" == quantization:
all_qtensor_classes = (
torchao.prototype.quantization.autoquant_v2.DEFAULT_AUTOQUANT_CLASS_LIST
+ torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST
+ torchao.prototype.quantization.autoquant_v2.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
)
if torchao.utils.is_sm_89():
# this is fp8 related subclasses, should rename
all_qtensor_classes += (
torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST
)
model = autoquant_v2(
model,
manual=True,
qtensor_class_list=all_qtensor_classes,
example_input=inputs,
batch_size=calibration_seq_length,
)
else:
model = autoquant_v2(
model,
manual=True,
example_input=inputs,
batch_size=calibration_seq_length,
)
print("running generate")
generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
)
print("running finalize autoquant")
# do autoquantization
model.finalize_autoquant()
elif "autoquant" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
calibration_seq_length = 256
calibration_limit = 1
inputs = (
InputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
False, # pad_calibration_inputs
model.config.vocab_size,
device="cuda",
)
.record_inputs(
["wikitext"],
1,
)
.get_inputs()[0]
.values[0]
)
inputs = prepare_inputs_for_model(inputs)
with torch.device("cuda"):
model.setup_caches(
max_batch_size=1, max_seq_length=calibration_seq_length
)
if "autoquant-int4" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
elif "autoquant-float8" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
elif "autoquant-fp" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
elif "autoquant-sparse" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
elif "autoquant-gemlite-int4" == quantization:
import os, pwd
from gemlite.core import GemLiteLinearTriton
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.GEMLITE_INT4_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
elif "autoquant-all" == quantization:
try:
import os, pwd
from gemlite.core import GemLiteLinearTriton
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
except:
pass
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
else:
model = autoquant(model, manual=True, example_input=inputs)
generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
)
# do autoquantization
model.finalize_autoquant()
elif "codebook" in quantization:
from torchao.prototype.quantization.codebook import codebook_weight_only
model.to(device)
quantize_(model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64))
else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
# standalone sparsity
elif sparsity:
from torchao.sparsity import semi_sparse_weight, sparsify_
if "semi" in sparsity:
# TODO there is a bug here, need to fix
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
if save:
output_dir = str(checkpoint_path.cwd())
filename = str(checkpoint_path.name).split(".")[0]
torch.save(
model.state_dict(),
os.path.join(output_dir, filename + f"-{quantization}.pt"),
)
if compile:
print("Compiling Model")
global decode_one_token, prefill
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
)
if compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
if memory_profile:
if device == "cuda":
torch.cuda.memory._record_memory_history(
True, trace_alloc_max_entries=250000, trace_alloc_record_context=True
)
elif device == "xpu":
torch.xpu.memory._record_memory_history(
True, trace_alloc_max_entries=250000, trace_alloc_record_context=True
)
else:
print("Memory profiling only works on CUDA or XPU devices")
aggregate_metrics = {
"tokens_per_sec": [],
"time": [],
"decode_tokens_per_sec": [],
"prefill_time": [],
}
start = -1 if compile else 0
for i in range(start, num_samples):
if i == 0:
if device == "cuda":
torch.cuda.reset_peak_memory_stats() # MKG
elif device == "xpu":
torch.xpu.reset_peak_memory_stats() # MKG
device_sync(device=device) # MKG
if i >= 0 and interactive:
prompt = input("What is your prompt? ")
if is_chat:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
if interactive and i >= 0:
buffer = []
period_id = tokenizer.encode(".")[0]
done_generating = False
def callback(x):
nonlocal done_generating
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
if x.item() == tokenizer.eos_id():
done_generating = True
if len(buffer) == 4 or done_generating:
print("".join(buffer), end="", flush=True)
buffer.clear()
# print(, end="", flush=True)
elif demo_summarize_prompt is not None and i >= 0:
buffer = []
period_id = tokenizer.encode(".")[0]
def callback(x):
buffer.append(tokenizer.decode([period_id] + x.squeeze(0).tolist())[1:])
if len(buffer) == 4:
print("".join(buffer), end="", flush=True)
buffer.clear()
else:
callback = lambda x: x
t0 = time.perf_counter()
prefill_start_event, prefill_end_event = device_timer(device), device_timer(
device
)
decode_start_event, decode_end_event = device_timer(device), device_timer(
device
)
import contextlib
if i != num_samples - 1 or not profile:
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
prof = torch.profiler.profile()
with prof:
y = generate(
model,
encoded,
max_new_tokens,
batch_size,
interactive=interactive,
callback=callback,
temperature=temperature,
top_k=top_k,
kv_cache_quantization=kv_cache_quantization,
cache_size=cache_size,
linear_causal_mask=linear_causal_mask,
prefill_start_event=prefill_start_event,
prefill_end_event=prefill_end_event,
decode_start_event=decode_start_event,
decode_end_event=decode_end_event,
)
if i < 0:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "export_chrome_trace"):
prof.export_chrome_trace(f"{profile}.json")
device_sync(device=device) # MKG
t = time.perf_counter() - t0
if not interactive and demo_summarize_prompt is None:
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = (
tok_list
if tokenizer.eos_id() not in tok_list
else tok_list[: tok_list.index(tokenizer.eos_id())]
)
print(tokenizer.decode(tokens))
else:
print("\n")
tokens_generated = y.size(-1) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
aggregate_metrics["time"].append(t)
decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000
decode_tokens_sec = tokens_generated / decode_time
aggregate_metrics["decode_tokens_per_sec"].append(decode_tokens_sec)
prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000
aggregate_metrics["prefill_time"].append(prefill_time)
print(
f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec",
f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec",
)
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")
if memory_profile and i == 0:
if device == "cuda":
snapshot = torch.cuda.memory._snapshot()
elif device == "xpu":
snapshot = torch.xpu.memory._snapshot()
else:
print("Memory profiling only works on CUDA or XPU devices")
with open(f"{memory_profile}.pickle", "wb") as f:
from pickle import dump
dump(snapshot, f)
print(
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html",
)
break
print("==========")
# ignore first sample for warmup
tokpersec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item()
ttft = torch.mean(torch.tensor(aggregate_metrics["prefill_time"])).item()
decode_tokpersec = torch.mean(
torch.tensor(aggregate_metrics["decode_tokens_per_sec"])
).item()
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() / 1e9
print(f"Average overall tokens/sec: {tokpersec:.2f}")
print(f"Average decode tokens/sec: {decode_tokpersec:.04f} s")
print(f"Average TTFT: {ttft:.04f} s")
if device == "cuda":
mem = torch.cuda.max_memory_reserved() / 1e9
elif device == "xpu":
mem = torch.xpu.max_memory_reserved() / 1e9
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
print(f"Peak Memory Usage: {mem:.02f} GB")
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += "repro: python generate.py "
result_txt += f"--quantization {quantization} " if quantization else ""
result_txt += f"--sparsity {sparsity} " if sparsity else ""
result_txt += f"--checkpoint_path {checkpoint_path} "
result_txt += f"--device {device} "
result_txt += f"--precision {precision} "
result_txt += "--compile " if compile else ""
result_txt += "--compile_prefill " if compile_prefill else ""
result_txt += f"--prefill_size {prefill_size}" if prefill_size else ""
result_txt += f"--profile {profile} " if profile else ""
result_txt += f"--profile {memory_profile} " if memory_profile else ""
result_txt += "--interactive " if interactive else ""
result_txt += f"--num_samples {num_samples} "
result_txt += f"--max_new_tokens {max_new_tokens} "
result_txt += f"--batch_size {batch_size} "
result_txt += f"--top_k {top_k} "
result_txt += f"--temperature {temperature} "
result_txt += f"--cache_size {cache_size}" if cache_size else ""
result_txt += "--kv_cache_quantization " if kv_cache_quantization else ""
result_txt += "--linear_causal_mask " if linear_causal_mask else ""
f = open(write_result, "a")
f.write(result_txt)
f.close()
if output_json_path:
headers = ["name", "dtype", "device", "arch", "metric", "actual", "target"]
name = checkpoint_path.parent.name
arch = get_arch_name()
dtype = quantization or "noquant"
memory_result = [name, dtype, device, arch, "mem/s", bandwidth, None]
performance_result = [name, dtype, device, arch, "tok/s", tokpersec, None]
write_json_result = write_json_result_local if output_json_local else write_json_result_ossci
write_json_result(output_json_path, headers, memory_result)
write_json_result(output_json_path, headers, performance_result)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Your CLI description.")
parser.add_argument(
"--prefill_size", type=int, default=None, help="Whether to run in ttft mode"
)
parser.add_argument(
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
)
parser.add_argument(