Skip to content

Commit

Permalink
feat: use model name as adapter id in chat endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 26, 2024
1 parent be2d380 commit 29a1137
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ pub struct CompletionRequest {
/// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
pub model: Option<String>,

/// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")]
Expand Down Expand Up @@ -706,7 +706,7 @@ impl ChatCompletionChunk {
pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
pub model: Option<String>,

/// A list of messages comprising the conversation so far.
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
Expand Down
6 changes: 4 additions & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ async fn completions(
metrics::increment_counter!("tgi_request_count");

let CompletionRequest {
model,
max_tokens,
seed,
stop,
Expand Down Expand Up @@ -673,7 +674,7 @@ async fn completions(
seed,
top_n_tokens: None,
grammar: None,
..Default::default()
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
})
.collect();
Expand Down Expand Up @@ -1011,6 +1012,7 @@ async fn chat_completions(
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count");
let ChatRequest {
model,
logprobs,
max_tokens,
messages,
Expand Down Expand Up @@ -1116,7 +1118,7 @@ async fn chat_completions(
seed,
top_n_tokens: req.top_logprobs,
grammar,
..Default::default()
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
};

Expand Down

0 comments on commit 29a1137

Please sign in to comment.