Skip to content

Commit 05c861b

Browse files
authored
Merge pull request #8 from matlab-deep-learning/AzureAPI
Adding support to Azure API
2 parents 38edd99 + 1ac24ff commit 05c861b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+3671
-1397
lines changed

+llms/+azure/apiVersions.m

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
function versions = apiVersions
2+
%VERSIONS - supported azure API versions
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
versions = [...
6+
"2024-05-01-preview", ...
7+
"2024-04-01-preview", ...
8+
"2024-03-01-preview", ...
9+
"2024-02-01", ...
10+
"2023-05-15", ...
11+
];
12+
end

+llms/+internal/callAzureChatAPI.m

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp)
2+
% This function is undocumented and will change in a future release
3+
4+
%callAzureChatAPI Calls the openAI chat completions API on Azure.
5+
%
6+
% MESSAGES and FUNCTIONS should be structs matching the json format
7+
% required by the OpenAI Chat Completions API.
8+
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
9+
%
10+
% More details on the parameters: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt
11+
%
12+
% Example
13+
%
14+
% % Create messages struct
15+
% messages = {struct("role", "system",...
16+
% "content", "You are a helpful assistant");
17+
% struct("role", "user", ...
18+
% "content", "What is the edit distance between hi and hello?")};
19+
%
20+
% % Create functions struct
21+
% functions = {struct("name", "editDistance", ...
22+
% "description", "Find edit distance between two strings or documents.", ...
23+
% "parameters", struct( ...
24+
% "type", "object", ...
25+
% "properties", struct(...
26+
% "str1", struct(...
27+
% "description", "Source string.", ...
28+
% "type", "string"),...
29+
% "str2", struct(...
30+
% "description", "Target string.", ...
31+
% "type", "string")),...
32+
% "required", ["str1", "str2"]))};
33+
%
34+
% % Define your API key
35+
% apiKey = "your-api-key-here"
36+
%
37+
% % Send a request
38+
% [text, message] = llms.internal.callAzureChatAPI(messages, functions, APIKey=apiKey)
39+
40+
% Copyright 2023-2024 The MathWorks, Inc.
41+
42+
arguments
43+
endpoint
44+
deploymentID
45+
messages
46+
functions
47+
nvp.ToolChoice
48+
nvp.APIVersion
49+
nvp.Temperature
50+
nvp.TopP
51+
nvp.NumCompletions
52+
nvp.StopSequences
53+
nvp.MaxNumTokens
54+
nvp.PresencePenalty
55+
nvp.FrequencyPenalty
56+
nvp.ResponseFormat
57+
nvp.Seed
58+
nvp.APIKey
59+
nvp.TimeOut
60+
nvp.StreamFun
61+
end
62+
63+
URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;
64+
65+
parameters = buildParametersCall(messages, functions, nvp);
66+
67+
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun);
68+
69+
% If call errors, "choices" will not be part of response.Body.Data, instead
70+
% we get response.Body.Data.error
71+
if response.StatusCode=="OK"
72+
% Outputs the first generation
73+
if isempty(nvp.StreamFun)
74+
message = response.Body.Data.choices(1).message;
75+
else
76+
message = struct("role", "assistant", ...
77+
"content", streamedText);
78+
end
79+
if isfield(message, "tool_choice")
80+
text = "";
81+
else
82+
text = string(message.content);
83+
end
84+
else
85+
text = "";
86+
message = struct();
87+
end
88+
end
89+
90+
function parameters = buildParametersCall(messages, functions, nvp)
91+
% Builds a struct in the format that is expected by the API, combining
92+
% MESSAGES, FUNCTIONS and parameters in NVP.
93+
94+
parameters = struct();
95+
parameters.messages = messages;
96+
97+
parameters.stream = ~isempty(nvp.StreamFun);
98+
99+
if ~isempty(functions)
100+
parameters.tools = functions;
101+
end
102+
103+
if ~isempty(nvp.ToolChoice)
104+
parameters.tool_choice = nvp.ToolChoice;
105+
end
106+
107+
if ~isempty(nvp.Seed)
108+
parameters.seed = nvp.Seed;
109+
end
110+
111+
dict = mapNVPToParameters;
112+
113+
nvpOptions = keys(dict);
114+
for opt = nvpOptions.'
115+
if isfield(nvp, opt)
116+
parameters.(dict(opt)) = nvp.(opt);
117+
end
118+
end
119+
end
120+
121+
function dict = mapNVPToParameters()
122+
dict = dictionary();
123+
dict("Temperature") = "temperature";
124+
dict("TopP") = "top_p";
125+
dict("NumCompletions") = "n";
126+
dict("StopSequences") = "stop";
127+
dict("MaxNumTokens") = "max_tokens";
128+
dict("PresencePenalty") = "presence_penalty";
129+
dict("FrequencyPenalty") = "frequency_penalty";
130+
end

+llms/+internal/callOllamaChatAPI.m

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
function [text, message, response] = callOllamaChatAPI(model, messages, nvp)
2+
% This function is undocumented and will change in a future release
3+
4+
%callOllamaChatAPI Calls the Ollama® chat completions API.
5+
%
6+
% MESSAGES and FUNCTIONS should be structs matching the json format
7+
% required by the Ollama Chat Completions API.
8+
% Ref: https://github.com/ollama/ollama/blob/main/docs/api.md
9+
%
10+
% More details on the parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
11+
%
12+
% Example
13+
%
14+
% model = "mistral";
15+
%
16+
% % Create messages struct
17+
% messages = {struct("role", "system",...
18+
% "content", "You are a helpful assistant");
19+
% struct("role", "user", ...
20+
% "content", "What is the edit distance between hi and hello?")};
21+
%
22+
% % Send a request
23+
% [text, message] = llms.internal.callOllamaChatAPI(model, messages)
24+
25+
% Copyright 2023-2024 The MathWorks, Inc.
26+
27+
arguments
28+
model
29+
messages
30+
nvp.Temperature
31+
nvp.TopP
32+
nvp.TopK
33+
nvp.TailFreeSamplingZ
34+
nvp.StopSequences
35+
nvp.MaxNumTokens
36+
nvp.ResponseFormat
37+
nvp.Seed
38+
nvp.TimeOut
39+
nvp.StreamFun
40+
end
41+
42+
URL = "http://localhost:11434/api/chat";
43+
44+
% The JSON for StopSequences must have an array, and cannot say "stop": "foo".
45+
% The easiest way to ensure that is to never pass in a scalar …
46+
if isscalar(nvp.StopSequences)
47+
nvp.StopSequences = [nvp.StopSequences, nvp.StopSequences];
48+
end
49+
50+
parameters = buildParametersCall(model, messages, nvp);
51+
52+
[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun);
53+
54+
% If call errors, "choices" will not be part of response.Body.Data, instead
55+
% we get response.Body.Data.error
56+
if response.StatusCode=="OK"
57+
% Outputs the first generation
58+
if isempty(nvp.StreamFun)
59+
message = response.Body.Data.message;
60+
else
61+
message = struct("role", "assistant", ...
62+
"content", streamedText);
63+
end
64+
text = string(message.content);
65+
else
66+
text = "";
67+
message = struct();
68+
end
69+
end
70+
71+
function parameters = buildParametersCall(model, messages, nvp)
72+
% Builds a struct in the format that is expected by the API, combining
73+
% MESSAGES, FUNCTIONS and parameters in NVP.
74+
75+
parameters = struct();
76+
parameters.model = model;
77+
parameters.messages = messages;
78+
79+
parameters.stream = ~isempty(nvp.StreamFun);
80+
81+
options = struct;
82+
if ~isempty(nvp.Seed)
83+
options.seed = nvp.Seed;
84+
end
85+
86+
dict = mapNVPToParameters;
87+
88+
nvpOptions = keys(dict);
89+
for opt = nvpOptions.'
90+
if isfield(nvp, opt) && ~isempty(nvp.(opt)) && ~isequaln(nvp.(opt),Inf)
91+
options.(dict(opt)) = nvp.(opt);
92+
end
93+
end
94+
95+
parameters.options = options;
96+
end
97+
98+
function dict = mapNVPToParameters()
99+
dict = dictionary();
100+
dict("Temperature") = "temperature";
101+
dict("TopP") = "top_p";
102+
dict("TopK") = "top_k";
103+
dict("TailFreeSamplingZ") = "tfs_z";
104+
dict("StopSequences") = "stop";
105+
dict("MaxNumTokens") = "num_predict";
106+
end

+llms/+internal/callOpenAIChatAPI.m

+19-32
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,12 @@
11
function [text, message, response] = callOpenAIChatAPI(messages, functions, nvp)
2+
% This function is undocumented and will change in a future release
3+
24
%callOpenAIChatAPI Calls the openAI chat completions API.
35
%
46
% MESSAGES and FUNCTIONS should be structs matching the json format
57
% required by the OpenAI Chat Completions API.
68
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
79
%
8-
% Currently, the supported NVP are, including the equivalent name in the API:
9-
% - ToolChoice (tool_choice)
10-
% - ModelName (model)
11-
% - Temperature (temperature)
12-
% - TopProbabilityMass (top_p)
13-
% - NumCompletions (n)
14-
% - StopSequences (stop)
15-
% - MaxNumTokens (max_tokens)
16-
% - PresencePenalty (presence_penalty)
17-
% - FrequencyPenalty (frequence_penalty)
18-
% - ResponseFormat (response_format)
19-
% - Seed (seed)
20-
% - ApiKey
21-
% - TimeOut
22-
% - StreamFun
2310
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
2411
%
2512
% Example
@@ -48,34 +35,34 @@
4835
% apiKey = "your-api-key-here"
4936
%
5037
% % Send a request
51-
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
38+
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, APIKey=apiKey)
5239

5340
% Copyright 2023-2024 The MathWorks, Inc.
5441

5542
arguments
5643
messages
5744
functions
58-
nvp.ToolChoice = []
59-
nvp.ModelName = "gpt-3.5-turbo"
60-
nvp.Temperature = 1
61-
nvp.TopProbabilityMass = 1
62-
nvp.NumCompletions = 1
63-
nvp.StopSequences = []
64-
nvp.MaxNumTokens = inf
65-
nvp.PresencePenalty = 0
66-
nvp.FrequencyPenalty = 0
67-
nvp.ResponseFormat = "text"
68-
nvp.Seed = []
69-
nvp.ApiKey = ""
70-
nvp.TimeOut = 10
71-
nvp.StreamFun = []
45+
nvp.ToolChoice
46+
nvp.ModelName
47+
nvp.Temperature
48+
nvp.TopP
49+
nvp.NumCompletions
50+
nvp.StopSequences
51+
nvp.MaxNumTokens
52+
nvp.PresencePenalty
53+
nvp.FrequencyPenalty
54+
nvp.ResponseFormat
55+
nvp.Seed
56+
nvp.APIKey
57+
nvp.TimeOut
58+
nvp.StreamFun
7259
end
7360

7461
END_POINT = "https://api.openai.com/v1/chat/completions";
7562

7663
parameters = buildParametersCall(messages, functions, nvp);
7764

78-
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
65+
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
7966

8067
% If call errors, "choices" will not be part of response.Body.Data, instead
8168
% we get response.Body.Data.error
@@ -160,7 +147,7 @@
160147
function dict = mapNVPToParameters()
161148
dict = dictionary();
162149
dict("Temperature") = "temperature";
163-
dict("TopProbabilityMass") = "top_p";
150+
dict("TopP") = "top_p";
164151
dict("NumCompletions") = "n";
165152
dict("StopSequences") = "stop";
166153
dict("MaxNumTokens") = "max_tokens";
+13-13
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
function key = getApiKeyFromNvpOrEnv(nvp)
1+
function key = getApiKeyFromNvpOrEnv(nvp,envVarName)
22
% This function is undocumented and will change in a future release
33

44
%getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable.
55
%
6-
% This function takes a struct nvp containing name-value pairs and checks
7-
% if it contains a field called "ApiKey". If the field is not found,
8-
% the function attempts to retrieve the API key from an environment
9-
% variable called "OPENAI_API_KEY". If both methods fail, the function
10-
% throws an error.
6+
% This function takes a struct nvp containing name-value pairs and checks if
7+
% it contains a field called "APIKey". If the field is not found, the
8+
% function attempts to retrieve the API key from an environment variable
9+
% whose name is given as the second argument. If both methods fail, the
10+
% function throws an error.
1111

12-
% Copyright 2023 The MathWorks, Inc.
12+
% Copyright 2023-2024 The MathWorks, Inc.
1313

14-
if isfield(nvp, "ApiKey")
15-
key = nvp.ApiKey;
14+
if isfield(nvp, "APIKey")
15+
key = nvp.APIKey;
1616
else
17-
if isenv("OPENAI_API_KEY")
18-
key = getenv("OPENAI_API_KEY");
17+
if isenv(envVarName)
18+
key = getenv(envVarName);
1919
else
20-
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified"));
20+
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified", envVarName));
2121
end
2222
end
23-
end
23+
end

+llms/+internal/gptPenalties.m

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
classdef (Abstract) gptPenalties
2+
% This class is undocumented and will change in a future release
3+
4+
% Copyright 2024 The MathWorks, Inc.
5+
properties
6+
%PRESENCEPENALTY Penalty for using a token in the response that has already been used.
7+
PresencePenalty {llms.utils.mustBeValidPenalty} = 0
8+
9+
%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
10+
FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0
11+
end
12+
end

0 commit comments

Comments
 (0)