Skip to content

Commit

Permalink
feat: supports openai chat completions API (huggingface#1427)
Browse files Browse the repository at this point in the history
This PR adds support to make TGI a drop in replacement for OpenAI
clients by exposing the same HTTP interface.

Notes
- TGI inits a single model at startup so the `model` field is unused in
HTTP requests.
- `max_tokens` and `stream` should work as expected but other params may
be (unimplemented or not supported)

General approach
- fetch the `tokenizer_config` at startup from the hub
- pass `tokenizer_config` into `Infer` so we have it at request time
- use the `chat_template` on the config to format chat request
- parse jinja template and render chat string
- pass inputs into existing generate function
- wrap generation output in expected structure before returning

```bash
curl localhost:3000/v1/chat/completions \
    -X POST \
    -d '{
  "model": "tgi",
  "messages": [
    {
      "role": "system",
      "content": "You are a helpful assistant."
    },
    {
      "role": "user",
      "content": "What is deep learning?"
    }
  ],
  "stream": true,
  "max_tokens": 20
}' \
    -H 'Content-Type: application/json'
```

It is also possible to use the `openai` python library and change the
base url

```python
from openai import OpenAI

client = OpenAI(
    base_url="http://localhost:3000/v1",
    api_key="not needed for a local LLM"
)

chat_completion = client.chat.completions.create(
    model="tgi",
    messages=[
        {"role": "system", "content": "You are a helpful assistant." },
        {"role": "user", "content": "What is deep learning?"}
    ],
    stream=True
)

for message in chat_completion:
    print(message)

```

```python
from openai import OpenAI

client = OpenAI(
    base_url="http://localhost:3000/v1",
    api_key="not needed for a local LLM"
)

chat_completion = client.chat.completions.create(
    model="tgi",
    messages=[
        {"role": "system", "content": "You are a helpful assistant." },
        {"role": "user", "content": "What is deep learning?"}
    ],
    stream=False
)

print(chat_completion)
```

```bash
cd text-generation-inference/server
MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2
```

***note many of the existing `chat_templates` use non standard `jinja`
(ie. adding a `raise` to the template) which will throw an error when
parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a
valid template
```bash
cd text-generation-inference/router
cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0
```

trigger
```bash
curl localhost:3000/v1/chat/completions \
    -X POST \
    -d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \
    -H 'Content-Type: application/json'
```

^ supports `stream: true` and `stream: false` requests
  • Loading branch information
drbh authored and kdamaszk committed Apr 22, 2024
1 parent 12cfc79 commit 76b226b
Show file tree
Hide file tree
Showing 7 changed files with 550 additions and 56 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10"
futures-util = "0.3.30"

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
51 changes: 38 additions & 13 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -15,7 +17,7 @@ use text_generation_client::{
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
Expand All @@ -32,6 +34,8 @@ pub struct Infer {
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Chat template
template: Option<Template<'static, 'static>>,
}

/// Infer shared state
Expand All @@ -56,6 +60,7 @@ impl Infer {
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(
Expand Down Expand Up @@ -85,11 +90,21 @@ impl Infer {
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));

let template = tokenizer_config.chat_template.map(|t| {
let env = Box::new(Environment::new());
let template_str = t.into_boxed_str();
// leaking env and template_str as read-only, static resources for performance.
Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap()
});

Self {
validation,
queue,
shared,
limit_concurrent_requests: semaphore,
template,
}
}

Expand All @@ -98,14 +113,7 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<
(
OwnedSemaphorePermit,
u32,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
Expand Down Expand Up @@ -150,6 +158,20 @@ impl Infer {
))
}

/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
self.template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(chat)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
InferError::TemplateError(e)
})
}

/// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)]
pub(crate) async fn generate(
Expand Down Expand Up @@ -561,9 +583,9 @@ fn send_responses(
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
Expand Down Expand Up @@ -676,6 +698,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
}

impl InferError {
Expand All @@ -685,6 +709,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
}
}
}
Loading

0 comments on commit 76b226b

Please sign in to comment.