Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple LoRa adapters #2010

Merged
merged 41 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
db3d8e6
feat: first draft load multiple lora
drbh May 30, 2024
0a6ea7f
feat: load weights within layer and refactor lora pass
drbh Jun 4, 2024
a046c30
fix: refactor and reduce lora math
drbh Jun 4, 2024
c661631
feat: baseline impl single request multi lora support
drbh Jun 4, 2024
8b50f4b
feat: prefer lorax implementation and port loading logic
drbh Jun 5, 2024
d5f21d5
fix: prefer adapter_data and refactors
drbh Jun 6, 2024
8984ce6
feat: perfer loraxs custom punica kernels and add mlp loras
drbh Jun 6, 2024
ad088d5
fix: adjust batch for bgmv
drbh Jun 6, 2024
c927376
fix: adjust adapter_segments logic when in batch
drbh Jun 6, 2024
73eb2ae
fix: refactor and move changes to v3 proto
drbh Jun 6, 2024
88bd5c2
fix: pass model_id for all flash causal lms
drbh Jun 6, 2024
dc0f765
fix: pass model_id for all causal and seq2seq lms
drbh Jun 6, 2024
9c45d34
fix: add model_id to model test
drbh Jun 6, 2024
de56a81
feat: add lora support to mistral and refactors
drbh Jun 6, 2024
68399c1
feat: prefer model id in request
drbh Jun 6, 2024
81707bf
fix: include rust code for adapter id
drbh Jun 6, 2024
43ec9df
feat: bump launcher and add new lora docs
drbh Jun 6, 2024
611225f
feat: support base model generation and refactors
drbh Jun 7, 2024
a563a93
fix: rename doc to retry ci build
drbh Jun 7, 2024
91f4072
feat: support if vlm models
drbh Jun 7, 2024
b116927
fix: add adapter_data param and avoid missing layers
drbh Jun 7, 2024
1deb372
fix: add adapter_data param to phi and neox
drbh Jun 7, 2024
101b95a
fix: update all models forwards to include adapter_data
drbh Jun 7, 2024
ce40ad2
fix: add model_id to IdeficsCausalLM
drbh Jun 7, 2024
1be1ebc
Update lora.md
datavistics Jun 10, 2024
d6cf63c
Update lora.md
datavistics Jun 10, 2024
aa88c4f
fix: add lora kernel to dockerfile, support running without kernels a…
drbh Jun 14, 2024
06c3254
fix: avoid dockerfile conflict
drbh Jun 14, 2024
0e1c28c
fix: merge 'main' into lora-internal to resolve conflicts
drbh Jun 14, 2024
1104885
Merge branch 'main' into lora-internal
drbh Jun 14, 2024
224455f
Merge branch 'main' into lora-internal
drbh Jun 18, 2024
4f1543d
fix: refactors and adjust flash llama lora logic
drbh Jun 19, 2024
ce70fce
fix: skip llama test due to CI issue (temp)
drbh Jun 19, 2024
c9e4526
fix: skip llama test CI (temp) 2
drbh Jun 19, 2024
a07b612
fix: revert skips and prefer updated ci token for tests
drbh Jun 19, 2024
3c9b28e
fix: refactors and helpful comments
drbh Jun 24, 2024
c927cff
fix: add noop in TensorParallelAdapterRowLinear too
drbh Jun 24, 2024
f94f2b3
fix: refactor and move shard_lora_weights logic
drbh Jun 24, 2024
0d496ba
Merge branch 'main' into lora-internal
drbh Jun 24, 2024
a2d821c
fix: exit early if no adapter_data
drbh Jun 25, 2024
59575fe
Merge branch 'main' into lora-internal
drbh Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ COPY server/marlin/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build

# Build Lorax Punica kernels
FROM kernel-builder as lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica

# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
Expand Down Expand Up @@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
Expand Down Expand Up @@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
# CMD ["--json-output"]
1 change: 1 addition & 0 deletions benchmark/src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
})
.collect();

Expand Down
5 changes: 4 additions & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
- local: conceptual/speculation
title: Speculation (Medusa, ngram)
- local: conceptual/guidance
title: How Guidance Works (via outlines)
title: How Guidance Works (via outlines
- local: conceptual/lora
title: LoRA (Low-Rank Adaptation)


title: Conceptual Guides
8 changes: 8 additions & 0 deletions docs/source/basic_tutorials/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ Options:
[env: MAX_CLIENT_BATCH_SIZE=]
[default: 4]

```
## LORA_ADAPTERS
```shell
--lora-adapters <LORA_ADAPTERS>
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request

[env: LORA_ADAPTERS=]

```
## HELP
```shell
Expand Down
65 changes: 65 additions & 0 deletions docs/source/conceptual/lora.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# LoRA (Low-Rank Adaptation)

## What is LoRA?

LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.

LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.

## How is it used?

LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:

Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:

- fine-tuning a language model on a small dataset
- fine-tuning a language model on a domain-specific dataset
- fine-tuning a language model on a dataset with limited labels

## Optimizing Inference with LoRA

LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.

## Serving multiple LoRA adapters with TGI

Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.

In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.

Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.

### Specifying LoRA models

To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:

```bash
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```

In the server logs, you will see the following message:

```txt
Loading adapter weights into model: predibase/customer_support
Loading adapter weights into model: predibase/dbpedia
```

## Generate text

You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:

```json
curl 127.0.0.1:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "Hello who are you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/customer_support"
}
}'
```

> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.

An updated tutorial with detailed examples will be published soon. Stay tuned!
13 changes: 13 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,11 @@ struct Args {
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,

/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
/// startup that will be available to callers via the `adapter_id` field in a request.
#[clap(long, env)]
lora_adapters: Option<String>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -485,6 +490,7 @@ fn shard_manager(
max_total_tokens: usize,
max_batch_size: Option<usize>,
max_input_tokens: usize,
lora_adapters: Option<String>,
otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter,
Expand Down Expand Up @@ -620,6 +626,11 @@ fn shard_manager(
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}

// Lora Adapters
if let Some(lora_adapters) = lora_adapters {
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
}

// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
Expand Down Expand Up @@ -1060,6 +1071,7 @@ fn spawn_shards(
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
let lora_adapters = args.lora_adapters.clone();
thread::spawn(move || {
shard_manager(
model_id,
Expand All @@ -1085,6 +1097,7 @@ fn spawn_shards(
max_total_tokens,
max_batch_size,
max_input_tokens,
lora_adapters,
otlp_endpoint,
otlp_service_name,
max_log_level,
Expand Down
2 changes: 2 additions & 0 deletions proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ message Request {
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
}

message Batch {
Expand Down
1 change: 1 addition & 0 deletions router/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;

Expand Down
1 change: 1 addition & 0 deletions router/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,
Expand Down
1 change: 1 addition & 0 deletions router/src/infer/v2/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),
Expand Down
2 changes: 2 additions & 0 deletions router/src/infer/v3/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
adapter_id: entry.request.adapter_id.clone(),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
Expand Down Expand Up @@ -491,6 +492,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),
Expand Down
6 changes: 6 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>,

/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>,
}

fn default_max_new_tokens() -> Option<u32> {
Expand All @@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters {
seed: None,
top_n_tokens: None,
grammar: None,
adapter_id: None,
}
}

Expand Down
2 changes: 2 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ async fn completions(
seed,
top_n_tokens: None,
grammar: None,
..Default::default()
},
})
.collect();
Expand Down Expand Up @@ -1115,6 +1116,7 @@ async fn chat_completions(
seed,
top_n_tokens: req.top_logprobs,
grammar,
..Default::default()
},
};

Expand Down
3 changes: 3 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ impl Validation {
decoder_input_details,
top_n_tokens,
grammar,
adapter_id,
..
} = request.parameters;

Expand Down Expand Up @@ -383,6 +384,7 @@ impl Validation {
parameters,
stopping_parameters,
top_n_tokens,
adapter_id,
})
}

Expand Down Expand Up @@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32,
pub adapter_id: Option<String>,
}

#[derive(Error, Debug)]
Expand Down
1 change: 1 addition & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
include Makefile-lorax-punica

unit-tests:
pytest -s -vv -m "not private" tests
Expand Down
12 changes: 12 additions & 0 deletions server/Makefile-lorax-punica
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
drbh marked this conversation as resolved.
Show resolved Hide resolved

build-lorax-punica:
if [ ! -d 'lorax-punica' ]; then \
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
fi
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
cd lorax-punica && git submodule update --init --recursive
cd lorax-punica/server/punica_kernels && python setup.py build

install-lorax-punica: build-lorax-punica
cd lorax-punica/server/punica_kernels && python setup.py install
7 changes: 6 additions & 1 deletion server/tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def generate_token(self, batch):
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
"test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
)
return model

Expand Down
13 changes: 13 additions & 0 deletions server/text_generation_server/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004

from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)

__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
]
44 changes: 44 additions & 0 deletions server/text_generation_server/adapters/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple

import torch

from text_generation_server.adapters.weights import AdapterWeights

if TYPE_CHECKING:
from text_generation_server.models.model import Model


@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]


@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str

@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass

@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass
Loading
Loading