Skip to content

Commit

Permalink
Fixing Log probability information
Browse files Browse the repository at this point in the history
  • Loading branch information
ricken07 authored and tzolov committed Apr 26, 2024
1 parent 7e03a15 commit 5f9ecdd
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
List<Choice> choices = chunk.choices()
.stream()
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason()))
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason(), cc.logprobs()))
.toList();

return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,17 +575,65 @@ public record ChatCompletion(
* @param index The index of the choice in the list of choices.
* @param message A chat completion message generated by the model.
* @param finishReason The reason the model stopped generating tokens.
* @param logprobs Log probability information for the choice.
*/
@JsonInclude(Include.NON_NULL)
public record Choice(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("message") ChatCompletionMessage message,
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("logprobs") LogProbs logprobs) {
// @formatter:on
}
}

/**
*
* Log probability information for the choice. anticipation of future changes.
*
* @param content A list of message content tokens with log probability information.
*/
@JsonInclude(Include.NON_NULL)
public record LogProbs(@JsonProperty("content") List<Content> content) {

/**
* Message content tokens with log probability information.
*
* @param token The token.
* @param logprob The log probability of the token.
* @param probBytes A list of integers representing the UTF-8 bytes representation
* of the token. Useful in instances where characters are represented by multiple
* tokens and their byte representations must be combined to generate the correct
* text representation. Can be null if there is no bytes representation for the
* token.
* @param topLogprobs List of the most likely tokens and their log probability, at
* this token position. In rare cases, there may be fewer than the number of
* requested top_logprobs returned.
*/
@JsonInclude(Include.NON_NULL)
public record Content(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob,
@JsonProperty("bytes") List<Integer> probBytes,
@JsonProperty("top_logprobs") List<TopLogProbs> topLogprobs) {

/**
* The most likely tokens and their log probability, at this token position.
*
* @param token The token.
* @param logprob The log probability of the token.
* @param probBytes A list of integers representing the UTF-8 bytes
* representation of the token. Useful in instances where characters are
* represented by multiple tokens and their byte representations must be
* combined to generate the correct text representation. Can be null if there
* is no bytes representation for the token.
*/
@JsonInclude(Include.NON_NULL)
public record TopLogProbs(@JsonProperty("token") String token, @JsonProperty("logprob") Float logprob,
@JsonProperty("bytes") List<Integer> probBytes) {
}
}
}

/**
* Represents a streamed chunk of a chat completion response returned by model, based
* on the provided input.
Expand Down Expand Up @@ -614,13 +662,15 @@ public record ChatCompletionChunk(
* @param index The index of the choice in the list of choices.
* @param delta A chat completion delta generated by streamed model responses.
* @param finishReason The reason the model stopped generating tokens.
* @param logprobs Log probability information for the choice.
*/
@JsonInclude(Include.NON_NULL)
public record ChunkChoice(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("delta") ChatCompletionMessage delta,
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("logprobs") LogProbs logprobs) {
// @formatter:on
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.mistralai.api.MistralAiApi.LogProbs;
import org.springframework.util.CollectionUtils;

/**
Expand Down Expand Up @@ -83,8 +84,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
.toList();

var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT;
current = new ChunkChoice(current.index(), new ChatCompletionMessage(current.delta().content(),
role, current.delta().name(), toolCallsWithID), current.finishReason());
current = new ChunkChoice(
current.index(), new ChatCompletionMessage(current.delta().content(), role,
current.delta().name(), toolCallsWithID),
current.finishReason(), current.logprobs());
}
}
return current;
Expand All @@ -95,8 +98,9 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
Integer index = (current.index() != null ? current.index() : previous.index());

ChatCompletionMessage message = merge(previous.delta(), current.delta());
LogProbs logprobs = (current.logprobs() != null ? current.logprobs() : previous.logprobs());

return new ChunkChoice(index, message, finishReason);
return new ChunkChoice(index, message, finishReason, logprobs);
}

private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public void beforeEach() {
public void mistralAiChatTransientError() {

var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
ChatCompletionFinishReason.STOP);
ChatCompletionFinishReason.STOP, null);
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model",
List.of(choice), new MistralAiApi.Usage(10, 10, 10));

Expand Down Expand Up @@ -137,7 +137,7 @@ public void mistralAiChatNonTransientError() {
public void mistralAiChatStreamTransientError() {

var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
ChatCompletionFinishReason.STOP);
ChatCompletionFinishReason.STOP, null);
ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l,
"model", List.of(choice));

Expand Down

0 comments on commit 5f9ecdd

Please sign in to comment.