From fc2b181b92f928b0fb1aeb704e584eb9656925bf Mon Sep 17 00:00:00 2001 From: Aleksandar-Argead <88592406+Aleksandar-Argead@users.noreply.github.com> Date: Fri, 22 Dec 2023 00:13:17 +0100 Subject: [PATCH] Add response_format option Add response_format option to client config --- src/client.rs | 6 ++++++ src/config.rs | 37 ++++++++++++++++++++++++++++++++++++- src/types.rs | 3 +++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 7395d33..5a82a9f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -157,6 +157,7 @@ impl ChatGPT { temperature: self.config.temperature, top_p: self.config.top_p, max_tokens: self.config.max_tokens, + response_format: self.config.response_format, frequency_penalty: self.config.frequency_penalty, presence_penalty: self.config.presence_penalty, reply_count: self.config.reply_count, @@ -198,6 +199,7 @@ impl ChatGPT { temperature: self.config.temperature, top_p: self.config.top_p, max_tokens: self.config.max_tokens, + response_format: self.config.response_format, frequency_penalty: self.config.frequency_penalty, presence_penalty: self.config.presence_penalty, reply_count: self.config.reply_count, @@ -230,6 +232,7 @@ impl ChatGPT { temperature: self.config.temperature, top_p: self.config.top_p, max_tokens: self.config.max_tokens, + response_format: self.config.response_format, frequency_penalty: self.config.frequency_penalty, presence_penalty: self.config.presence_penalty, reply_count: self.config.reply_count, @@ -273,6 +276,7 @@ impl ChatGPT { temperature: self.config.temperature, top_p: self.config.top_p, max_tokens: self.config.max_tokens, + response_format: self.config.response_format, frequency_penalty: self.config.frequency_penalty, presence_penalty: self.config.presence_penalty, reply_count: self.config.reply_count, @@ -370,6 +374,7 @@ impl ChatGPT { stream: false, temperature: self.config.temperature, top_p: self.config.top_p, + response_format: self.config.response_format, frequency_penalty: self.config.frequency_penalty, presence_penalty: self.config.presence_penalty, reply_count: self.config.reply_count, @@ -407,6 +412,7 @@ impl ChatGPT { stream: false, temperature: self.config.temperature, top_p: self.config.top_p, + response_format: self.config.response_format, frequency_penalty: self.config.frequency_penalty, presence_penalty: self.config.presence_penalty, reply_count: self.config.reply_count, diff --git a/src/config.rs b/src/config.rs index 0cc3849..8d7b6bb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,5 @@ -use std::{fmt::Display, str::FromStr}; use std::time::Duration; +use std::{fmt::Display, str::FromStr}; #[cfg(feature = "functions")] use crate::functions::FunctionValidationStrategy; @@ -18,6 +18,8 @@ pub struct ModelConfiguration { pub top_p: f32, /// Controls the maximum number of tokens to generate in the completion pub max_tokens: Option, + /// Format of the response from the model, e.g., text or json_object. + pub response_format: ResponseFormat, /// Determines how much to penalize new tokens passed on their existing presence so far pub presence_penalty: f32, /// Determines how much to penalize new tokens based on their existing frequency so far @@ -40,6 +42,7 @@ impl Default for ModelConfiguration { temperature: 0.5, top_p: 1.0, max_tokens: None, + response_format: ResponseFormat::default(), presence_penalty: 0.0, frequency_penalty: 0.0, reply_count: 1, @@ -51,6 +54,38 @@ impl Default for ModelConfiguration { } } +/// Specifies the format of the response. +#[derive(Serialize, Debug, Clone, Copy, PartialEq, PartialOrd, Builder)] +#[builder(default, setter(into))] +pub struct ResponseFormat { + /// The type of format for the response. + pub format_type: FormatType, +} + +impl Default for ResponseFormat { + fn default() -> Self { + Self { + format_type: FormatType::default(), + } + } +} + +/// Specifies the type of the format. +#[derive(Serialize, Debug, Clone, Copy, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum FormatType { + /// Standard text format. + Text, + /// JSON object format. + JsonObject, +} + +impl Default for FormatType { + fn default() -> Self { + FormatType::Text + } +} + /// The engine version for ChatGPT #[derive(Serialize, Debug, Default, Copy, Clone, PartialEq, PartialOrd)] #[allow(non_camel_case_types)] diff --git a/src/types.rs b/src/types.rs index 35c650a..2e043a9 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,4 @@ +use crate::config::ResponseFormat; #[cfg(feature = "functions")] use crate::functions::FunctionCall; use serde::{Deserialize, Deserializer, Serialize}; @@ -85,6 +86,8 @@ pub struct CompletionRequest<'a> { pub stream: bool, /// The extra randomness of response pub temperature: f32, + /// Format of the response from the model, e.g., text or json_object. + pub response_format: ResponseFormat, /// Controls diversity via nucleus sampling, not recommended to use with temperature pub top_p: f32, /// Controls the maximum number of tokens to generate in the completion