Skip to content

Commit

Permalink
Fix Geminie GenerativeModel handling between calls
Browse files Browse the repository at this point in the history
 Resolves #560
  • Loading branch information
tzolov committed Apr 7, 2024
1 parent 2b421a4 commit e268975
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ public class VertexAiGeminiChatClient

private final GenerationConfig generationConfig;

private GenerativeModel generativeModel;

public enum GeminiMessageType {

USER("user"),
Expand Down Expand Up @@ -140,7 +138,6 @@ public VertexAiGeminiChatClient(VertexAI vertexAI, VertexAiGeminiChatOptions opt
this.vertexAI = vertexAI;
this.defaultOptions = options;
this.generationConfig = toGenerationConfig(options);
this.generativeModel = new GenerativeModel(options.getModel(), vertexAI);
}

// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
Expand Down Expand Up @@ -204,7 +201,8 @@ private GeminiRequest createGeminiRequest(Prompt prompt) {
Set<String> functionsForThisRequest = new HashSet<>();

GenerationConfig generationConfig = this.generationConfig;
GenerativeModel generativeModel = this.generativeModel;

GenerativeModel generativeModel = new GenerativeModel(this.defaultOptions.getModel(), this.vertexAI);

VertexAiGeminiChatOptions updatedRuntimeOptions = null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public void functionCallTestInferredOpenApiSchema() {
public void functionCallTestInferredOpenApiSchemaStream() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, in Paris and in Tokyo? Use Multi-turn function calling.");
"What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ void functionCallTest() {

assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");

response = chatClient
.call(new Prompt(List.of(systemMessage, userMessage), VertexAiGeminiChatOptions.builder().build()));

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15");

});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ void functionCallTest() {
logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");

// Verify that no function call is made.
response = chatClient
.call(new Prompt(List.of(systemMessage, userMessage), VertexAiGeminiChatOptions.builder().build()));

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).doesNotContain("30", "10", "15");

});
}

Expand Down

0 comments on commit e268975

Please sign in to comment.