Skip to content

Commit

Permalink
Adding support to Azure API
Browse files Browse the repository at this point in the history
  • Loading branch information
Angel Vega Alvarez committed Feb 9, 2024
1 parent a038630 commit a32e681
Show file tree
Hide file tree
Showing 11 changed files with 919 additions and 109 deletions.
138 changes: 138 additions & 0 deletions +llms/+internal/callAzureChatAPI.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
function [text, message, response] = callAzureChatAPI(resourceName, deploymentID, messages, functions, nvp)
%callOpenAIChatAPI Calls the openAI chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the OpenAI Chat Completions API.
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% Currently, the supported NVP are, including the equivalent name in the API:
% - ToolChoice (tool_choice)
% - Temperature (temperature)
% - TopProbabilityMass (top_p)
% - NumCompletions (n)
% - StopSequences (stop)
% - MaxNumTokens (max_tokens)
% - PresencePenalty (presence_penalty)
% - FrequencyPenalty (frequence_penalty)
% - ResponseFormat (response_format)
% - Seed (seed)
% - ApiKey
% - TimeOut
% - StreamFun
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
%
% Example
%
% % Create messages struct
% messages = {struct("role", "system",...
% "content", "You are a helpful assistant");
% struct("role", "user", ...
% "content", "What is the edit distance between hi and hello?")};
%
% % Create functions struct
% functions = {struct("name", "editDistance", ...
% "description", "Find edit distance between two strings or documents.", ...
% "parameters", struct( ...
% "type", "object", ...
% "properties", struct(...
% "str1", struct(...
% "description", "Source string.", ...
% "type", "string"),...
% "str2", struct(...
% "description", "Target string.", ...
% "type", "string")),...
% "required", ["str1", "str2"]))};
%
% % Define your API key
% apiKey = "your-api-key-here"
%
% % Send a request
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)

% Copyright 2023-2024 The MathWorks, Inc.

arguments
resourceName
deploymentID
messages
functions
nvp.ToolChoice = []
nvp.APIVersion = "2023-05-15"
nvp.Temperature = 1
nvp.TopProbabilityMass = 1
nvp.NumCompletions = 1
nvp.StopSequences = []
nvp.MaxNumTokens = inf
nvp.PresencePenalty = 0
nvp.FrequencyPenalty = 0
nvp.ResponseFormat = "text"
nvp.Seed = []
nvp.ApiKey = ""
nvp.TimeOut = 10
nvp.StreamFun = []
end

END_POINT = "https://" + resourceName + ".openai.azure.com/openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;

parameters = buildParametersCall(messages, functions, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
if response.StatusCode=="OK"
% Outputs the first generation
if isempty(nvp.StreamFun)
message = response.Body.Data.choices(1).message;
else
message = struct("role", "assistant", ...
"content", streamedText);
end
if isfield(message, "tool_choice")
text = "";
else
text = string(message.content);
end
else
text = "";
message = struct();
end
end

function parameters = buildParametersCall(messages, functions, nvp)
% Builds a struct in the format that is expected by the API, combining
% MESSAGES, FUNCTIONS and parameters in NVP.

parameters = struct();
parameters.messages = messages;

parameters.stream = ~isempty(nvp.StreamFun);

parameters.tools = functions;

parameters.tool_choice = nvp.ToolChoice;

if ~isempty(nvp.Seed)
parameters.seed = nvp.Seed;
end

dict = mapNVPToParameters;

nvpOptions = keys(dict);
for opt = nvpOptions.'
if isfield(nvp, opt)
parameters.(dict(opt)) = nvp.(opt);
end
end
end

function dict = mapNVPToParameters()
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopProbabilityMass") = "top_p";
dict("NumCompletions") = "n";
dict("StopSequences") = "stop";
dict("MaxNumTokens") = "max_tokens";
dict("PresencePenalty") = "presence_penalty";
dict("FrequencyPenalty") = "frequency_penalty";
end
90 changes: 90 additions & 0 deletions +llms/+internal/textGenerator.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
classdef (Abstract) textGenerator

properties
%TEMPERATURE Temperature of generation.
Temperature

%TOPPROBABILITYMASS Top probability mass to consider for generation.
TopProbabilityMass

%STOPSEQUENCES Sequences to stop the generation of tokens.
StopSequences

%PRESENCEPENALTY Penalty for using a token in the response that has already been used.
PresencePenalty

%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
FrequencyPenalty
end

properties (SetAccess=protected)
%TIMEOUT Connection timeout in seconds (default 10 secs)
TimeOut

%FUNCTIONNAMES Names of the functions that the model can request calls
FunctionNames

%SYSTEMPROMPT System prompt.
SystemPrompt = []

%RESPONSEFORMAT Response format, "text" or "json"
ResponseFormat
end

properties (Access=protected)
Tools
FunctionsStruct
ApiKey
StreamFun
end


methods
function this = set.Temperature(this, temperature)
arguments
this
temperature
end
llms.utils.mustBeValidTemperature(temperature);
this.Temperature = temperature;
end

function this = set.TopProbabilityMass(this,topP)
arguments
this
topP
end
llms.utils.mustBeValidTopP(topP);
this.TopProbabilityMass = topP;
end

function this = set.StopSequences(this,stop)
arguments
this
stop
end
llms.utils.mustBeValidStop(stop);
this.StopSequences = stop;
end

function this = set.PresencePenalty(this,penalty)
arguments
this
penalty
end
llms.utils.mustBeValidPenalty(penalty)
this.PresencePenalty = penalty;
end

function this = set.FrequencyPenalty(this,penalty)
arguments
this
penalty
end
llms.utils.mustBeValidPenalty(penalty)
this.FrequencyPenalty = penalty;
end

end

end
3 changes: 3 additions & 0 deletions +llms/+utils/errorMessageCatalog.m
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@
catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'";
catalog("llms:pngExpected") = "Argument must be a PNG image.";
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
catalog("llms:invalidOptionsForOpenAIBackEnd") = "The parameters Resource Name, Deployment ID and API Version are not compatible with OpenAI.";
catalog("llms:invalidOptionsForAzureBackEnd") = "The parameter Model Name is not compatible with Azure.";

end
3 changes: 3 additions & 0 deletions +llms/+utils/mustBeValidPenalty.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function mustBeValidPenalty(value)
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonsparse', '<=', 2, '>=', -2})
end
10 changes: 10 additions & 0 deletions +llms/+utils/mustBeValidStop.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
function mustBeValidStop(value)
if ~isempty(value)
mustBeVector(value);
mustBeNonzeroLengthText(value);
% This restriction is set by the OpenAI API
if numel(value)>4
error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements"));
end
end
end
3 changes: 3 additions & 0 deletions +llms/+utils/mustBeValidTemperature.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function mustBeValidTemperature(value)
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 2})
end
3 changes: 3 additions & 0 deletions +llms/+utils/mustBeValidTopP.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function mustBeValidTopP(value)
validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 1})
end
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,35 @@ messages = addUserMessageWithImages(messages,"What is in the image?",image_path)
% Should output the description of the image
```
## Establishing a connection to Chat Completions API using Azure®
If you would like to connect MATLAB to Chat Completions API via Azure® instead of directly with OpenAI, you will have to create an `azureChat` object.
However, you first need to obtain, in addition to the Azure API keys, your Azure OpenAI Resource.
In order to create the chat assistant, you must specify your Azure OpenAI Resource and the LLM you want to use:
```matlab
chat = azureChat(YOUR_RESOURCE_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant");
```
The `azureChat` object also allows to specify additional options in the same way as the `openAIChat` object.
However, the `ModelName` option is not available due to the fact that the name of the LLM is already specified when creating the chat assistant.
On the other hand, the `azureChat` object offers an additional option that allows you to set the API version that you want to use for the operation.
After establishing your connection with Azure, you can continue using the `generate` function and other objects in the same way as if you had established a connection directly with OpenAI:
```matlab
% Initialize the Azure Chat object, passing a system prompt and specifying the API version
chat = azureChat(YOUR_RESOURCE_NAME, YOUR_DEPLOYMENT_NAME, "You are a helpful AI assistant", APIVersion="2023-12-01-preview");
% Create an openAIMessages object to start the conversation history
history = openAIMessages;
% Ask your question and store it in the history, create the response using the generate method, and store the response in the history
history = addUserMessage(history,"What is an eigenvalue?");
[txt, response] = generate(chat, history)
history = addResponseMessage(history, response);
```
### Obtaining embeddings
You can extract embeddings from your text with OpenAI using the function `extractOpenAIEmbeddings` as follows:
Expand Down
Loading

0 comments on commit a32e681

Please sign in to comment.