diff --git a/Cargo.toml b/Cargo.toml index 8bf99d5..be56da6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,12 @@ categories = ["api-bindings"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-stream = "0.3.6" base64 = "0.22.0" +futures-util = "0.3.31" image = "0.25.1" itertools = "0.13.0" -reqwest = { version = "0.12.3", features = ["json"] } +reqwest = { version = "0.12.3", features = ["json", "stream"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.115" thiserror = "1.0.58" diff --git a/src/chat.rs b/src/chat.rs index b4a29fd..35db048 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; use serde::{Deserialize, Serialize}; -use crate::Task; +use crate::{StreamTask, Task}; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct Message<'a> { @@ -126,6 +126,9 @@ struct ChatBody<'a> { /// When no value is provided, the default value of 1 will be used. #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, + /// Whether to stream the response or not. + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub stream: bool, } impl<'a> ChatBody<'a> { @@ -136,8 +139,14 @@ impl<'a> ChatBody<'a> { maximum_tokens: task.maximum_tokens, temperature: task.temperature, top_p: task.top_p, + stream: false, } } + + pub fn with_streaming(mut self) -> Self { + self.stream = true; + self + } } impl<'a> Task for TaskChat<'a> { @@ -159,3 +168,55 @@ impl<'a> Task for TaskChat<'a> { response.choices.pop().unwrap() } } + +#[derive(Deserialize)] +pub struct StreamMessage { + /// The role of the current chat completion. Will be assistant for the first chunk of every + /// completion stream and missing for the remaining chunks. + pub role: Option, + /// The content of the current chat completion. Will be empty for the first chunk of every + /// completion stream and non-empty for the remaining chunks. + pub content: String, +} + +/// One chunk of a chat completion stream. +#[derive(Deserialize)] +pub struct ChatStreamChunk { + /// The reason the model stopped generating tokens. + /// The value is only set in the last chunk of a completion and null otherwise. + pub finish_reason: Option, + /// Chat completion chunk generated by the model when streaming is enabled. + pub delta: StreamMessage, +} + +/// Event received from a chat completion stream. As the crate does not support multiple +/// chat completions, there will always exactly one choice item. +#[derive(Deserialize)] +pub struct ChatEvent { + pub choices: Vec, +} + +impl<'a> StreamTask for TaskChat<'a> { + type Output = ChatStreamChunk; + + type ResponseBody = ChatEvent; + + fn build_request( + &self, + client: &reqwest::Client, + base: &str, + model: &str, + ) -> reqwest::RequestBuilder { + let body = ChatBody::new(model, &self).with_streaming(); + client.post(format!("{base}/chat/completions")).json(&body) + } + + fn body_to_output(mut response: Self::ResponseBody) -> Self::Output { + // We always expect there to be exactly one choice, as the `n` parameter is not + // supported by this crate. + response + .choices + .pop() + .expect("There must always be at least one choice") + } +} diff --git a/src/completion.rs b/src/completion.rs index 58e1a8f..70301ba 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::{http::Task, Prompt}; +use crate::{http::Task, Prompt, StreamTask}; /// Completes a prompt. E.g. continues a text. pub struct TaskCompletion<'a> { @@ -142,6 +142,9 @@ struct BodyCompletion<'a> { pub top_p: Option, #[serde(skip_serializing_if = "<[_]>::is_empty")] pub completion_bias_inclusion: &'a [&'a str], + /// If true, the response will be streamed. + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub stream: bool, } impl<'a> BodyCompletion<'a> { @@ -155,8 +158,13 @@ impl<'a> BodyCompletion<'a> { top_k: task.sampling.top_k, top_p: task.sampling.top_p, completion_bias_inclusion: task.sampling.start_with_one_of, + stream: false, } } + pub fn with_streaming(mut self) -> Self { + self.stream = true; + self + } } #[derive(Deserialize, Debug, PartialEq, Eq)] @@ -205,3 +213,68 @@ impl Task for TaskCompletion<'_> { response.completions.pop().unwrap() } } + +/// Describes a chunk of a completion stream +#[derive(Deserialize, Debug)] +pub struct StreamChunk { + /// The index of the stream that this chunk belongs to. + /// This is relevant if multiple completion streams are requested (see parameter n). + pub index: u32, + /// The completion of the stream. + pub completion: String, +} + +/// Denotes the end of a completion stream. +/// +/// The index of the stream that is being terminated is not deserialized. +/// It is only relevant if multiple completion streams are requested, (see parameter n), +/// which is not supported by this crate yet. +#[derive(Deserialize)] +pub struct StreamSummary { + /// Model name and version (if any) of the used model for inference. + pub model_version: String, + /// The reason why the model stopped generating new tokens. + pub finish_reason: String, +} + +/// Denotes the end of all completion streams. +#[derive(Deserialize)] +pub struct CompletionSummary { + /// Number of tokens combined across all completion tasks. + /// In particular, if you set best_of or n to a number larger than 1 then we report the + /// combined prompt token count for all best_of or n tasks. + pub num_tokens_prompt_total: u32, + /// Number of tokens combined across all completion tasks. + /// If multiple completions are returned or best_of is set to a value greater than 1 then + /// this value contains the combined generated token count. + pub num_tokens_generated: u32, +} + +#[derive(Deserialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum CompletionEvent { + StreamChunk(StreamChunk), + StreamSummary(StreamSummary), + CompletionSummary(CompletionSummary), +} + +impl StreamTask for TaskCompletion<'_> { + type Output = CompletionEvent; + + type ResponseBody = CompletionEvent; + + fn build_request( + &self, + client: &reqwest::Client, + base: &str, + model: &str, + ) -> reqwest::RequestBuilder { + let body = BodyCompletion::new(model, &self).with_streaming(); + client.post(format!("{base}/complete")).json(&body) + } + + fn body_to_output(response: Self::ResponseBody) -> Self::Output { + response + } +} diff --git a/src/http.rs b/src/http.rs index fd4b62a..bccbb76 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,11 +1,13 @@ -use std::{borrow::Cow, time::Duration}; +use std::{borrow::Cow, pin::Pin, time::Duration}; -use reqwest::{header, ClientBuilder, RequestBuilder, StatusCode}; +use futures_util::{stream::StreamExt, Stream}; +use reqwest::{header, ClientBuilder, RequestBuilder, Response, StatusCode}; use serde::Deserialize; use thiserror::Error as ThisError; use tokenizers::Tokenizer; -use crate::How; +use crate::{How, StreamJob}; +use async_stream::stream; /// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required /// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is @@ -100,6 +102,36 @@ impl HttpClient { }) } + /// Construct and execute a request building on top of a `RequestBuilder` + async fn response(&self, builder: RequestBuilder, how: &How) -> Result { + let query = if how.be_nice { + [("nice", "true")].as_slice() + } else { + // nice=false is default, so we just omit it. + [].as_slice() + }; + + let api_token = how + .api_token + .as_ref() + .or(self.api_token.as_ref()) + .expect("API token needs to be set on client construction or per request"); + let response = builder + .query(query) + .header(header::AUTHORIZATION, Self::header_from_token(api_token)) + .timeout(how.client_timeout) + .send() + .await + .map_err(|reqwest_error| { + if reqwest_error.is_timeout() { + Error::ClientTimeout(how.client_timeout) + } else { + reqwest_error.into() + } + })?; + translate_http_error(response).await + } + /// Execute a task with the aleph alpha API and fetch its result. /// /// ```no_run @@ -126,38 +158,59 @@ impl HttpClient { /// } /// ``` pub async fn output_of(&self, task: &T, how: &How) -> Result { - let query = if how.be_nice { - [("nice", "true")].as_slice() - } else { - // nice=false is default, so we just omit it. - [].as_slice() - }; - - let api_token = how - .api_token - .as_ref() - .or(self.api_token.as_ref()) - .expect("API token needs to be set on client construction or per request"); - let response = task - .build_request(&self.http, &self.base) - .query(query) - .header(header::AUTHORIZATION, Self::header_from_token(api_token)) - .timeout(how.client_timeout) - .send() - .await - .map_err(|reqwest_error| { - if reqwest_error.is_timeout() { - Error::ClientTimeout(how.client_timeout) - } else { - reqwest_error.into() - } - })?; - let response = translate_http_error(response).await?; + let builder = task.build_request(&self.http, &self.base); + let response = self.response(builder, how).await?; let response_body: T::ResponseBody = response.json().await?; let answer = task.body_to_output(response_body); Ok(answer) } + pub async fn stream_output_of( + &self, + task: &T, + how: &How, + ) -> Result> + Send>>, Error> + where + T::Output: 'static, + { + let builder = task.build_request(&self.http, &self.base); + let response = self.response(builder, how).await?; + let mut stream = response.bytes_stream(); + + Ok(Box::pin(stream! { + while let Some(item) = stream.next().await { + match item { + Ok(bytes) => { + let events = Self::parse_stream_event::(bytes.as_ref()); + for event in events { + yield event.map(|b| T::body_to_output(b)); + } + } + Err(e) => { + yield Err(e.into()); + } + } + } + })) + } + + /// Take a byte slice (of a SSE) and parse it into a provided response body. + /// Each SSE event is expected to contain one or multiple JSON bodies prefixed by `data: `. + fn parse_stream_event(bytes: &[u8]) -> Vec> + where + StreamBody: for<'de> Deserialize<'de>, + { + String::from_utf8_lossy(bytes) + .split("data: ") + .skip(1) + .map(|s| { + serde_json::from_str(s).map_err(|e| Error::InvalidStream { + deserialization_error: e.to_string(), + }) + }) + .collect() + } + fn header_from_token(api_token: &str) -> header::HeaderValue { let mut auth_value = header::HeaderValue::from_str(&format!("Bearer {api_token}")).unwrap(); // Consider marking security-sensitive headers with `set_sensitive`. @@ -264,7 +317,87 @@ pub enum Error { deserialization_error )] InvalidTokenizer { deserialization_error: String }, + /// Deserialization error of the stream event. + #[error( + "Stream event could not be correctly deserialized. Caused by:\n{}.", + deserialization_error + )] + InvalidStream { deserialization_error: String }, /// Most likely either TLS errors creating the Client, or IO errors. #[error(transparent)] Other(#[from] reqwest::Error), } + +#[cfg(test)] +mod tests { + use crate::{chat::ChatEvent, completion::CompletionEvent}; + + use super::*; + + #[test] + fn stream_chunk_event_is_parsed() { + // Given some bytes + let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n"; + + // When they are parsed + let events = HttpClient::parse_stream_event::(bytes); + let event = events.first().unwrap().as_ref().unwrap(); + + // Then the event is a stream chunk + match event { + CompletionEvent::StreamChunk(chunk) => assert_eq!(chunk.index, 0), + _ => panic!("Expected a stream chunk"), + } + } + + #[test] + fn completion_summary_event_is_parsed() { + // Given some bytes with a stream summary and a completion summary + let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n"; + + // When they are parsed + let events = HttpClient::parse_stream_event::(bytes); + + // Then the first event is a stream summary and the last event is a completion summary + let first = events.first().unwrap().as_ref().unwrap(); + match first { + CompletionEvent::StreamSummary(summary) => { + assert_eq!(summary.finish_reason, "maximum_tokens") + } + _ => panic!("Expected a completion summary"), + } + let second = events.last().unwrap().as_ref().unwrap(); + match second { + CompletionEvent::CompletionSummary(summary) => { + assert_eq!(summary.num_tokens_generated, 7) + } + _ => panic!("Expected a completion summary"), + } + } + + #[test] + fn chat_stream_chunk_event_is_parsed() { + // Given some bytes + let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n"; + + // When they are parsed + let events = HttpClient::parse_stream_event::(bytes); + let event = events.first().unwrap().as_ref().unwrap(); + + // Then the event is a chat stream chunk + assert_eq!(event.choices[0].delta.role.as_ref().unwrap(), "assistant"); + } + + #[test] + fn chat_stream_chunk_without_role_is_parsed() { + // Given some bytes without a role + let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n"; + + // When they are parsed + let events = HttpClient::parse_stream_event::(bytes); + let event = events.first().unwrap().as_ref().unwrap(); + + // Then the event is a chat stream chunk + assert_eq!(event.choices[0].delta.content, "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask."); + } +} diff --git a/src/lib.rs b/src/lib.rs index edc0283..e1bf8a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,15 +31,19 @@ mod http; mod image_preprocessing; mod prompt; mod semantic_embedding; +mod stream; mod tokenization; -use std::time::Duration; +use std::{pin::Pin, time::Duration}; +use futures_util::Stream; use http::HttpClient; use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; use tokenizers::Tokenizer; pub use self::{ + chat::{ChatEvent, ChatStreamChunk}, chat::{ChatOutput, Message, TaskChat}, + completion::{CompletionEvent, CompletionSummary, StreamChunk, StreamSummary}, completion::{CompletionOutput, Sampling, Stopping, TaskCompletion}, detokenization::{DetokenizationOutput, TaskDetokenization}, explanation::{ @@ -51,6 +55,7 @@ pub use self::{ semantic_embedding::{ SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding, }, + stream::{StreamJob, StreamTask}, tokenization::{TaskTokenization, TokenizationOutput}, }; @@ -186,7 +191,47 @@ impl Client { how: &How, ) -> Result { self.http_client - .output_of(&task.with_model(model), how) + .output_of(&Task::with_model(task, model), how) + .await + } + + /// Instruct a model served by the aleph alpha API to continue writing a piece of text. + /// Stream the response as a series of events. + /// + /// ```no_run + /// use aleph_alpha_client::{Client, How, TaskCompletion, Error, CompletionEvent}; + /// use futures_util::StreamExt; + /// + /// async fn print_stream_completion() -> Result<(), Error> { + /// // Authenticate against API. Fetches token. + /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// + /// // Name of the model we we want to use. Large models give usually better answer, but are + /// // also slower and more costly. + /// let model = "luminous-base"; + /// + /// // The task we want to perform. Here we want to continue the sentence: "An apple a day + /// // ..." + /// let task = TaskCompletion::from_text("An apple a day"); + /// + /// // Retrieve stream from API + /// let mut stream = client.stream_completion(&task, model, &How::default()).await?; + /// while let Some(Ok(event)) = stream.next().await { + /// if let CompletionEvent::StreamChunk(chunk) = event { + /// println!("{}", chunk.completion); + /// } + /// } + /// Ok(()) + /// } + /// ``` + pub async fn stream_completion( + &self, + task: &TaskCompletion<'_>, + model: &str, + how: &How, + ) -> Result> + Send>>, Error> { + self.http_client + .stream_output_of(&Task::with_model(task, model), how) .await } @@ -194,7 +239,7 @@ impl Client { /// ```no_run /// use aleph_alpha_client::{Client, How, TaskChat, Error, Message}; /// - /// async fn chat() -> Result<(), Error> { + /// async fn print_chat() -> Result<(), Error> { /// // Authenticate against API. Fetches token. /// let client = Client::with_authentication("AA_API_TOKEN")?; /// @@ -213,14 +258,49 @@ impl Client { /// Ok(()) /// } /// ``` - pub async fn chat<'a>( - &'a self, - task: &'a TaskChat<'a>, - model: &'a str, - how: &'a How, + pub async fn chat( + &self, + task: &TaskChat<'_>, + model: &str, + how: &How, ) -> Result { self.http_client - .output_of(&task.with_model(model), how) + .output_of(&Task::with_model(task, model), how) + .await + } + + /// Send a chat message to a model. Stream the response as a series of events. + /// ```no_run + /// use aleph_alpha_client::{Client, How, TaskChat, Error, Message}; + /// use futures_util::StreamExt; + /// + /// async fn print_stream_chat() -> Result<(), Error> { + /// // Authenticate against API. Fetches token. + /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// + /// // Name of a model that supports chat. + /// let model = "pharia-1-llm-7b-control"; + /// + /// // Create a chat task with a user message. + /// let message = Message::user("Hello, how are you?"); + /// let task = TaskChat::with_message(message); + /// + /// // Send the message to the model. + /// let mut stream = client.stream_chat(&task, model, &How::default()).await?; + /// while let Some(Ok(event)) = stream.next().await { + /// println!("{}", event.delta.content); + /// } + /// Ok(()) + /// } + /// ``` + pub async fn stream_chat( + &self, + task: &TaskChat<'_>, + model: &str, + how: &How, + ) -> Result> + Send>>, Error> { + self.http_client + .stream_output_of(&StreamTask::with_model(task, model), how) .await } diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..0fe5e52 --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,68 @@ +use reqwest::RequestBuilder; +use serde::Deserialize; + +use crate::http::MethodJob; + +/// A job send to the Aleph Alpha Api using the http client. A job wraps all the knowledge required +/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is +/// executed on. This allows this trait to hold in the presence of services, which use more than one +/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`] +/// can not implement this trait directly, since its result would depend on what model is chosen to +/// execute it. You can remedy this by turning completion task into a job, calling +/// [`Task::with_model`]. +pub trait StreamJob { + /// Output returned by [`crate::Client::output_of`] + type Output: Send; + + /// Expected answer of the Aleph Alpha API + type ResponseBody: for<'de> Deserialize<'de> + Send; + + /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be + /// already set. + fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder; + + /// Parses the response of the server into higher level structs for the user. + fn body_to_output(response: Self::ResponseBody) -> Self::Output; +} + +/// A task send to the Aleph Alpha Api using the http client. Requires to specify a model before it +/// can be executed. Will return a stream of results. +pub trait StreamTask { + /// Output returned by [`crate::Client::output_of`] + type Output: Send; + + /// Expected answer of the Aleph Alpha API + type ResponseBody: for<'de> Deserialize<'de> + Send; + + /// Prepare the request for the Aleph Alpha API. Authentication headers can be assumed to be + /// already set. + fn build_request(&self, client: &reqwest::Client, base: &str, model: &str) -> RequestBuilder; + + /// Parses the response of the server into higher level structs for the user. + fn body_to_output(response: Self::ResponseBody) -> Self::Output; + + /// Turn your task into [`Job`] by annotating it with a model name. + fn with_model<'a>(&'a self, model: &'a str) -> MethodJob<'a, Self> + where + Self: Sized, + { + MethodJob { model, task: self } + } +} + +impl<'a, T> StreamJob for MethodJob<'a, T> +where + T: StreamTask, +{ + type Output = T::Output; + + type ResponseBody = T::ResponseBody; + + fn build_request(&self, client: &reqwest::Client, base: &str) -> RequestBuilder { + self.task.build_request(client, base, self.model) + } + + fn body_to_output(response: T::ResponseBody) -> T::Output { + T::body_to_output(response) + } +} diff --git a/tests/integration.rs b/tests/integration.rs index bb17fed..7be5332 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,12 +1,13 @@ use std::{fs::File, io::BufReader, sync::OnceLock}; use aleph_alpha_client::{ - cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Message, Modality, - Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, + cosine_similarity, Client, CompletionEvent, Granularity, How, ImageScore, ItemExplanation, + Message, Modality, Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation, TaskSemanticEmbedding, TaskTokenization, TextScore, }; use dotenv::dotenv; +use futures_util::StreamExt; use image::ImageFormat; fn api_token() -> &'static str { @@ -553,3 +554,60 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Then assert_eq!(128_000, tokenizer.get_vocab_size(true)); } + +#[tokio::test] +async fn stream_completion() { + // Given a streaming completion task + let client = Client::with_authentication(api_token()).unwrap(); + let task = TaskCompletion::from_text("").with_maximum_tokens(7); + + // When the events are streamed and collected + let mut stream = client + .stream_completion(&task, "luminous-base", &How::default()) + .await + .unwrap(); + + let mut events = Vec::new(); + while let Some(Ok(event)) = stream.next().await { + events.push(event); + } + + // Then there are at least one chunk, one summary and one completion summary + assert!(events.len() >= 3); + assert!(matches!( + events[events.len() - 3], + CompletionEvent::StreamChunk(_) + )); + assert!(matches!( + events[events.len() - 2], + CompletionEvent::StreamSummary(_) + )); + assert!(matches!( + events[events.len() - 1], + CompletionEvent::CompletionSummary(_) + )); +} + +#[tokio::test] +async fn stream_chat_with_pharia_1_llm_7b() { + // Given a streaming completion task + let client = Client::with_authentication(api_token()).unwrap(); + let message = Message::user("Hello,"); + let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7); + + // When the events are streamed and collected + let mut stream = client + .stream_chat(&task, "pharia-1-llm-7b-control", &How::default()) + .await + .unwrap(); + + let mut events = Vec::new(); + while let Some(Ok(event)) = stream.next().await { + events.push(event); + } + + // Then there are at least two chunks, with the second one having no role + assert!(events.len() >= 2); + assert_eq!(events[0].delta.role.as_ref().unwrap(), "assistant"); + assert_eq!(events[1].delta.role, None); +}