forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
1432 lines (1256 loc) · 58.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
from typing import Optional
import tensorrt as trt
from tensorrt_llm._common import default_net
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType,
MLPType, PositionEmbeddingType, Tensor,
assertion, gather_last_token_logits, gelu,
recv, send, shape, transpose)
from tensorrt_llm.layers import (MLP, Attention, AttentionMaskType,
AttentionParams, BertAttention, ColumnLinear,
Conv1d, Embedding, FusedGatedMLP, GatedMLP,
GroupNorm, KeyValueCacheParams, LayerNorm,
PromptTuningEmbedding, RmsNorm)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.generation_mixin import GenerationMixin
from tensorrt_llm.module import Module, ModuleList
from tensorrt_llm.parameter import Parameter
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
layernorm_map = {
LayerNormType.LayerNorm: LayerNorm,
LayerNormType.RmsNorm: RmsNorm,
LayerNormType.GroupNorm: GroupNorm,
}
mlp_map = {
MLPType.MLP: MLP,
MLPType.GatedMLP: GatedMLP,
MLPType.FusedGatedMLP: FusedGatedMLP,
}
class EncDecEmbedding(Module):
def __init__(self,
vocab_size,
hidden_size,
max_position_embeddings=None,
has_position_embedding=False,
type_vocab_size=None,
has_embedding_layernorm=False,
has_embedding_scale=False,
layernorm_eps=1e-5,
layernorm_type=LayerNormType.LayerNorm,
dtype=None,
use_prompt_tuning=False,
use_parallel_embedding=False,
embedding_sharding_dim=0,
mapping=Mapping()):
super().__init__()
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
self.use_prompt_tuning = use_prompt_tuning
EmbeddingCls = PromptTuningEmbedding if use_prompt_tuning else Embedding
self.vocab_embedding = EmbeddingCls(
vocab_size,
hidden_size,
dtype=dtype,
tp_size=mapping.tp_size if use_parallel_embedding else 1,
tp_group=mapping.tp_group if use_parallel_embedding else None,
sharding_dim=embedding_sharding_dim,
tp_rank=mapping.tp_rank)
self.position_embedding = None
self.max_position_embeddings = max_position_embeddings
if has_position_embedding:
self.position_embedding = Embedding(
max_position_embeddings,
hidden_size,
dtype=dtype,
tp_size=mapping.tp_size if use_parallel_embedding else 1,
tp_group=mapping.tp_group if use_parallel_embedding else None,
sharding_dim=embedding_sharding_dim,
tp_rank=mapping.tp_rank)
self.token_type_embedding = None
if type_vocab_size:
self.token_type_embedding = Embedding(
type_vocab_size,
hidden_size,
dtype=dtype,
tp_size=mapping.tp_size if use_parallel_embedding else 1,
tp_group=mapping.tp_group if use_parallel_embedding else None,
sharding_dim=embedding_sharding_dim,
tp_rank=mapping.tp_rank)
# e.g. BART true, T5 false
self.embedding_layernorm = None
if has_embedding_layernorm:
self.embedding_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# e.g. BART true, T5 false
self.embedding_scale = 1.0
if has_embedding_scale:
self.embedding_scale = math.sqrt(hidden_size)
# Note: embedding offset in BART is not considered as a standard. For the specific case,
# we just need to shrink its position embedding table by [offset:] during weight loading
def forward(self,
input_ids,
position_ids=None,
token_type_ids=None,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None):
# position_ids and token_type_ids are provided inputs
# and should not be formulated determinisitically
ptuning_args = []
if self.use_prompt_tuning:
ptuning_args = [
prompt_embedding_table, prompt_tasks, prompt_vocab_size
]
x = self.vocab_embedding(input_ids, *
ptuning_args) * self.embedding_scale
self.register_network_output('word_embeddings', x)
if self.position_embedding:
pos_emb = self.position_embedding(position_ids)
self.register_network_output('position_embeddings', pos_emb)
x = x + pos_emb
if self.token_type_embedding:
x = x + self.token_type_embedding(token_type_ids)
if self.embedding_layernorm:
x = self.embedding_layernorm(x)
return x
class EncoderLayer(Module):
def __init__(self,
hidden_size,
ffn_hidden_size,
num_attention_heads,
num_kv_heads,
head_size,
max_position_embeddings=None,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.LayerNorm,
layernorm_eps=1e-5,
hidden_act="relu",
mlp_type=MLPType.MLP,
mapping=Mapping(),
dtype=None,
residual_scaling=1.0,
relative_attention=False,
max_distance=0,
num_buckets=0):
super().__init__()
# e.g. BART regular, T5 RMS
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
# e.g. BART post, T5 pre
self.layernorm_position = layernorm_position
# e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
self.attention = BertAttention(
hidden_size,
num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
relative_attention=relative_attention,
max_distance=max_distance,
num_buckets=num_buckets)
self.attention_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# T5/BART MLP, Flan-T5 GatedMLP
self.mlp_type = mlp_type
mlp_f = mlp_map[mlp_type]
self.mlp = mlp_f(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
bias=has_mlp_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
dtype=dtype,
)
self.mlp_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.residual_scaling = residual_scaling
def forward(self,
hidden_states: Tensor,
attention_mask=None,
input_lengths=None,
max_input_length=None):
assert isinstance(hidden_states, Tensor)
# self attention
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.attention_layernorm(hidden_states)
attention_output = self.attention(hidden_states,
attention_mask=attention_mask,
input_lengths=input_lengths,
max_input_length=max_input_length)
self.register_network_output('attention_output', attention_output)
hidden_states = residual + attention_output
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.attention_layernorm(hidden_states)
# MLP
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
self.register_network_output('mlp_output', hidden_states)
hidden_states = residual + hidden_states
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
return hidden_states
class DecoderLayer(Module):
def __init__(self,
hidden_size,
ffn_hidden_size,
num_attention_heads,
num_kv_heads,
head_size,
max_position_embeddings=None,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.LayerNorm,
layernorm_eps=1e-5,
hidden_act="relu",
mlp_type=MLPType.MLP,
mapping=Mapping(),
dtype=None,
residual_scaling=1.0,
relative_attention=False,
max_distance=0,
num_buckets=0):
super().__init__()
# e.g. BART regular, T5 RMS
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
# e.g. BART post, T5 pre
self.layernorm_position = layernorm_position
# e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
self.self_attention = Attention(
hidden_size,
num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
attention_mask_type=AttentionMaskType.causal,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
cross_attention=False,
relative_attention=relative_attention,
max_distance=max_distance,
num_buckets=num_buckets,
position_embedding_type=PositionEmbeddingType.relative
if relative_attention else PositionEmbeddingType.learned_absolute)
self.self_attention_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# Note: self attn uses MMHA, mask is always causal triangular
# cross attn has two scenarios:
# - in context phase, all ones mask, same as padding type
# - in generation phase, same causal triangular mask as MMHA
# - context phase special handling is done in plugin by resetting mask type
#
# e.g. BART q_scaling = 1.f, T5 q_scaling = 1.f/sqrt(head_size)
self.cross_attention = Attention(
hidden_size,
num_attention_heads,
attention_head_size=head_size,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
bias=has_attention_qkvo_bias,
attention_mask_type=AttentionMaskType.causal,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
tp_rank=mapping.tp_rank,
dtype=dtype,
cross_attention=True,
relative_attention=
False, # Cross attention has no relative attention bias
max_distance=max_distance,
num_buckets=num_buckets,
position_embedding_type=PositionEmbeddingType.learned_absolute)
self.cross_attention_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
# T5/BART MLP, Flan-T5 GatedMLP
self.mlp_type = mlp_type
mlp_f = mlp_map[mlp_type]
self.mlp = mlp_f(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=hidden_act,
bias=has_mlp_bias,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
dtype=dtype,
)
self.mlp_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.residual_scaling = residual_scaling
def forward(self,
hidden_states: Tensor,
encoder_output: Optional[Tensor] = None,
attention_mask=None,
cross_attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None):
assert isinstance(hidden_states, Tensor)
if encoder_output:
assert isinstance(encoder_output, Tensor)
# self-attention
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.self_attention_layernorm(hidden_states)
attention_output = self.self_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params)
if use_cache:
attention_output, presents_self = attention_output
self.register_network_output('self_attention_output', attention_output)
hidden_states = residual + attention_output
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.self_attention_layernorm(hidden_states)
# cross attention
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.cross_attention_layernorm(hidden_states)
attention_output = self.cross_attention(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
encoder_output=encoder_output,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params)
if use_cache:
attention_output, presents_cross = attention_output
self.register_network_output('cross_attention_output', attention_output)
hidden_states = residual + attention_output
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.cross_attention_layernorm(hidden_states)
# MLP
residual = hidden_states * self.residual_scaling
if self.layernorm_position == LayerNormPositionType.pre_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
self.register_network_output('mlp_output', hidden_states)
hidden_states = residual + hidden_states
if self.layernorm_position == LayerNormPositionType.post_layernorm:
hidden_states = self.mlp_layernorm(hidden_states)
if use_cache:
return (hidden_states, presents_self, presents_cross)
return hidden_states
class EncoderModel(Module, GenerationMixin):
def __init__(self,
num_layers,
num_heads,
hidden_size,
ffn_hidden_size,
vocab_size,
dtype,
head_size=None,
num_kv_heads=None,
max_position_embeddings=None,
has_position_embedding=False,
relative_attention=False,
max_distance=None,
num_buckets=None,
type_vocab_size=None,
has_embedding_layernorm=False,
has_embedding_scale=False,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
has_model_final_layernorm=False,
layernorm_eps=1e-5,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.LayerNorm,
hidden_act="relu",
mlp_type=MLPType.MLP,
residual_scaling=1.0,
use_prompt_tuning=False,
use_parallel_embedding=False,
embedding_sharding_dim=0,
mapping=Mapping()):
super().__init__()
self.mapping = mapping
self.has_position_embedding = has_position_embedding
self.has_token_type_embedding = type_vocab_size is not None
# e.g. BART regular, T5 RMS
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
# e.g. BART true, T5 false
self.has_attention_qkvo_bias = has_attention_qkvo_bias
self.has_mlp_bias = has_mlp_bias
# e.g. BART false, T5 true
self.has_model_final_layernorm = has_model_final_layernorm
if isinstance(dtype, str):
self._dtype = str_dtype_to_trt(dtype)
else:
assert isinstance(dtype, trt.DataType)
self._dtype = dtype
self.total_num_layers = num_layers
self.num_layers = num_layers // self.mapping.pp_size
self.hidden_size = hidden_size
self.num_heads = num_heads
if num_kv_heads is None or num_kv_heads <= 0:
num_kv_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_size = self.hidden_size // self.num_heads if head_size is None else head_size
if self.mapping.is_first_pp_rank():
self.embedding = EncDecEmbedding(
vocab_size,
hidden_size,
max_position_embeddings=max_position_embeddings,
has_position_embedding=has_position_embedding,
type_vocab_size=type_vocab_size,
has_embedding_layernorm=has_embedding_layernorm,
has_embedding_scale=has_embedding_scale,
layernorm_eps=layernorm_eps,
layernorm_type=layernorm_type,
dtype=dtype,
use_prompt_tuning=use_prompt_tuning,
use_parallel_embedding=use_parallel_embedding,
embedding_sharding_dim=embedding_sharding_dim,
mapping=self.mapping)
self.encoder_layers = ModuleList([
EncoderLayer(hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=self.head_size,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
has_attention_qkvo_bias=has_attention_qkvo_bias,
has_mlp_bias=has_mlp_bias,
layernorm_position=layernorm_position,
layernorm_eps=layernorm_eps,
layernorm_type=layernorm_type,
hidden_act=hidden_act,
mlp_type=mlp_type,
mapping=self.mapping,
dtype=dtype,
residual_scaling=residual_scaling,
relative_attention=relative_attention,
max_distance=max_distance,
num_buckets=num_buckets)
for _ in self.mapping.pp_layers(self.total_num_layers)
])
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
self.final_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
def forward(self,
input_ids: Tensor,
input_lengths=None,
position_ids=None,
token_type_ids=None,
hidden_states=None,
max_input_length=None,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None):
# In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs
if self.mapping.is_first_pp_rank():
hidden_states = self.embedding(input_ids, position_ids,
token_type_ids,
prompt_embedding_table, prompt_tasks,
prompt_vocab_size)
self.register_network_output('embedding_layer_output',
hidden_states)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
for encoder_layer in self.encoder_layers:
hidden_states = encoder_layer(hidden_states=hidden_states,
attention_mask=attention_mask,
input_lengths=input_lengths,
max_input_length=max_input_length)
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
hidden_states = self.final_layernorm(hidden_states)
hidden_states.mark_output('encoder_output', self._dtype)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
hidden_states.mark_output('hidden_states_output', self._dtype)
return hidden_states
def prepare_inputs(self,
max_batch_size,
max_input_len,
prompt_embedding_table_size: int = 0):
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
ranges of the dimensions of when using TRT dynamic shapes.
@return: a list contains values which can be fed into the self.forward()
'''
hidden_size = self.hidden_size
bs_range = [1, (max_batch_size + 1) // 2, max_batch_size]
inlen_range = [1, (max_input_len + 1) // 2, max_input_len]
num_tokens_range = [
1,
(max_input_len * max_batch_size + 1) // 2,
max_input_len * max_batch_size,
]
input_ids, position_ids, token_type_ids, hidden_states = None, None, None, None
remove_input_padding = default_net().plugin_config.remove_input_padding
use_custom_all_reduce = default_net(
).plugin_config.use_custom_all_reduce
attention_mask = None
if remove_input_padding:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(
name="input_ids",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("num_tokens", [num_tokens_range])]),
)
if self.has_position_embedding:
position_ids = Tensor(
name='position_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens',
[num_tokens_range])]),
)
if self.has_token_type_embedding:
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('num_tokens',
[num_tokens_range])]),
)
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, hidden_size],
dim_range=OrderedDict([
('num_tokens', [num_tokens_range]),
('hidden_size', [hidden_size]),
]))
else:
if self.mapping.is_first_pp_rank():
input_ids = Tensor(
name="input_ids",
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([("batch_size", [bs_range]),
("input_len", [inlen_range])]),
)
if self.has_position_embedding:
position_ids = Tensor(
name='position_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
if self.has_token_type_embedding:
token_type_ids = Tensor(
name='token_type_ids',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([('batch_size', [bs_range]),
('input_len', [inlen_range])]),
)
else:
hidden_states = Tensor(name='hidden_states_input',
dtype=self._dtype,
shape=[-1, -1, hidden_size],
dim_range=OrderedDict([
('batch_size', [bs_range]),
('input_len', [inlen_range]),
('hidden_size', [hidden_size]),
]))
if not default_net().plugin_config.bert_attention_plugin:
attention_mask = Tensor(
name='attention_mask',
dtype=trt.int32,
shape=[-1, -1],
dim_range=OrderedDict([
('batch_size', [bs_range]),
('input_len', [inlen_range]),
]),
)
if use_custom_all_reduce and self.mapping.tp_size > 1:
current_all_reduce_helper().set_workspace_tensor(
self.mapping, False)
input_lengths = Tensor(
name="input_lengths",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("batch_size", [bs_range])]),
)
max_input_length = Tensor(
name="max_input_length",
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([("max_input_length", [inlen_range])]),
)
prompt_embedding_table = None
tasks = None
prompt_vocab_size = None
if self.mapping.is_first_pp_rank() and prompt_embedding_table_size > 0:
p_embedding_range = [[
1, prompt_embedding_table_size // 2, prompt_embedding_table_size
]]
prompt_embedding_table = Tensor(name='prompt_embedding_table',
dtype=self._dtype,
shape=[-1, hidden_size],
dim_range=OrderedDict([
('prompt_embedding_table_size',
p_embedding_range),
('hidden_size', [hidden_size]),
]))
if remove_input_padding:
tasks = Tensor(name='tasks',
dtype=trt.int32,
shape=[-1],
dim_range=OrderedDict([('input_len_task',
[num_tokens_range])]))
else:
tasks = Tensor(name='tasks',
dtype=trt.int32,
shape=[-1, 1],
dim_range=OrderedDict([
('batch_size_beam_width', bs_range),
('broadcast_dim', [1]),
]))
prompt_vocab_size = Tensor(name='prompt_vocab_size',
dtype=trt.int32,
shape=[1],
dim_range=OrderedDict([('size', [1])]))
return (input_ids, input_lengths, position_ids, token_type_ids,
hidden_states, max_input_length, prompt_embedding_table, tasks,
prompt_vocab_size, attention_mask)
class DecoderModel(Module, GenerationMixin):
def __init__(self,
num_layers,
num_heads,
hidden_size,
ffn_hidden_size,
encoder_num_heads,
encoder_hidden_size,
vocab_size,
dtype,
logits_dtype='float32',
head_size=None,
encoder_head_size=None,
num_kv_heads=None,
encoder_num_kv_heads=None,
max_position_embeddings=None,
has_position_embedding=False,
relative_attention=False,
max_distance=None,
num_buckets=None,
type_vocab_size=None,
has_embedding_layernorm=False,
has_embedding_scale=False,
q_scaling=1.0,
has_attention_qkvo_bias=False,
has_mlp_bias=False,
has_model_final_layernorm=False,
layernorm_eps=1e-5,
layernorm_position=LayerNormPositionType.pre_layernorm,
layernorm_type=LayerNormType.LayerNorm,
hidden_act="relu",
mlp_type=MLPType.MLP,
rescale_before_lm_head=False,
has_lm_head_bias=False,
residual_scaling=1.0,
use_parallel_embedding=False,
embedding_sharding_dim=0,
mapping=Mapping()):
super().__init__()
self.mapping = mapping
self.has_position_embedding = has_position_embedding # TODO: remove dup codes
self.has_token_type_embedding = type_vocab_size is not None
self.rescale_before_lm_head = rescale_before_lm_head
# e.g. BART regular, T5 RMS
self.layernorm_type = layernorm_type
ln_type = layernorm_map[layernorm_type]
# e.g. BART true, T5 false
self.has_attention_qkvo_bias = has_attention_qkvo_bias
self.has_mlp_bias = has_mlp_bias
# e.g. BART false, T5 true
self.has_model_final_layernorm = has_model_final_layernorm
if isinstance(dtype, str):
self._dtype = str_dtype_to_trt(dtype)
else:
assert isinstance(dtype, trt.DataType)
self._dtype = dtype
# no quantization considered for now
self._kv_dtype = self._dtype
if isinstance(logits_dtype, str):
self._logits_dtype = str_dtype_to_trt(logits_dtype)
else:
assert isinstance(logits_dtype, trt.DataType)
self._logits_dtype = logits_dtype
self.total_num_layers = num_layers
self.num_layers = num_layers // self.mapping.pp_size
self.hidden_size = hidden_size
self.num_heads = num_heads
if num_kv_heads is None or num_kv_heads <= 0:
num_kv_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_size = self.hidden_size // self.num_heads if head_size is None else head_size
self.encoder_hidden_size = encoder_hidden_size
self.encoder_num_heads = encoder_num_heads
if encoder_num_kv_heads is None or encoder_num_kv_heads <= 0:
encoder_num_kv_heads = encoder_num_heads
self.encoder_num_kv_heads = encoder_num_kv_heads
self.encoder_head_size = self.encoder_hidden_size // self.num_heads if encoder_head_size is None else encoder_head_size
self.has_position_embedding = has_position_embedding
self.has_token_type_embedding = type_vocab_size is not None
self.rescale_before_lm_head = rescale_before_lm_head
if self.mapping.is_first_pp_rank():
self.embedding = EncDecEmbedding(
vocab_size,
hidden_size,
max_position_embeddings=max_position_embeddings,
has_position_embedding=has_position_embedding,
type_vocab_size=type_vocab_size,
has_embedding_layernorm=has_embedding_layernorm,
has_embedding_scale=has_embedding_scale,
layernorm_eps=layernorm_eps,
layernorm_type=layernorm_type,
dtype=dtype,
use_parallel_embedding=use_parallel_embedding,
embedding_sharding_dim=embedding_sharding_dim,
mapping=self.mapping)
self.decoder_layers = ModuleList([
DecoderLayer(hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
max_position_embeddings=max_position_embeddings,
q_scaling=q_scaling,
has_attention_qkvo_bias=has_attention_qkvo_bias,
has_mlp_bias=has_mlp_bias,
layernorm_position=layernorm_position,
layernorm_eps=layernorm_eps,
layernorm_type=layernorm_type,
hidden_act=hidden_act,
mlp_type=mlp_type,
mapping=self.mapping,
dtype=dtype,
residual_scaling=residual_scaling,
relative_attention=relative_attention,
max_distance=max_distance,
num_buckets=num_buckets)
for _ in self.mapping.pp_layers(self.total_num_layers)
])
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
self.final_layernorm = ln_type(normalized_shape=hidden_size,
eps=layernorm_eps,
dtype=dtype)
self.lm_head = ColumnLinear(
hidden_size,
vocab_size,
bias=has_lm_head_bias,
dtype=dtype,
tp_group=mapping.tp_group,
tp_size=mapping.tp_size,
gather_output=True,
)
def forward(self,
decoder_input_ids: Tensor,
encoder_output: Tensor,
position_ids=None,
token_type_ids=None,
use_cache=False,
attention_mask=None,
cross_attention_mask=None,
last_token_ids=None,
kv_cache_params=None,
attention_params=None,
hidden_states=None):
if self.mapping.is_first_pp_rank():
assert isinstance(decoder_input_ids, Tensor)
else:
assert isinstance(hidden_states, Tensor)
# In PP, layer 0 has ids as inputs, all other layers have hidden_states as inputs
if self.mapping.is_first_pp_rank():
hidden_states = self.embedding(decoder_input_ids, position_ids,
token_type_ids)
self.register_network_output('embedding_layer_output',
hidden_states)
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
kv_cache_params.fill_none_tensor_list(len(self.decoder_layers))
if use_cache:
presents = []
for i, (decoder_layer, past, max_attention_window_size) in enumerate(
zip(self.decoder_layers, kv_cache_params.past_key_value,
kv_cache_params.host_max_attention_window_sizes)):
hidden_states = decoder_layer(
hidden_states,
encoder_output=encoder_output,
attention_mask=attention_mask,
cross_attention_mask=cross_attention_mask,
use_cache=use_cache,
kv_cache_params=KeyValueCacheParams(
past_key_value=past,
host_past_key_value_lengths=kv_cache_params.
host_past_key_value_lengths,
host_max_attention_window_sizes=max_attention_window_size,
host_sink_token_length=kv_cache_params.
host_sink_token_length,
cache_indirection=kv_cache_params.cache_indirection),
attention_params=attention_params)
if use_cache:
presents_self, presents_cross = hidden_states[1], hidden_states[
2]
presents.append((presents_self, presents_cross))
hidden_states = hidden_states[0]
self.register_network_output(f'decoder_layer_{i}_output',
hidden_states)
if self.mapping.is_last_pp_rank():
if self.has_model_final_layernorm:
hidden_states = self.final_layernorm(hidden_states)
# [bs, seq, hidden_size] or [num_tokens, hidden_size] -> [bs, hidden_size]
hidden_states = gather_last_token_logits(
hidden_states, last_token_ids,
default_net().plugin_config.remove_input_padding)
self.register_network_output('logits_before_lmhead', hidden_states)
# Rescale output before projecting on vocab (for T5)
# See https://github.com/huggingface/transformers/blob/0b192de1f353b0e04dad4813e02e2c672de077be/src/transformers/models/t5/modeling_t5.py#L1769-L1772
# Note: this is specific for T5, to make it more generic, one can pass in a config:
# self.config.tie_word_embeddings - default to be True for T5
# openai whisper model didn't use this rescale
if self.rescale_before_lm_head:
hidden_states = hidden_states * (self.hidden_size**-0.5)
# [bs, hidden_size] -> [bs, vocab_size]
lm_logits = self.lm_head(hidden_states)
lm_logits.mark_output('logits', self._logits_dtype)
else:
hidden_states = send(hidden_states, self.mapping.next_pp_rank())
hidden_states.mark_output('hidden_states_output', self._dtype)