The first versions of TensorRT-LLM were developed with a very aggressive timeline. For those versions emphasis was not put on defining a unified workflow. Now that TensorRT-LLM has reached some level of feature richness, the development team has decided to put more efforts into unifying the APIs and workflow of TensorRT-LLM. This document summarises the new workflow adopted by TensorRT-LLM at its core.
There are 3 steps in the new workflow:
- Convert weights from different source frameworks into TensorRT-LLM checkpoint
- Build the TensorRT-LLM checkpoint into TensorRT engine(s) with a unified build command
- Load the engine(s) to TensorRT-LLM model runner and make evaluation with different evaluation tasks
NeMo -------------
|
HuggingFace ------
| convert build load
AMMO ------------- ----------> TensorRT-LLM Checkpoint --------> TensorRT Engine ------> TensorRT-LLM ModelRunner
|
JAX --------------
|
DeepSpeed --------
TensorRT-LLM aims at supporting different of sources:
- Trained models from NeMo, DeepSpeed, JAX
- Quantized models from AMMO
- Popular models from HuggingFace
TensorRT-LLM defines its own checkpoint format. A checkpoint directory includes:
- One config json file, which contains several model hyper-parameters
- One or several rank weights files, each file contains a dictionary of tensors (weights). The different files will be loaded by different ranks in a multi-GPU (multi-process) scenario
Field | Type | Default Value |
---|---|---|
architecture | string | mandatory |
dtype | string | mandatory |
logits_dtype | string | 'float32' |
vocab_size | int | mandatory |
max_position_embeddings | int | null |
hidden_size | int | mandatory |
num_hidden_layers | int | mandatory |
num_attention_heads | int | mandatory |
num_key_value_heads | int | num_attention_heads |
hidden_act | string | mandatory |
intermediate_size | int | null |
norm_epsilon | float | 1e-5 |
position_embedding_type | string | 'learned_absolute' |
use_prompt_tuning | bool | false |
mapping.world_size | int | 1 |
mapping.tp_size | int | 1 |
mapping.pp_size | int | 1 |
quantization.quant_aglo | str | null |
quantization.kv_cache_quant_aglo | str | null |
quantization.group_size | int | 64 |
quantization.has_zero_point | bool | False |
quantization.pre_quant_scale | bool | False |
quantization.exclude_modules | list | null |
mapping.world_size
means mapping
is a dictionary containing the world_size
sub field.
{
"architecture": "OPTForCausalLM",
"mapping": {
"world_size": 1
}
}
Supported quantization algorithm list:
- W8A16
- W4A16
- W4A16_AWQ
- W4A8_AWQ
- W4A16_GPTQ
- FP8
- W8A8_SQ_PER_CHANNEL
Supported KV cache quantization algorithm list:
- FP8
- INT8
The config field is extensible, a model could add its own specific config fields.
For example, OPT model has a do_layer_norm_before
field.
Here is the model specific config list:
Field | Type | Default Value |
---|---|---|
OPT | ||
do_layer_norm_before | bool | False |
Falcon | ||
bias | bool | True |
new_decoder_architecture | bool | False |
parallel_attention | bool | False |
Like PyTorch, the tensor(weight) name is a string containing hierarchical information, which is uniquely mapped to a certain parameter of a TensorRT-LLM model.
For example, each transformer layer of the OPT model contains an Attention
layer, an MLP
layer and two LayerNorm
layers.
The Attention
layer contains two Linear
layers, qkv and dense; each Linear
layer contains one weight and one bias.
So, there are four tensors (weights) in total, whose names are:
- "transformer.layers.0.attention.qkv.weight"
- "transformer.layers.0.attention.qkv.bias"
- "transformer.layers.0.attention.dense.weight"
- "transformer.layers.0.attention.dense.bias"
where transformer.layers.0.attention
is the prefix name, indicating that the weights/biases are in the attention module of the 0-th transformer layer.
The MLP
layer also contains two Linear
layers, fc and proj; each Linear
layer contains one weight and one bias.
So, there are four tensors (weights) in total, whose names are:
- "transformer.layers.0.mlp.fc.weight"
- "transformer.layers.0.mlp.fc.bias"
- "transformer.layers.0.mlp.proj.weight"
- "transformer.layers.0.mlp.proj.bias"
where transformer.layers.0.mlp
is the prefix name, indicating that the weights/biases are in the mlp module of the 0-th transformer layer.
Each of the two LayerNorm
layers, namely input_layernorm and post_layernorm, contains one weight and one bias.
So, there are four tensors (weights) in total, whose names are:
- "transformer.layers.0.input_layernorm.weight"
- "transformer.layers.0.input_layernorm.bias"
- "transformer.layers.0.post_layernorm.weight"
- "transformer.layers.0.post_layernorm.bias"
where transformer.layers.0.input_layernorm
and transformer.layers.0.post_layernorm
are prefix names for the two layernorm modules.
Note that if we quantize the model, there will be different tensors (depending on the quantization method applied).
For example, if we quantize the KV cache, the Attention
layer will have this extra scaling factor:
- "transformer.layers.0.attention.kv_cache_scaling_factor"
For example, here is the FP8 scaling factors of attention.qkv linear layer:
- "transformer.layers.0.attention.qkv.activation_scaling_factor"
- "transformer.layers.0.attention.qkv.weights_scaling_factor"
For example, here is the AWQ scaling factors of mlp.fc linear layer:
- "transformer.layers.0.mlp.fc.weights_scaling_factor"
- "transformer.layers.0.mlp.fc.prequant_scaling_factor"
Note: The linear weights in TensorRT-LLM checkpoint always follows (out_feature, in_feature) shape,
whereas some quantized linear in TensorRT-LLM implemented by plugin may use (in_feature, out_fature) shape.
trtllm-build
command will add a transpose operation to post-process it.
Let's take OPT as an example, say we want to deploy the model with tensor parallelism 2:
cd examples/opt
python3 convert_checkpoint.py --model_dir ./opt-125m \
--dtype float16 \
--world_size 2 \
--output_dir ./opt/125M/trt_ckpt/fp16/2-gpu/
Here is the checkpoint directory:
./opt/125M/trt_ckpt/fp16/1-gpu/
config.json
rank0.safetensors
rank1.safetensors
Here is the config.json
:
{
"architecture": "OPTForCausalLM",
"dtype": "float16",
"logits_dtype": "float32",
"num_hidden_layers": 12,
"num_attention_heads": 12,
"hidden_size": 768,
"vocab_size": 50272,
"position_embedding_type": "learned_absolute",
"max_position_embeddings": 2048,
"hidden_act": "relu",
"quantization": {
"use_weight_only": false,
"weight_only_precision": "int8"
},
"mapping": {
"world_size": 2,
"tp_size": 2
},
"use_parallel_embedding": false,
"embedding_sharding_dim": 0,
"share_embedding_table": false,
"do_layer_norm_before": true,
"use_prompt_tuning": false
}
TensorRT-LLM provides a unified build command: trtllm-build
. Before using it,
you may need to add it to the PATH
export PATH=/usr/local/bin:$PATH
trtllm-build --checkpoint_dir ./opt/125M/trt_ckpt/fp16/2-gpu/ \
--gemm_plugin float16 \
--max_batch_size 8 \
--max_input_len 924 \
--max_output_len 100 \
--output_dir ./opt/125M/trt_engines/fp16/2-gpu/
mpirun -n 2 --allow-run-as-root \
python3 ../summarize.py --engine_dir ./opt/125M/trt_engines/fp16/2-gpu/ \
--batch_size 1 \
--test_trt_llm \
--hf_model_dir opt-125m \
--data_type fp16 \
--check_accuracy \
--tensorrt_llm_rouge1_threshold=14