diff --git a/router/src/lib.rs b/router/src/lib.rs index 126726c6a58..e198b4139e9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -420,10 +420,11 @@ pub struct CompletionRequest { pub stop: Option>, } -#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] +#[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct Completion { pub id: String, - pub object: String, + #[schema(default = "ObjectType::ChatCompletion")] + pub object: ObjectType, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -444,7 +445,7 @@ pub(crate) struct CompletionComplete { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct ChatCompletion { pub id: String, - pub object: String, + pub object: ObjectType, #[schema(example = "1706270835")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -540,6 +541,14 @@ pub(crate) struct Usage { pub total_tokens: u32, } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ObjectType { + #[serde(rename = "chat.completion")] + ChatCompletion, + #[serde(rename = "chat.completion.chunk")] + ChatCompletionChunk, +} + impl ChatCompletion { pub(crate) fn new( model: String, @@ -576,7 +585,7 @@ impl ChatCompletion { }; Self { id: String::new(), - object: "chat.completion".into(), + object: ObjectType::ChatCompletion, created, model, system_fingerprint, @@ -598,7 +607,7 @@ impl ChatCompletion { #[derive(Clone, Deserialize, Serialize, ToSchema)] pub(crate) struct CompletionCompleteChunk { pub id: String, - pub object: String, + pub object: ObjectType, pub created: u64, pub choices: Vec, pub model: String, @@ -608,7 +617,7 @@ pub(crate) struct CompletionCompleteChunk { #[derive(Clone, Serialize, ToSchema)] pub(crate) struct ChatCompletionChunk { pub id: String, - pub object: String, + pub object: ObjectType, #[schema(example = "1706270978")] pub created: u64, #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] @@ -688,7 +697,7 @@ impl ChatCompletionChunk { }; Self { id: String::new(), - object: "chat.completion.chunk".to_string(), + object: ObjectType::ChatCompletionChunk, created, model, system_fingerprint, diff --git a/router/src/server.rs b/router/src/server.rs index 7f15bfdd6a6..d20785b6352 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -13,8 +13,8 @@ use crate::validation::ValidationError; use crate::{ BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info, - Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, - Usage, Validation, + Message, ObjectType, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, + TokenizeResponse, Usage, Validation, }; use crate::{ ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, @@ -705,7 +705,7 @@ async fn completions( event .json_data(CompletionCompleteChunk { id: "".to_string(), - object: "text_completion".to_string(), + object: ObjectType::ChatCompletionChunk, created: current_time, choices: vec![CompletionComplete { @@ -932,7 +932,7 @@ async fn completions( let response = Completion { id: "".to_string(), - object: "text_completion".to_string(), + object: ObjectType::ChatCompletion, created: current_time, model: info.model_id.clone(), system_fingerprint: format!(