Skip to content
This repository has been archived by the owner on Jul 9, 2024. It is now read-only.

Add response_format option #76

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ impl ChatGPT {
temperature: self.config.temperature,
top_p: self.config.top_p,
max_tokens: self.config.max_tokens,
response_format: self.config.response_format,
frequency_penalty: self.config.frequency_penalty,
presence_penalty: self.config.presence_penalty,
reply_count: self.config.reply_count,
Expand Down Expand Up @@ -198,6 +199,7 @@ impl ChatGPT {
temperature: self.config.temperature,
top_p: self.config.top_p,
max_tokens: self.config.max_tokens,
response_format: self.config.response_format,
frequency_penalty: self.config.frequency_penalty,
presence_penalty: self.config.presence_penalty,
reply_count: self.config.reply_count,
Expand Down Expand Up @@ -230,6 +232,7 @@ impl ChatGPT {
temperature: self.config.temperature,
top_p: self.config.top_p,
max_tokens: self.config.max_tokens,
response_format: self.config.response_format,
frequency_penalty: self.config.frequency_penalty,
presence_penalty: self.config.presence_penalty,
reply_count: self.config.reply_count,
Expand Down Expand Up @@ -273,6 +276,7 @@ impl ChatGPT {
temperature: self.config.temperature,
top_p: self.config.top_p,
max_tokens: self.config.max_tokens,
response_format: self.config.response_format,
frequency_penalty: self.config.frequency_penalty,
presence_penalty: self.config.presence_penalty,
reply_count: self.config.reply_count,
Expand Down Expand Up @@ -370,6 +374,7 @@ impl ChatGPT {
stream: false,
temperature: self.config.temperature,
top_p: self.config.top_p,
response_format: self.config.response_format,
frequency_penalty: self.config.frequency_penalty,
presence_penalty: self.config.presence_penalty,
reply_count: self.config.reply_count,
Expand Down Expand Up @@ -407,6 +412,7 @@ impl ChatGPT {
stream: false,
temperature: self.config.temperature,
top_p: self.config.top_p,
response_format: self.config.response_format,
frequency_penalty: self.config.frequency_penalty,
presence_penalty: self.config.presence_penalty,
reply_count: self.config.reply_count,
Expand Down
37 changes: 36 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{fmt::Display, str::FromStr};
use std::time::Duration;
use std::{fmt::Display, str::FromStr};

#[cfg(feature = "functions")]
use crate::functions::FunctionValidationStrategy;
Expand All @@ -18,6 +18,8 @@ pub struct ModelConfiguration {
pub top_p: f32,
/// Controls the maximum number of tokens to generate in the completion
pub max_tokens: Option<u32>,
/// Format of the response from the model, e.g., text or json_object.
pub response_format: ResponseFormat,
/// Determines how much to penalize new tokens passed on their existing presence so far
pub presence_penalty: f32,
/// Determines how much to penalize new tokens based on their existing frequency so far
Expand All @@ -40,6 +42,7 @@ impl Default for ModelConfiguration {
temperature: 0.5,
top_p: 1.0,
max_tokens: None,
response_format: ResponseFormat::default(),
presence_penalty: 0.0,
frequency_penalty: 0.0,
reply_count: 1,
Expand All @@ -51,6 +54,38 @@ impl Default for ModelConfiguration {
}
}

/// Specifies the format of the response.
#[derive(Serialize, Debug, Clone, Copy, PartialEq, PartialOrd, Builder)]
#[builder(default, setter(into))]
pub struct ResponseFormat {
/// The type of format for the response.
pub format_type: FormatType,
}

impl Default for ResponseFormat {
fn default() -> Self {
Self {
format_type: FormatType::default(),
}
}
}

/// Specifies the type of the format.
#[derive(Serialize, Debug, Clone, Copy, PartialEq, PartialOrd)]
#[serde(rename_all = "snake_case")]
pub enum FormatType {
/// Standard text format.
Text,
/// JSON object format.
JsonObject,
}

impl Default for FormatType {
fn default() -> Self {
FormatType::Text
}
}

/// The engine version for ChatGPT
#[derive(Serialize, Debug, Default, Copy, Clone, PartialEq, PartialOrd)]
#[allow(non_camel_case_types)]
Expand Down
3 changes: 3 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::config::ResponseFormat;
#[cfg(feature = "functions")]
use crate::functions::FunctionCall;
use serde::{Deserialize, Deserializer, Serialize};
Expand Down Expand Up @@ -85,6 +86,8 @@ pub struct CompletionRequest<'a> {
pub stream: bool,
/// The extra randomness of response
pub temperature: f32,
/// Format of the response from the model, e.g., text or json_object.
pub response_format: ResponseFormat,
/// Controls diversity via nucleus sampling, not recommended to use with temperature
pub top_p: f32,
/// Controls the maximum number of tokens to generate in the completion
Expand Down