Skip to content

Commit

Permalink
Introduce id and model in ChatResponseMetadata
Browse files Browse the repository at this point in the history
* Extend ChatResponseMetadata for Anthropic (blocking, streaming)
* Add ChatResponseMetadata for Mistral AI (blocking)
* Extend ChatResponseMetadata for OpenAI (blocking)
* Deprecate gpt-4-vision-preview and replace its usage in tests because OpenAI rejects the calls (see: https://platform.openai.com/docs/deprecations)

Fixes gh-936

Signed-off-by: Thomas Vitale <[email protected]>
  • Loading branch information
ThomasVitale authored and tzolov committed Jun 24, 2024
1 parent 25a0372 commit 3227311
Show file tree
Hide file tree
Showing 17 changed files with 267 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,42 +30,53 @@
* {@link ChatResponseMetadata} implementation for {@literal AnthropicApi}.
*
* @author Christian Tzolov
* @author Thomas Vitale
* @see ChatResponseMetadata
* @see RateLimit
* @see Usage
* @since 1.0.0
*/
public class AnthropicChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {

protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }";

public static AnthropicChatResponseMetadata from(AnthropicApi.ChatCompletion result) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
AnthropicUsage usage = AnthropicUsage.from(result.usage());
return new AnthropicChatResponseMetadata(result.id(), usage);
return new AnthropicChatResponseMetadata(result.id(), result.model(), usage);
}

private final String id;

private final String model;

@Nullable
private RateLimit rateLimit;

private final Usage usage;

protected AnthropicChatResponseMetadata(String id, AnthropicUsage usage) {
this(id, usage, null);
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage) {
this(id, model, usage, null);
}

protected AnthropicChatResponseMetadata(String id, AnthropicUsage usage, @Nullable AnthropicRateLimit rateLimit) {
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage,
@Nullable AnthropicRateLimit rateLimit) {
this.id = id;
this.model = model;
this.usage = usage;
this.rateLimit = rateLimit;
}

@Override
public String getId() {
return this.id;
}

@Override
public String getModel() {
return this.model;
}

@Override
@Nullable
public RateLimit getRateLimit() {
Expand All @@ -86,7 +97,7 @@ public AnthropicChatResponseMetadata withRateLimit(RateLimit rateLimit) {

@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getUsage(), getRateLimit());
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.tool.MockWeatherService;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand Down Expand Up @@ -241,4 +242,43 @@ void functionCallTest() {
assertThat(generation.getOutput().getContent()).contains("30", "10", "15");
}

@Test
void validateCallResponseMetadata() {
String model = AnthropicApi.ChatModel.CLAUDE_2_1.getModelName();
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.options(AnthropicChatOptions.builder().withModel(model).build())
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.call()
.chatResponse();
// @formatter:on

logger.info(response.toString());
validateChatResponseMetadata(response, model);
}

@Test
void validateStreamCallResponseMetadata() {
String model = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getModelName();
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.options(AnthropicChatOptions.builder().withModel(model).build())
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.stream()
.chatResponse()
.blockLast();
// @formatter:on

logger.info(response.toString());
validateChatResponseMetadata(response, model);
}

private static void validateChatResponseMetadata(ChatResponse response, String model) {
assertThat(response.getMetadata().getId()).isNotEmpty();
assertThat(response.getMetadata().getModel()).containsIgnoringCase(model);
assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive();
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
* {@literal Microsoft Azure OpenAI Service}.
*
* @author John Blum
* @author Thomas Vitale
* @see ChatResponseMetadata
* @since 0.7.1
*/
Expand Down Expand Up @@ -59,6 +60,7 @@ protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, Pro
this.promptMetadata = promptMetadata;
}

@Override
public String getId() {
return this.id;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
import org.springframework.ai.mistralai.metadata.MistralAiChatResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
Expand Down Expand Up @@ -119,7 +120,7 @@ public ChatResponse call(Prompt prompt) {
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
.toList();

return new ChatResponse(generations);
return new ChatResponse(generations, MistralAiChatResponseMetadata.from(chatCompletion));
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.springframework.ai.mistralai.metadata;

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.util.Assert;

import java.util.HashMap;

/**
* {@link ChatResponseMetadata} implementation for {@literal Mistral AI}.
*
* @author Thomas Vitale
* @see ChatResponseMetadata
* @see Usage
* @since 1.0.0
*/
public class MistralAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {

protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s }";

public static MistralAiChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
MistralAiUsage usage = MistralAiUsage.from(result.usage());
return new MistralAiChatResponseMetadata(result.id(), result.model(), usage);
}

private final String id;

private final String model;

private final Usage usage;

protected MistralAiChatResponseMetadata(String id, String model, MistralAiUsage usage) {
this.id = id;
this.model = model;
this.usage = usage;
}

@Override
public String getId() {
return this.id;
}

@Override
public String getModel() {
return this.model;
}

@Override
public Usage getUsage() {
Usage usage = this.usage;
return usage != null ? usage : new EmptyUsage();
}

@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package org.springframework.ai.mistralai.metadata;

import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.util.Assert;

/**
* {@link Usage} implementation for {@literal Mistral AI}.
*
* @author Thomas Vitale
* @since 1.0.0
* @see <a href="https://docs.mistral.ai/api/">Chat Completion API</a>
*/
public class MistralAiUsage implements Usage {

public static MistralAiUsage from(MistralAiApi.Usage usage) {
return new MistralAiUsage(usage);
}

private final MistralAiApi.Usage usage;

protected MistralAiUsage(MistralAiApi.Usage usage) {
Assert.notNull(usage, "Mistral AI Usage must not be null");
this.usage = usage;
}

protected MistralAiApi.Usage getUsage() {
return this.usage;
}

@Override
public Long getPromptTokens() {
return getUsage().promptTokens().longValue();
}

@Override
public Long getGenerationTokens() {
return getUsage().completionTokens().longValue();
}

@Override
public Long getTotalTokens() {
return getUsage().totalTokens().longValue();
}

@Override
public String toString() {
return getUsage().toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,23 @@ void streamFunctionCallTest() {
assertThat(content).containsAnyOf("15.0", "15");
}

@Test
void validateCallResponseMetadata() {
String model = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getModelName();
// @formatter:off
ChatResponse response = ChatClient.create(chatModel).prompt()
.options(MistralAiChatOptions.builder().withModel(model).build())
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.call()
.chatResponse();
// @formatter:on

logger.info(response.toString());
assertThat(response.getMetadata().getId()).isNotEmpty();
assertThat(response.getMetadata().getModel()).containsIgnoringCase(model);
assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive();
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
* @author Hyunjoon Choi
* @author Mariusz Bernacki
* @author luocongqiu
* @author Thomas Vitale
* @see ChatModel
* @see StreamingChatModel
* @see OpenAiApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
/**
* @author Christian Tzolov
* @author Mariusz Bernacki
* @author Thomas Vitale
* @since 0.8.0
*/
@JsonInclude(Include.NON_NULL)
Expand All @@ -66,8 +67,7 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
private @JsonProperty("logit_bias") Map<String, Integer> logitBias;
/**
* Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities
* of each output token returned in the 'content' of 'message'. This option is currently not available
* on the 'gpt-4-vision-preview' model.
* of each output token returned in the 'content' of 'message'.
*/
private @JsonProperty("logprobs") Boolean logprobs;
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
* @author Christian Tzolov
* @author Michael Lavelle
* @author Mariusz Bernacki
* @author Thomas Vitale
*/
public class OpenAiApi {

Expand Down Expand Up @@ -124,7 +125,6 @@ public enum ChatModel implements ModelDescription {
*/
GPT_4_O("gpt-4o"),


/**
* GPT-4 Turbo with Vision
* The latest GPT-4 Turbo model with vision capabilities.
Expand All @@ -134,7 +134,7 @@ public enum ChatModel implements ModelDescription {
GPT_4_TURBO("gpt-4-turbo"),

/**
* GPT-4 Turbo with Vision model. Vision requests can now use JSON mode and function calling
* GPT-4 Turbo with Vision model. Vision requests can now use JSON mode and function calling.
*/
GPT_4_TURBO_2204_04_09("gpt-4-turbo-2024-04-09"),

Expand Down Expand Up @@ -162,6 +162,7 @@ public enum ChatModel implements ModelDescription {
* Returns a maximum of 4,096 output tokens
* Context window: 128k tokens
*/
@Deprecated(since = "1.0.0-M2", forRemoval = true) // Replaced by GPT_4_O
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),

/**
Expand All @@ -178,6 +179,7 @@ public enum ChatModel implements ModelDescription {
* function calling support.
* Context window: 32k tokens
*/
@Deprecated(since = "1.0.0-M2", forRemoval = true) // Replaced by GPT_4_O
GPT_4_32K("gpt-4-32k"),

/**
Expand Down Expand Up @@ -296,8 +298,7 @@ public Function(String description, String name, String jsonSchema) {
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100
* or 100 should result in a ban or exclusive selection of the relevant token.
* @param logprobs Whether to return log probabilities of the output tokens or not. If true, returns the log
* probabilities of each output token returned in the 'content' of 'message'. This option is currently not available
* on the 'gpt-4-vision-preview' model.
* probabilities of each output token returned in the 'content' of 'message'.
* @param topLogprobs An integer between 0 and 5 specifying the number of most likely tokens to return at each token
* position, each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used.
* @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input
Expand Down
Loading

0 comments on commit 3227311

Please sign in to comment.