Skip to content

Commit

Permalink
OpenAI: Add support for streamin token usage.
Browse files Browse the repository at this point in the history
 - OpenAiApi: add StreamingOptions class and ChatCompletionRequest#streamingOptions field.
 - add OpenAiChatOption#withStreamingUsage(boolean) to set/unset the StreamingOptions.
 - add a boolean (get/set)StreamUsage() to OpenAiChatOptions that internally set the SstreamOptions.
   Later allows the "spring.ai.openai.chat.options.stream-usage" property.
 - update the OpenAI property documentation.

Co-authored-by: Christian Tzolov <[email protected]>
  • Loading branch information
didalgolab and tzolov committed Jun 24, 2024
1 parent 12dbc1e commit 20e4b56
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,12 @@ public Flux<ChatResponse> stream(Prompt prompt) {
return generation;
}).toList();

return new ChatResponse(generations);
if (chatCompletion.usage() != null) {
return new ChatResponse(generations, OpenAiChatResponseMetadata.from(chatCompletion));
}
else {
return new ChatResponse(generations);
}
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
Expand All @@ -245,7 +250,7 @@ private OpenAiApi.ChatCompletion chunkToChatCompletion(OpenAiApi.ChatCompletionC
.toList();

return new OpenAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(),
chunk.systemFingerprint(), "chat.completion", null);
chunk.systemFingerprint(), "chat.completion", chunk.usage());
}

/**
Expand Down Expand Up @@ -306,6 +311,12 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
request, ChatCompletionRequest.class);
}

// Remove `streamOptions` from the request if it is not a streaming request
if (request.streamOptions() != null && !stream) {
logger.warn("Removing streamOptions from the request as it is not a streaming request!");
request = request.withStreamOptions(null);
}

return request;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
import org.springframework.ai.openai.api.OpenAiApi.FunctionTool;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
Expand All @@ -39,6 +40,7 @@

/**
* @author Christian Tzolov
* @author Mariusz Bernacki
* @since 0.8.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -93,6 +95,10 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
* "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
*/
private @JsonProperty("response_format") ResponseFormat responseFormat;
/**
* Options for streaming response. Included in the API only if streaming-mode completion is requested.
*/
private @JsonProperty("stream_options") StreamOptions streamOptions;
/**
* This feature is in Beta. If specified, our system will make a best effort to sample
* deterministically, such that repeated requests with the same seed and parameters should return the same result.
Expand Down Expand Up @@ -226,6 +232,13 @@ public Builder withResponseFormat(ResponseFormat responseFormat) {
return this;
}

public Builder withStreamUsage(boolean enableStreamUsage) {
if (enableStreamUsage) {
this.options.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null;
}
return this;
}

public Builder withSeed(Integer seed) {
this.options.seed = seed;
return this;
Expand Down Expand Up @@ -284,6 +297,14 @@ public OpenAiChatOptions build() {

}

public Boolean getStreamUsage() {
return this.streamOptions != null;
}

public void setStreamUsage(Boolean enableStreamUsage) {
this.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null;
}

public String getModel() {
return this.model;
}
Expand Down Expand Up @@ -356,6 +377,14 @@ public void setResponseFormat(ResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

public StreamOptions getStreamOptions() {
return streamOptions;
}

public void setStreamOptions(StreamOptions streamOptions) {
this.streamOptions = streamOptions;
}

public Integer getSeed() {
return this.seed;
}
Expand Down Expand Up @@ -446,6 +475,7 @@ public int hashCode() {
result = prime * result + ((n == null) ? 0 : n.hashCode());
result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode());
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
result = prime * result + ((streamOptions == null) ? 0 : streamOptions.hashCode());
result = prime * result + ((seed == null) ? 0 : seed.hashCode());
result = prime * result + ((stop == null) ? 0 : stop.hashCode());
result = prime * result + ((temperature == null) ? 0 : temperature.hashCode());
Expand Down Expand Up @@ -519,6 +549,12 @@ else if (!this.presencePenalty.equals(other.presencePenalty))
}
else if (!this.responseFormat.equals(other.responseFormat))
return false;
if (this.streamOptions == null) {
if (other.streamOptions != null)
return false;
}
else if (!this.streamOptions.equals(other.streamOptions))
return false;
if (this.seed == null) {
if (other.seed != null)
return false;
Expand Down Expand Up @@ -586,6 +622,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
.withN(fromOptions.getN())
.withPresencePenalty(fromOptions.getPresencePenalty())
.withResponseFormat(fromOptions.getResponseFormat())
.withStreamUsage(fromOptions.getStreamUsage())
.withSeed(fromOptions.getSeed())
.withStop(fromOptions.getStop())
.withTemperature(fromOptions.getTemperature())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
*
* @author Christian Tzolov
* @author Michael Lavelle
* @author Mariusz Bernacki
*/
public class OpenAiApi {

Expand Down Expand Up @@ -314,6 +315,7 @@ public Function(String description, String name, String jsonSchema) {
* @param stop Up to 4 sequences where the API will stop generating further tokens.
* @param stream If set, partial message deltas will be sent.Tokens will be sent as data-only server-sent events as
* they become available, with the stream terminated by a data: [DONE] message.
* @param streamOptions Options for streaming response. Only set this when you set.
* @param temperature What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
* more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend
* altering this or top_p but not both.
Expand Down Expand Up @@ -345,6 +347,7 @@ public record ChatCompletionRequest (
@JsonProperty("seed") Integer seed,
@JsonProperty("stop") List<String> stop,
@JsonProperty("stream") Boolean stream,
@JsonProperty("stream_options") StreamOptions streamOptions,
@JsonProperty("temperature") Float temperature,
@JsonProperty("top_p") Float topP,
@JsonProperty("tools") List<FunctionTool> tools,
Expand All @@ -360,7 +363,7 @@ public record ChatCompletionRequest (
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature) {
this(messages, model, null, null, null, null, null, null, null,
null, null, null, false, temperature, null,
null, null, null, false, null, temperature, null,
null, null, null);
}

Expand All @@ -375,7 +378,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature, boolean stream) {
this(messages, model, null, null, null, null, null, null, null,
null, null, null, stream, temperature, null,
null, null, null, stream, null, temperature, null,
null, null, null);
}

Expand All @@ -391,7 +394,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
List<FunctionTool> tools, Object toolChoice) {
this(messages, model, null, null, null, null, null, null, null,
null, null, null, false, 0.8f, null,
null, null, null, false, null, 0.8f, null,
tools, toolChoice, null);
}

Expand All @@ -404,10 +407,22 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null, null, null, null,
null, null, null, stream, null, null,
null, null, null, stream, null, null, null,
null, null, null);
}

/**
* Sets the {@link StreamOptions} for this request.
*
* @param streamOptions The new stream options to use.
* @return A new {@link ChatCompletionRequest} with the specified stream options.
*/
public ChatCompletionRequest withStreamOptions(StreamOptions streamOptions) {
return new ChatCompletionRequest(messages, model, frequencyPenalty, logitBias, logprobs, topLogprobs, maxTokens, n, presencePenalty,
responseFormat, seed, stop, stream, streamOptions, temperature, topP,
tools, toolChoice, user);
}

/**
* Helper factory that creates a tool_choice of type 'none', 'auto' or selected function by name.
*/
Expand Down Expand Up @@ -437,6 +452,20 @@ public static Object FUNCTION(String functionName) {
public record ResponseFormat(
@JsonProperty("type") String type) {
}

/**
* @param includeUsage If set, an additional chunk will be streamed
* before the data: [DONE] message. The usage field on this chunk
* shows the token usage statistics for the entire request, and
* the choices field will always be an empty array. All other chunks
* will also include a usage field, but with a null value.
*/
@JsonInclude(Include.NON_NULL)
public record StreamOptions(
@JsonProperty("include_usage") Boolean includeUsage) {

public static StreamOptions INCLUDE_USAGE = new StreamOptions(true);
}
}

/**
Expand Down Expand Up @@ -742,7 +771,8 @@ public record ChatCompletionChunk(
@JsonProperty("created") Long created,
@JsonProperty("model") String model,
@JsonProperty("system_fingerprint") String systemFingerprint,
@JsonProperty("object") String object) {
@JsonProperty("object") String object,
@JsonProperty("usage") Usage usage) {

/**
* Chat completion choice.
Expand Down Expand Up @@ -825,7 +855,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
// Flux<Flux<ChatCompletionChunk>> -> Flux<Mono<ChatCompletionChunk>>
.concatMapIterable(window -> {
Mono<ChatCompletionChunk> monoChunk = window.reduce(
new ChatCompletionChunk(null, null, null, null, null, null),
new ChatCompletionChunk(null, null, null, null, null, null, null),
(previous, current) -> this.chunkMerger.merge(previous, current));
return List.of(monoChunk);
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCompletionFunction;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.openai.api.OpenAiApi.Usage;
import org.springframework.util.CollectionUtils;

/**
Expand Down Expand Up @@ -58,13 +59,14 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu
String systemFingerprint = (current.systemFingerprint() != null ? current.systemFingerprint()
: previous.systemFingerprint());
String object = (current.object() != null ? current.object() : previous.object());
Usage usage = (current.usage() != null ? current.usage() : previous.usage());

ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0));
ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0));

ChunkChoice choice = merge(previousChoice0, currentChoice0);
List<ChunkChoice> chunkChoices = choice == null ? List.of() : List.of(choice);
return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object);
return new ChatCompletionChunk(id, chunkChoices, created, model, systemFingerprint, object, usage);
}

private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class MessageTypeContentTests {
ArgumentCaptor<ChatCompletionRequest> promptCaptor;

Flux<ChatCompletionChunk> fluxResponse = Flux
.generate(() -> new ChatCompletionChunk("id", List.of(), 0l, "model", "fp", "object"), (state, sink) -> {
.generate(() -> new ChatCompletionChunk("id", List.of(), 0l, "model", "fp", "object", null), (state, sink) -> {
sink.next(state);
sink.complete();
return state;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.springframework.ai.openai.chat;

import static org.assertj.core.api.Assertions.assertThat;

import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
Expand All @@ -29,14 +31,12 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
Expand All @@ -56,7 +56,7 @@
import org.springframework.core.io.Resource;
import org.springframework.util.MimeTypeUtils;

import static org.assertj.core.api.Assertions.assertThat;
import reactor.core.publisher.Flux;

@SpringBootTest(classes = OpenAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
Expand Down Expand Up @@ -103,6 +103,24 @@ void streamRoleTest() {

}

@Test
void streamingWithTokenUsage() {
var promptOptions = OpenAiChatOptions.builder().withStreamUsage(true).withSeed(1).build();

var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();

assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);

assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens());
assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens());

}

@Test
void listOutputConverter() {
DefaultConversionService conversionService = new DefaultConversionService();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public void openAiChatStreamTransientError() {
var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
new ChatCompletionMessage("Response", Role.ASSISTANT), null);
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null,
null);
null, null);

when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
.thenThrow(new TransientAiException("Transient Error 1"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ The prefix `spring.ai.openai.chat` is the property prefix that lets you configur
| spring.ai.openai.chat.options.toolChoice | Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function. none is the default when no functions are present. auto is the default if functions are present. | -
| spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | -
| spring.ai.openai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | -
| spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chuk is an empty array and all other chunks will also include a usage field, but with a null value. | false
|====

NOTE: You can override the common `spring.ai.openai.base-url` and `spring.ai.openai.api-key` for the `ChatModel` and `EmbeddingModel` implementations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.embedding.EmbeddingResponse;
Expand Down Expand Up @@ -113,6 +114,28 @@ void generateStreaming() {
});
}

@Test
void streamingWithTokenUsage() {
contextRunner.withPropertyValues("spring.ai.openai.chat.options.stream-usage=true").run(context -> {
OpenAiChatModel chatModel = context.getBean(OpenAiChatModel.class);

Flux<ChatResponse> responseFlux = chatModel.stream(new Prompt(new UserMessage("Hello")));

Usage[] streamingTokenUsage = new Usage[1];
String response = responseFlux.collectList().block().stream().map(chatResponse -> {
streamingTokenUsage[0] = chatResponse.getMetadata().getUsage();
return (chatResponse.getResult() != null) ? chatResponse.getResult().getOutput().getContent() : "";
}).collect(Collectors.joining());

assertThat(streamingTokenUsage[0].getPromptTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage[0].getGenerationTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage[0].getTotalTokens()).isGreaterThan(0);

assertThat(response).isNotEmpty();
logger.info("Response: " + response);
});
}

@Test
void embedding() {
contextRunner.run(context -> {
Expand Down

0 comments on commit 20e4b56

Please sign in to comment.