Skip to content

Commit

Permalink
Merge pull request #37 from matlab-deep-learning/refactor-models
Browse files Browse the repository at this point in the history
Move model capabilitiy verification out of openAIChat.m, for maintain…
  • Loading branch information
ccreutzi authored May 21, 2024
2 parents f4f97b9 + 6d34b1f commit 1365049
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 24 deletions.
12 changes: 12 additions & 0 deletions +llms/+openai/models.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function models = models
%MODELS - supported OpenAI models

% Copyright 2024 The MathWorks, Inc.
models = [...
"gpt-4o","gpt-4o-2024-05-13",...
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
"gpt-4","gpt-4-0613", ...
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
"gpt-3.5-turbo-1106",...
];
end
13 changes: 13 additions & 0 deletions +llms/+openai/validateMessageSupported.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function validateMessageSupported(message, model);
%validateMessageSupported - check that message is supported by model

% Copyright 2024 The MathWorks, Inc.

% only certain models support image generation
if iscell(message.content) && any(cellfun(@(x) isfield(x,"image_url"), message.content))
if ~ismember(model,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"])
error("llms:invalidContentTypeForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", model));
end
end
end
16 changes: 16 additions & 0 deletions +llms/+openai/validateResponseFormat.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function validateResponseFormat(format,model)
%validateResponseFormat - validate requested response format is available for selected model
% Not all OpenAI models support JSON output

% Copyright 2024 The MathWorks, Inc.

if format == "json"
if ismember(model,["gpt-4","gpt-4-0613"])
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", model));
else
warning("llms:warningJsonInstruction", ...
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
end
end
end
Binary file not shown.
31 changes: 8 additions & 23 deletions openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,7 @@
arguments
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.ModelName (1,1) string {mustBeMember(nvp.ModelName,[...
"gpt-4o","gpt-4o-2024-05-13",...
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
"gpt-4","gpt-4-0613", ...
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
"gpt-3.5-turbo-1106",...
])} = "gpt-3.5-turbo"
nvp.ModelName (1,1) string {mustBeModel} = "gpt-3.5-turbo"
nvp.Temperature {mustBeValidTemperature} = 1
nvp.TopProbabilityMass {mustBeValidTopP} = 1
nvp.StopSequences {mustBeValidStop} = {}
Expand Down Expand Up @@ -160,16 +154,8 @@
this.StopSequences = nvp.StopSequences;

% ResponseFormat is only supported in the latest models only
if nvp.ResponseFormat == "json"
if ismember(this.ModelName,["gpt-4","gpt-4-0613"])
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", this.ModelName));
else
warning("llms:warningJsonInstruction", ...
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
end

end
llms.openai.validateResponseFormat(nvp.ResponseFormat, this.ModelName);
this.ResponseFormat = nvp.ResponseFormat;

this.PresencePenalty = nvp.PresencePenalty;
this.FrequencyPenalty = nvp.FrequencyPenalty;
Expand Down Expand Up @@ -219,12 +205,7 @@
messagesStruct = messages.Messages;
end

if iscell(messagesStruct{end}.content) && any(cellfun(@(x) isfield(x,"image_url"), messagesStruct{end}.content))
if ~ismember(this.ModelName,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"])
error("llms:invalidContentTypeForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", this.ModelName));
end
end
llms.openai.validateMessageSupported(messagesStruct{end}, this.ModelName);

if ~isempty(this.SystemPrompt)
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
Expand Down Expand Up @@ -334,3 +315,7 @@ function mustBeIntegerOrEmpty(value)
mustBeInteger(value)
end
end

function mustBeModel(model)
mustBeMember(model,llms.openai.models);
end
169 changes: 168 additions & 1 deletion tests/topenAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ function saveEnvVar(testCase)
end

properties(TestParameter)
ValidConstructorInput = iGetValidConstructorInput();
InvalidConstructorInput = iGetInvalidConstructorInput();
InvalidGenerateInput = iGetInvalidGenerateInput();
InvalidValuesSetters = iGetInvalidValuesSetters();
Expand Down Expand Up @@ -65,6 +66,21 @@ function constructChatWithAllNVP(testCase)
testCase.verifyEqual(chat.PresencePenalty, presenceP);
end

function validConstructorCalls(testCase,ValidConstructorInput)
if isempty(ValidConstructorInput.ExpectedWarning)
chat = testCase.verifyWarningFree(...
@() openAIChat(ValidConstructorInput.Input{:}));
else
chat = testCase.verifyWarning(...
@() openAIChat(ValidConstructorInput.Input{:}), ...
ValidConstructorInput.ExpectedWarning);
end
properties = ValidConstructorInput.VerifyProperties;
for prop=string(fieldnames(properties)).'
testCase.verifyEqual(chat.(prop),properties.(prop),"Property " + prop);
end
end

function verySmallTimeOutErrors(testCase)
chat = openAIChat(TimeOut=0.0001, ApiKey="false-key");

Expand Down Expand Up @@ -126,7 +142,6 @@ function noStopSequencesNoMaxNumTokens(testCase)
end

function createOpenAIChatWithStreamFunc(testCase)

function seen = sf(str)
persistent data;
if isempty(data)
Expand Down Expand Up @@ -275,6 +290,158 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase)
"Error", "MATLAB:notGreaterEqual"));
end

function validConstructorInput = iGetValidConstructorInput()
% while it is valid to provide the key via an environment variable,
% this test set does not use that, for easier setup
validFunction = openAIFunction("funName");
validConstructorInput = struct( ...
"JustKey", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key"}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"SystemPrompt", struct( ...
"Input",{{"system prompt","ApiKey","this-is-not-a-real-key"}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {{struct("role","system","content","system prompt")}}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"Temperature", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","Temperature",2}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {2}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"TopProbabilityMass", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","TopProbabilityMass",0.2}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {0.2}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"StopSequences", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","StopSequences",["foo","bar"]}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {["foo","bar"]}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"PresencePenalty", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","PresencePenalty",0.1}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0.1}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"FrequencyPenalty", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","FrequencyPenalty",0.1}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0.1}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"TimeOut", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","TimeOut",0.1}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {0.1}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"ResponseFormat", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","ResponseFormat","json"}}, ...
"ExpectedWarning", "llms:warningJsonInstruction", ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"json"} ...
) ...
) ...
);
end

function invalidConstructorInput = iGetInvalidConstructorInput()
validFunction = openAIFunction("funName");
invalidConstructorInput = struct( ...
Expand Down

0 comments on commit 1365049

Please sign in to comment.