diff --git a/api/src/main/java/com/launchableinc/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/launchableinc/openai/completion/chat/ChatCompletionRequest.java index 4bc5f5da..54332ce0 100644 --- a/api/src/main/java/com/launchableinc/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/launchableinc/openai/completion/chat/ChatCompletionRequest.java @@ -27,6 +27,14 @@ public class ChatCompletionRequest { */ List messages; + /** + * An object specifying the format that the model must output. Compatible with GPT-4 Turbo and all + * GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. + * https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format + */ + @JsonProperty("response_format") + ChatResponseFormat responseFormat; + /** * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output * more random, while lower values like 0.2 will make it more focused and deterministic.
We diff --git a/api/src/main/java/com/launchableinc/openai/completion/chat/ChatResponseFormat.java b/api/src/main/java/com/launchableinc/openai/completion/chat/ChatResponseFormat.java new file mode 100644 index 00000000..19048aac --- /dev/null +++ b/api/src/main/java/com/launchableinc/openai/completion/chat/ChatResponseFormat.java @@ -0,0 +1,31 @@ +package com.launchableinc.openai.completion.chat; + +import com.fasterxml.jackson.annotation.JsonValue; +import lombok.Builder; +import lombok.Data; + +/* + * OpenAI API Document:https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format + */ +@Data +@Builder +public class ChatResponseFormat { + + private ResponseFormat type; + + public enum ResponseFormat { + TEXT("text"), JSON("json_object"); + + private final String value; + + ResponseFormat(String value) { + this.value = value; + } + + @JsonValue + public String getValue() { + return value; + } + } + +} diff --git a/service/src/test/java/com/launchableinc/openai/service/ChatCompletionTest.java b/service/src/test/java/com/launchableinc/openai/service/ChatCompletionTest.java index dc624533..76295293 100644 --- a/service/src/test/java/com/launchableinc/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/launchableinc/openai/service/ChatCompletionTest.java @@ -2,9 +2,12 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.launchableinc.openai.completion.chat.*; +import com.launchableinc.openai.completion.chat.ChatResponseFormat.ResponseFormat; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -77,6 +80,37 @@ void createChatCompletion() { assertEquals(5, choices.size()); } + @Test + void createChatCompletion_with_json_mode() { + final List messages = new ArrayList<>(); + final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), + "Generate a random name and age json object. name field is a object that has first and last fields. age is a number."); + messages.add(systemMessage); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-3.5-turbo-1106") + .messages(messages) + .maxTokens(50) + .logitBias(new HashMap<>()) + .responseFormat(ChatResponseFormat.builder().type(ResponseFormat.JSON).build()) + .build(); + + ChatCompletionChoice choices = service.createChatCompletion(chatCompletionRequest) + .getChoices().get(0); + assertTrue(isValidJson(choices.getMessage().getContent())); + } + + private boolean isValidJson(String jsonString) { + ObjectMapper objectMapper = new ObjectMapper(); + try { + objectMapper.readTree(jsonString); + return true; + } catch (JsonProcessingException e) { + return false; + } + } + @Test void streamChatCompletion() { final List messages = new ArrayList<>();