Skip to content

Commit

Permalink
fix: prefer enum for chat object
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 26, 2024
1 parent be2d380 commit 226b766
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
23 changes: 16 additions & 7 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,11 @@ pub struct CompletionRequest {
pub stop: Option<Vec<String>>,
}

#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct Completion {
pub id: String,
pub object: String,
#[schema(default = "ObjectType::ChatCompletion")]
pub object: ObjectType,
#[schema(example = "1706270835")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
Expand All @@ -444,7 +445,7 @@ pub(crate) struct CompletionComplete {
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion {
pub id: String,
pub object: String,
pub object: ObjectType,
#[schema(example = "1706270835")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
Expand Down Expand Up @@ -540,6 +541,14 @@ pub(crate) struct Usage {
pub total_tokens: u32,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ObjectType {
#[serde(rename = "chat.completion")]
ChatCompletion,
#[serde(rename = "chat.completion.chunk")]
ChatCompletionChunk,
}

impl ChatCompletion {
pub(crate) fn new(
model: String,
Expand Down Expand Up @@ -576,7 +585,7 @@ impl ChatCompletion {
};
Self {
id: String::new(),
object: "chat.completion".into(),
object: ObjectType::ChatCompletion,
created,
model,
system_fingerprint,
Expand All @@ -598,7 +607,7 @@ impl ChatCompletion {
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk {
pub id: String,
pub object: String,
pub object: ObjectType,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
Expand All @@ -608,7 +617,7 @@ pub(crate) struct CompletionCompleteChunk {
#[derive(Clone, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub object: ObjectType,
#[schema(example = "1706270978")]
pub created: u64,
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
Expand Down Expand Up @@ -688,7 +697,7 @@ impl ChatCompletionChunk {
};
Self {
id: String::new(),
object: "chat.completion.chunk".to_string(),
object: ObjectType::ChatCompletionChunk,
created,
model,
system_fingerprint,
Expand Down
8 changes: 4 additions & 4 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::validation::ValidationError;
use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
Usage, Validation,
Message, ObjectType, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeResponse, Usage, Validation,
};
use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
Expand Down Expand Up @@ -705,7 +705,7 @@ async fn completions(
event
.json_data(CompletionCompleteChunk {
id: "".to_string(),
object: "text_completion".to_string(),
object: ObjectType::ChatCompletionChunk,
created: current_time,

choices: vec![CompletionComplete {
Expand Down Expand Up @@ -932,7 +932,7 @@ async fn completions(

let response = Completion {
id: "".to_string(),
object: "text_completion".to_string(),
object: ObjectType::ChatCompletion,
created: current_time,
model: info.model_id.clone(),
system_fingerprint: format!(
Expand Down

0 comments on commit 226b766

Please sign in to comment.