Skip to content

Commit

Permalink
Merge pull request #32 from Aleph-Alpha/stream-completion
Browse files Browse the repository at this point in the history
feat: add stream completion method
  • Loading branch information
moldhouse authored Oct 28, 2024
2 parents 390f24f + c513a36 commit 2ef4d7f
Show file tree
Hide file tree
Showing 7 changed files with 519 additions and 44 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ categories = ["api-bindings"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-stream = "0.3.6"
base64 = "0.22.0"
futures-util = "0.3.31"
image = "0.25.1"
itertools = "0.13.0"
reqwest = { version = "0.12.3", features = ["json"] }
reqwest = { version = "0.12.3", features = ["json", "stream"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
thiserror = "1.0.58"
Expand Down
63 changes: 62 additions & 1 deletion src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Cow;

use serde::{Deserialize, Serialize};

use crate::Task;
use crate::{StreamTask, Task};

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
Expand Down Expand Up @@ -126,6 +126,9 @@ struct ChatBody<'a> {
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
/// Whether to stream the response or not.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> ChatBody<'a> {
Expand All @@ -136,8 +139,14 @@ impl<'a> ChatBody<'a> {
maximum_tokens: task.maximum_tokens,
temperature: task.temperature,
top_p: task.top_p,
stream: false,
}
}

pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

impl<'a> Task for TaskChat<'a> {
Expand All @@ -159,3 +168,55 @@ impl<'a> Task for TaskChat<'a> {
response.choices.pop().unwrap()
}
}

#[derive(Deserialize)]
pub struct StreamMessage {
/// The role of the current chat completion. Will be assistant for the first chunk of every
/// completion stream and missing for the remaining chunks.
pub role: Option<String>,
/// The content of the current chat completion. Will be empty for the first chunk of every
/// completion stream and non-empty for the remaining chunks.
pub content: String,
}

/// One chunk of a chat completion stream.
#[derive(Deserialize)]
pub struct ChatStreamChunk {
/// The reason the model stopped generating tokens.
/// The value is only set in the last chunk of a completion and null otherwise.
pub finish_reason: Option<String>,
/// Chat completion chunk generated by the model when streaming is enabled.
pub delta: StreamMessage,
}

/// Event received from a chat completion stream. As the crate does not support multiple
/// chat completions, there will always exactly one choice item.
#[derive(Deserialize)]
pub struct ChatEvent {
pub choices: Vec<ChatStreamChunk>,
}

impl<'a> StreamTask for TaskChat<'a> {
type Output = ChatStreamChunk;

type ResponseBody = ChatEvent;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = ChatBody::new(model, &self).with_streaming();
client.post(format!("{base}/chat/completions")).json(&body)
}

fn body_to_output(mut response: Self::ResponseBody) -> Self::Output {
// We always expect there to be exactly one choice, as the `n` parameter is not
// supported by this crate.
response
.choices
.pop()
.expect("There must always be at least one choice")
}
}
75 changes: 74 additions & 1 deletion src/completion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use crate::{http::Task, Prompt};
use crate::{http::Task, Prompt, StreamTask};

/// Completes a prompt. E.g. continues a text.
pub struct TaskCompletion<'a> {
Expand Down Expand Up @@ -142,6 +142,9 @@ struct BodyCompletion<'a> {
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "<[_]>::is_empty")]
pub completion_bias_inclusion: &'a [&'a str],
/// If true, the response will be streamed.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> BodyCompletion<'a> {
Expand All @@ -155,8 +158,13 @@ impl<'a> BodyCompletion<'a> {
top_k: task.sampling.top_k,
top_p: task.sampling.top_p,
completion_bias_inclusion: task.sampling.start_with_one_of,
stream: false,
}
}
pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -205,3 +213,68 @@ impl Task for TaskCompletion<'_> {
response.completions.pop().unwrap()
}
}

/// Describes a chunk of a completion stream
#[derive(Deserialize, Debug)]
pub struct StreamChunk {
/// The index of the stream that this chunk belongs to.
/// This is relevant if multiple completion streams are requested (see parameter n).
pub index: u32,
/// The completion of the stream.
pub completion: String,
}

/// Denotes the end of a completion stream.
///
/// The index of the stream that is being terminated is not deserialized.
/// It is only relevant if multiple completion streams are requested, (see parameter n),
/// which is not supported by this crate yet.
#[derive(Deserialize)]
pub struct StreamSummary {
/// Model name and version (if any) of the used model for inference.
pub model_version: String,
/// The reason why the model stopped generating new tokens.
pub finish_reason: String,
}

/// Denotes the end of all completion streams.
#[derive(Deserialize)]
pub struct CompletionSummary {
/// Number of tokens combined across all completion tasks.
/// In particular, if you set best_of or n to a number larger than 1 then we report the
/// combined prompt token count for all best_of or n tasks.
pub num_tokens_prompt_total: u32,
/// Number of tokens combined across all completion tasks.
/// If multiple completions are returned or best_of is set to a value greater than 1 then
/// this value contains the combined generated token count.
pub num_tokens_generated: u32,
}

#[derive(Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum CompletionEvent {
StreamChunk(StreamChunk),
StreamSummary(StreamSummary),
CompletionSummary(CompletionSummary),
}

impl StreamTask for TaskCompletion<'_> {
type Output = CompletionEvent;

type ResponseBody = CompletionEvent;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyCompletion::new(model, &self).with_streaming();
client.post(format!("{base}/complete")).json(&body)
}

fn body_to_output(response: Self::ResponseBody) -> Self::Output {
response
}
}
Loading

0 comments on commit 2ef4d7f

Please sign in to comment.