Skip to content

Commit 273a6af

Browse files
committed
Addressing code review feedback.
1 parent a5c3f10 commit 273a6af

9 files changed

+161
-133
lines changed

+llms/+internal/callOpenAIChatAPI.m

+88-58
Original file line numberDiff line numberDiff line change
@@ -3,93 +3,123 @@
33

44
%callOpenAIChatAPI Calls the openAI chat completions API.
55
%
6-
% MESSAGES and FUNCTIONS should be structs matching the json format
6+
% MESSAGES and FUNCTIONS should be structs matching the json format
77
% required by the OpenAI Chat Completions API.
88
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
99
%
10-
% Currently, the supported NVP are, including the equivalent name in the API:
11-
% - FunctionCall (function_call)
10+
% Currently, the supported NVP are, including the equivalent name in the API:
11+
% - FunctionCall (function_call)
1212
% - ModelName (model)
1313
% - Temperature (temperature)
14-
% - TopProbabilityMass (top_p)
14+
% - TopProbabilityMass (top_p)
1515
% - NumCompletions (n)
16-
% - StopSequences (stop)
17-
% - MaxNumTokens (max_tokens)
18-
% - PresencePenalty (presence_penalty)
19-
% - FrequencyPenalty (frequence_penalty)
20-
% - ApiKey
16+
% - StopSequences (stop)
17+
% - MaxNumTokens (max_tokens)
18+
% - PresencePenalty (presence_penalty)
19+
% - FrequencyPenalty (frequence_penalty)
20+
% - ApiKey
2121
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
22+
%
23+
% Example
24+
%
25+
% % Create messages struct
26+
% messages = {struct("role", "system",...
27+
% "content", "You are a helpful assistant");
28+
% struct("role", "user", ...
29+
% "content", "What is the edit distance between hi and hello?")};
30+
%
31+
% % Create functions struct
32+
% functions = {struct("name", "editDistance", ...
33+
% "description", "Find edit distance between two strings or documents.", ...
34+
% "parameters", struct( ...
35+
% "type", "object", ...
36+
% "properties", struct(...
37+
% "str1", struct(...
38+
% "description", "Source string.", ...
39+
% "type", "string"),...
40+
% "str2", struct(...
41+
% "description", "Target string.", ...
42+
% "type", "string")),...
43+
% "required", ["str1", "str2"]))};
44+
%
45+
% % Define your API key
46+
% apiKey = "your-api-key-here"
47+
%
48+
% % Send a request
49+
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
2250

2351
% Copyright 2023 The MathWorks, Inc.
2452

2553
arguments
2654
messages
2755
functions
28-
nvp.FunctionCall
29-
nvp.ModelName
30-
nvp.Temperature
31-
nvp.TopProbabilityMass
32-
nvp.NumCompletions
33-
nvp.StopSequences
34-
nvp.MaxNumTokens
35-
nvp.PresencePenalty
36-
nvp.FrequencyPenalty
37-
nvp.ApiKey
56+
nvp.FunctionCall = []
57+
nvp.ModelName = "gpt-3.5-turbo"
58+
nvp.Temperature = 1
59+
nvp.TopProbabilityMass = 1
60+
nvp.NumCompletions = 1
61+
nvp.StopSequences = []
62+
nvp.MaxNumTokens = inf
63+
nvp.PresencePenalty = 0
64+
nvp.FrequencyPenalty = 0
65+
nvp.ApiKey = ""
3866
end
3967

40-
END_POINT = "https://api.openai.com/v1/chat/completions";
68+
END_POINT = "https://api.openai.com/v1/chat/completions";
4169

42-
parameters = buildParametersCall(messages, functions, nvp);
70+
parameters = buildParametersCall(messages, functions, nvp);
4371

44-
response = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT);
72+
response = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT);
4573

46-
if response.StatusCode=="OK"
47-
message = response.Body.Data.choices(1).message;
48-
if isfield(message, "function_call")
49-
text = "";
50-
message.function_call.arguments = message.function_call.arguments;
51-
else
52-
text = string(message.content);
53-
end
54-
else
74+
% If call errors, "choices" will not be part of response.Body.Data, instead
75+
% we get response.Body.Data.error
76+
if response.StatusCode=="OK"
77+
% Outputs the first generation
78+
message = response.Body.Data.choices(1).message;
79+
if isfield(message, "function_call")
5580
text = "";
56-
message = struct();
81+
else
82+
text = string(message.content);
5783
end
84+
else
85+
text = "";
86+
message = struct();
87+
end
5888
end
5989

6090
function parameters = buildParametersCall(messages, functions, nvp)
61-
% Builds a struct in the format that is expected by the API, combining
62-
% MESSAGES, FUNCTIONS and parameters in NVP.
91+
% Builds a struct in the format that is expected by the API, combining
92+
% MESSAGES, FUNCTIONS and parameters in NVP.
6393

64-
parameters = struct();
65-
parameters.messages = messages;
66-
if ~isempty(functions)
67-
parameters.functions = functions;
68-
end
94+
parameters = struct();
95+
parameters.messages = messages;
96+
if ~isempty(functions)
97+
parameters.functions = functions;
98+
end
6999

70-
if ~isempty(nvp.FunctionCall)
71-
parameters.function_call = nvp.FunctionCall;
72-
end
100+
if ~isempty(nvp.FunctionCall)
101+
parameters.function_call = nvp.FunctionCall;
102+
end
103+
104+
parameters.model = nvp.ModelName;
73105

74-
parameters.model = nvp.ModelName;
106+
dict = mapNVPToParameters;
75107

76-
dict = mapNVPToParameters;
77-
78-
nvpOptions = keys(dict);
79-
for i=1:length(nvpOptions)
80-
if isfield(nvp, nvpOptions(i))
81-
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
82-
end
108+
nvpOptions = keys(dict);
109+
for i=1:length(nvpOptions)
110+
if isfield(nvp, nvpOptions(i))
111+
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
83112
end
84113
end
114+
end
85115

86116
function dict = mapNVPToParameters()
87-
dict = dictionary();
88-
dict("Temperature") = "temperature";
89-
dict("TopProbabilityMass") = "top_p";
90-
dict("NumCompletions") = "n";
91-
dict("StopSequences") = "stop";
92-
dict("MaxNumTokens") = "max_tokens";
93-
dict("PresencePenalty") = "presence_penalty";
94-
dict("FrequencyPenalty ") = "frequency_penalty";
117+
dict = dictionary();
118+
dict("Temperature") = "temperature";
119+
dict("TopProbabilityMass") = "top_p";
120+
dict("NumCompletions") = "n";
121+
dict("StopSequences") = "stop";
122+
dict("MaxNumTokens") = "max_tokens";
123+
dict("PresencePenalty") = "presence_penalty";
124+
dict("FrequencyPenalty ") = "frequency_penalty";
95125
end

+llms/+internal/checkEnvOrNVP.m renamed to +llms/+internal/getApiKeyFromNvpOrEnv.m

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
function key = checkEnvOrNVP(nvp)
1+
function key = getApiKeyFromNvpOrEnv(nvp)
22
% This function is undocumented and will change in a future release
33

4-
%checkEnvOrNVP Retrieves an API key from a Name-Value Pair struct or environment variable.
4+
%getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable.
55
%
66
% This function takes a struct nvp containing name-value pairs and checks
77
% if it contains a field called "ApiKey". If the field is not found,

+llms/+utils/errorMessageCatalog.m

+7-7
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@
4040
catalog("llms:mustBeVarName") = "Parameter name must begin with a letter and contain not more than 'namelengthmax' characters.";
4141
catalog("llms:parameterMustBeUnique") = "A parameter name equivalent to '{1}' already exists in Parameters. Redefining a parameter is not allowed.";
4242
catalog("llms:mustBeAssistantCall") = "Input struct must contain field 'role' with value 'assistant', and field 'content'.";
43-
catalog("llms:mustBeAssistantWithContent") = "Input struct must contain field 'content' containing text with one or more characters";
44-
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function_call' must be a struct with fields 'name' and 'arguments'";
45-
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters";
46-
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1})";
43+
catalog("llms:mustBeAssistantWithContent") = "Input struct must contain field 'content' containing text with one or more characters.";
44+
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function_call' must be a struct with fields 'name' and 'arguments'.";
45+
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
46+
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
4747
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
48-
catalog("llms:keyMustBeSpecified") = "API key not found as enviroment variable OPEN_API_KEY and not specified via ApiKey parameter.";
49-
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages";
48+
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPEN_API_KEY and not specified via ApiKey parameter.";
49+
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
5050
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, FunctionCall must not be specified.";
51-
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects";
51+
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";
5252
end
5353

examples/ExampleAgentCreation.mlx

2.59 KB
Binary file not shown.

license.txt

+1-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,4 @@ Redistribution and use in source and binary forms, with or without modification,
44
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
55
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
66
3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings.
7-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
8-
9-
10-
11-
7+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

openAIChat.m

+26-29
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
classdef(Sealed) openAIChat
2-
%openAIChat - Chat completion API from OpenAI.
2+
%openAIChat Chat completion API from OpenAI.
33
%
44
% CHAT = openAIChat(systemPrompt) creates an openAIChat object with the
55
% specified system prompt.
66
%
77
% CHAT = openAIChat(systemPrompt, Name=Value) specifies additional options
88
% using one or more name-value arguments:
99
%
10-
% 'Functions' - An array of openAIFunction objects representing
10+
% 'Functions' - An array of openAIFunction objects representing
1111
% custom functions to be used during chat completions.
1212
%
1313
% 'ModelName' - The name of the model to use for chat completions.
14-
% The default value is "gpt-3.5-turbo"
14+
% The default value is "gpt-3.5-turbo".
1515
%
1616
% 'Temperature' - The temperature value for controlling the randomness
17-
% of the output. The default value is 1
17+
% of the output. The default value is 1.
1818
%
1919
% 'TopProbabilityMass' - The top probability mass value for controlling the
20-
% diversity of the output. The default value is 1
20+
% diversity of the output. The default value is 1.
2121
%
2222
% 'StopSequences' - Vector of strings that when encountered, will
2323
% stop the generation of tokens. The default
@@ -36,11 +36,11 @@
3636
% generate - Generate a response using the openAIChat instance.
3737
%
3838
% openAIChat Properties:
39-
% ModelName - Model name
39+
% ModelName - Model name.
4040
%
41-
% Temperature - Temperature of generation
41+
% Temperature - Temperature of generation.
4242
%
43-
% TopProbabilityMass - Top probability mass to consider for generation
43+
% TopProbabilityMass - Top probability mass to consider for generation.
4444
%
4545
% StopSequences - Sequences to stop the generation of tokens.
4646
%
@@ -50,12 +50,12 @@
5050
% FrequencyPenalty - Penalty for using a token that is
5151
% frequent in the training data.
5252
%
53-
% SystemPrompt - System prompt
53+
% SystemPrompt - System prompt.
5454
%
5555
% AvailableModels - List of available models.
5656
%
57-
% FunctionsNames - Names of the functions that the model can
58-
% request calls
57+
% FunctionNames - Names of the functions that the model can
58+
% request calls.
5959

6060
% Copyright 2023 The MathWorks, Inc.
6161

@@ -83,8 +83,8 @@
8383
end
8484

8585
properties(SetAccess=private)
86-
%FUNCTIONSNAMES Names of the functions that the model can request calls
87-
FunctionsNames
86+
%FUNCTIONNAMES Names of the functions that the model can request calls
87+
FunctionNames
8888
end
8989

9090
properties(Access=private)
@@ -105,7 +105,9 @@
105105
arguments
106106
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
107107
nvp.Functions (1,:) {mustBeA(nvp.Functions, "openAIFunction")} = openAIFunction.empty
108-
nvp.ModelName (1,1) {mustBeValidModelName} = "gpt-3.5-turbo"
108+
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613",...
109+
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k",...
110+
"gpt-3.5-turbo-16k-0613"])} = "gpt-3.5-turbo"
109111
nvp.Temperature (1,1) {mustBeValidTemperature} = 1
110112
nvp.TopProbabilityMass (1,1) {mustBeValidTopP} = 1
111113
nvp.StopSequences (1,:) {mustBeValidStop} = {}
@@ -116,11 +118,11 @@
116118

117119
if ~isempty(nvp.Functions)
118120
this.Functions = nvp.Functions;
119-
[this.FunctionsStruct, this.FunctionsNames] = functionAsStruct(nvp.Functions);
121+
[this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Functions);
120122
else
121123
this.Functions = [];
122124
this.FunctionsStruct = [];
123-
this.FunctionsNames = [];
125+
this.FunctionNames = [];
124126
end
125127

126128
if ~isempty(systemPrompt)
@@ -136,7 +138,7 @@
136138
this.StopSequences = nvp.StopSequences;
137139
this.PresencePenalty = nvp.PresencePenalty;
138140
this.FrequencyPenalty = nvp.FrequencyPenalty;
139-
this.ApiKey = llms.internal.checkEnvOrNVP(nvp);
141+
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp);
140142
end
141143

142144
function [text, message, response] = generate(this, messages, nvp)
@@ -146,7 +148,7 @@
146148
% with the specified MESSAGES and optional
147149
% name-value pair arguments.
148150
%
149-
% [TEXT, MESSAGE, RESPONSE] = generate(_______, Name=Value) specifies additional options
151+
% [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options
150152
% using one or more name-value arguments:
151153
%
152154
% NumCompletions - Number of completions to generate.
@@ -230,17 +232,17 @@
230232
function mustBeValidFunctionCall(this, functionCall)
231233
if ~isempty(functionCall)
232234
mustBeTextScalar(functionCall);
233-
if isempty(this.FunctionsNames)
235+
if isempty(this.FunctionNames)
234236
error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall"));
235237
end
236-
mustBeMember(functionCall, ["none","auto", this.FunctionsNames]);
238+
mustBeMember(functionCall, ["none","auto", this.FunctionNames]);
237239
end
238240
end
239241

240242
function functionCall = convertFunctionCall(this, functionCall)
241243
% If functionCall is not empty, then it must be in
242244
% the format {"name", functionCall}
243-
if ~isempty(functionCall)&&ismember(functionCall, this.FunctionsNames)
245+
if ~isempty(functionCall)&&ismember(functionCall, this.FunctionNames)
244246
functionCall = struct("name", functionCall);
245247
end
246248

@@ -249,14 +251,14 @@ function mustBeValidFunctionCall(this, functionCall)
249251
end
250252

251253

252-
function [functionsStruct, functionsNames] = functionAsStruct(functions)
254+
function [functionsStruct, functionNames] = functionAsStruct(functions)
253255
numFunctions = numel(functions);
254256
functionsStruct = cell(1, numFunctions);
255-
functionsNames = strings(1, numFunctions);
257+
functionNames = strings(1, numFunctions);
256258

257259
for i = 1:numFunctions
258260
functionsStruct{i} = encodeStruct(functions(i));
259-
functionsNames(i) = functions(i).FunctionName;
261+
functionNames(i) = functions(i).FunctionName;
260262
end
261263
end
262264

@@ -284,11 +286,6 @@ function mustBeValidTopP(value)
284286
mustBeLessThanOrEqual(value,1);
285287
end
286288

287-
function mustBeValidModelName(value)
288-
mustBeNonzeroLengthText(value);
289-
mustBeMember(value,openAIChat.AvailableModels);
290-
end
291-
292289
function mustBeValidTemperature(value)
293290
mustBeNonnegative(value);
294291
mustBeLessThanOrEqual(value,2)

0 commit comments

Comments
 (0)