Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : allow using LoRA adapters per-request #10994

Merged
merged 12 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ These words will not be included in the completion, so make sure to add them to

`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.

`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation.

**Response format**

- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
Expand Down Expand Up @@ -945,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo

By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`

Please note that this value will be overwritten by the `lora` field for each request.

If an adapter is disabled, the scale will be set to 0.

**Response format**
Expand All @@ -966,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0.

### POST `/lora-adapters`: Set list of LoRA adapters

This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request.

To disable an adapter, either remove it from the list below, or set scale to 0.

**Request format**
Expand Down
116 changes: 76 additions & 40 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ struct slot_params {
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<common_lora_adapter_container> lora;

std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
Expand All @@ -120,6 +122,11 @@ struct slot_params {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}

json lora = json::array();
for (size_t i = 0; i < this->lora.size(); ++i) {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}

return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
Expand Down Expand Up @@ -160,6 +167,7 @@ struct slot_params {
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"lora", lora},
};
}
};
Expand Down Expand Up @@ -189,12 +197,16 @@ struct server_task {
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;

// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_container> set_lora;

server_task(server_task_type type) : type(type) {}

static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx,
const common_params & params_base,
const std::vector<common_lora_adapter_container> & base_lora,
ngxson marked this conversation as resolved.
Show resolved Hide resolved
const json & data) {
slot_params params;

Expand Down Expand Up @@ -251,6 +263,16 @@ struct server_task {
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_max = std::max(params.speculative.n_max, 0);

if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(base_lora, data.at("lora"));
} else {
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
}
} else {
params.lora = base_lora;
}

// TODO: add more sanity checks for the input parameters

if (params.sampling.penalty_last_n < -1) {
Expand Down Expand Up @@ -1110,6 +1132,8 @@ struct server_slot {

common_speculative * spec = nullptr;

std::vector<common_lora_adapter_container> lora;

// the index relative to completion multi-task request
size_t index = 0;

Expand Down Expand Up @@ -1191,6 +1215,11 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}

bool can_batch_with(server_slot & other_slot) {
return is_non_causal() == other_slot.is_non_causal()
&& are_lora_equal(lora, other_slot.lora);
}

bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
Expand Down Expand Up @@ -1600,7 +1629,7 @@ struct server_context {

llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<common_lora_adapter_container> loras;
std::vector<common_lora_adapter_container> lora;

llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
Expand Down Expand Up @@ -1667,7 +1696,7 @@ struct server_context {

model = llama_init.model;
ctx = llama_init.context;
loras = llama_init.lora_adapters;
lora = llama_init.lora_adapters;

if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
Expand Down Expand Up @@ -1866,6 +1895,12 @@ struct server_context {
slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);

if (!are_lora_equal(task.params.lora, slot.lora)) {
// if lora is changed, we cannot reuse cached tokens
slot.cache_tokens.clear();
slot.lora = std::move(task.params.lora);
}

SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());

if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
Expand Down Expand Up @@ -2557,7 +2592,7 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
common_lora_adapters_apply(ctx, loras);
lora = std::move(task.set_lora);
auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id;
queue_results.send(std::move(res));
Expand Down Expand Up @@ -2634,12 +2669,22 @@ struct server_context {
// start populating the batch for this iteration
common_batch_clear(batch);

// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;

// frist, add sampled tokens from any ongoing sequences
for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}

// check if we can batch this slot with the previous one
if (!slot_batched) {
slot_batched = &slot;
} else if (slot_batched && !slot_batched->can_batch_with(slot)) {
continue;
}
ngxson marked this conversation as resolved.
Show resolved Hide resolved

slot.i_batch = batch.n_tokens;

common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
Expand All @@ -2658,15 +2703,18 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);

// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;

// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
if (!slot_batched) {
slot_batched = &slot;
} else if (slot_batched && !slot_batched->can_batch_with(slot)) {
continue;
}
}

// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens;
Expand Down Expand Up @@ -2827,14 +2875,6 @@ struct server_context {
}
}

// check that we are in the right batch_type, if not defer the slot
int slot_type = slot.is_non_causal();
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}

// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
Expand Down Expand Up @@ -2902,8 +2942,12 @@ struct server_context {

SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);

// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
if (slot_batched) {
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, slot_batched->is_non_causal());
// apply lora, only need to do it once per batch
common_lora_adapters_apply(ctx, slot_batched->lora);
}

// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
Expand Down Expand Up @@ -3623,7 +3667,12 @@ int main(int argc, char ** argv) {
task.index = i;

task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
task.params = server_task::params_from_json_cmpl(
ctx_server.model,
ctx_server.ctx,
ctx_server.params_base,
ctx_server.lora,
data);
task.id_selected_slot = json_value(data, "id_slot", -1);

// OAI-compat
Expand Down Expand Up @@ -4049,8 +4098,8 @@ int main(int argc, char ** argv) {

const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
auto & lora = ctx_server.loras[i];
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
auto & lora = ctx_server.lora[i];
result.push_back({
{"id", i},
{"path", lora.path},
Expand All @@ -4062,27 +4111,14 @@ int main(int argc, char ** argv) {
};

const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.loras.size();

// clear existing value
for (auto & lora : ctx_server.loras) {
lora.scale = 0.0f;
}

// set value
for (auto entry : body) {
int id = entry.at("id");
float scale = entry.at("scale");
if (0 <= id && id < max_idx) {
ctx_server.loras[id].scale = scale;
} else {
throw std::runtime_error("invalid adapter id");
}
const json body = json::parse(req.body);
if (!body.is_array()) {
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
return;
}

server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = ctx_server.queue_tasks.get_new_id();
task.set_lora = parse_lora_request(ctx_server.lora, body);
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);

Expand Down
6 changes: 6 additions & 0 deletions examples/server/tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
DEBUG=1 ./tests.sh -s -v -x
```

To run single test unit:

```shell
./tests.sh unit/test_{name of test case here}.py -v -x
```
Hint: You can compile and run test in single command, useful for local developement:
```shell
Expand Down
1 change: 1 addition & 0 deletions examples/server/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ numpy~=1.26.4
openai~=1.55.3
prometheus-client~=0.20.0
requests~=2.32.3
wget~=3.2
Loading
Loading