Skip to content

Commit 7e789c3

Browse files
authored
Merge pull request #6 from matlab-deep-learning/dev-parallelfunctioncall-jsonmode-vision-dalle
Updating the API to the newest version.
2 parents a4baf34 + 510002a commit 7e789c3

19 files changed

+1141
-178
lines changed

+llms/+internal/callOpenAIChatAPI.m

+29-13
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
function [text, message, response] = callOpenAIChatAPI(messages, functions, nvp)
2-
% This function is undocumented and will change in a future release
3-
42
%callOpenAIChatAPI Calls the openAI chat completions API.
53
%
64
% MESSAGES and FUNCTIONS should be structs matching the json format
75
% required by the OpenAI Chat Completions API.
86
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
97
%
108
% Currently, the supported NVP are, including the equivalent name in the API:
11-
% - FunctionCall (function_call)
9+
% - ToolChoice (tool_choice)
1210
% - ModelName (model)
1311
% - Temperature (temperature)
1412
% - TopProbabilityMass (top_p)
@@ -17,6 +15,8 @@
1715
% - MaxNumTokens (max_tokens)
1816
% - PresencePenalty (presence_penalty)
1917
% - FrequencyPenalty (frequence_penalty)
18+
% - ResponseFormat (response_format)
19+
% - Seed (seed)
2020
% - ApiKey
2121
% - TimeOut
2222
% - StreamFun
@@ -50,12 +50,12 @@
5050
% % Send a request
5151
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
5252

53-
% Copyright 2023 The MathWorks, Inc.
53+
% Copyright 2023-2024 The MathWorks, Inc.
5454

5555
arguments
5656
messages
5757
functions
58-
nvp.FunctionCall = []
58+
nvp.ToolChoice = []
5959
nvp.ModelName = "gpt-3.5-turbo"
6060
nvp.Temperature = 1
6161
nvp.TopProbabilityMass = 1
@@ -64,6 +64,8 @@
6464
nvp.MaxNumTokens = inf
6565
nvp.PresencePenalty = 0
6666
nvp.FrequencyPenalty = 0
67+
nvp.ResponseFormat = "text"
68+
nvp.Seed = []
6769
nvp.ApiKey = ""
6870
nvp.TimeOut = 10
6971
nvp.StreamFun = []
@@ -85,7 +87,7 @@
8587
message = struct("role", "assistant", ...
8688
"content", streamedText);
8789
end
88-
if isfield(message, "function_call")
90+
if isfield(message, "tool_choice")
8991
text = "";
9092
else
9193
text = string(message.content);
@@ -105,22 +107,36 @@
105107

106108
parameters.stream = ~isempty(nvp.StreamFun);
107109

108-
if ~isempty(functions)
109-
parameters.functions = functions;
110+
if ~isempty(functions) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
111+
parameters.tools = functions;
112+
end
113+
114+
if ~isempty(nvp.ToolChoice) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
115+
parameters.tool_choice = nvp.ToolChoice;
116+
end
117+
118+
if ismember(nvp.ModelName,["gpt-3.5-turbo-1106","gpt-4-1106-preview"])
119+
if strcmp(nvp.ResponseFormat,"json")
120+
parameters.response_format = struct('type','json_object');
121+
end
110122
end
111123

112-
if ~isempty(nvp.FunctionCall)
113-
parameters.function_call = nvp.FunctionCall;
124+
if ~isempty(nvp.Seed)
125+
parameters.seed = nvp.Seed;
114126
end
115127

116128
parameters.model = nvp.ModelName;
117129

118130
dict = mapNVPToParameters;
119131

120132
nvpOptions = keys(dict);
121-
for i=1:length(nvpOptions)
122-
if isfield(nvp, nvpOptions(i))
123-
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
133+
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
134+
nvpOptions(ismember(nvpOptions,["MaxNumTokens","StopSequences"])) = [];
135+
end
136+
137+
for opt = nvpOptions.'
138+
if isfield(nvp, opt)
139+
parameters.(dict(opt)) = nvp.(opt);
124140
end
125141
end
126142
end

+llms/+utils/errorMessageCatalog.m

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
classdef errorMessageCatalog
2-
% This class is undocumented and will change in a future release
3-
42
%errorMessageCatalog Stores the error messages from this repository
53

6-
% Copyright 2023 The MathWorks, Inc.
4+
% Copyright 2023-2024 The MathWorks, Inc.
5+
76
properties(Constant)
87
%CATALOG dictionary mapping error ids to error msgs
98
Catalog = buildErrorMessageCatalog;
109
end
1110

1211
methods(Static)
1312
function msg = getMessage(messageId, slot)
14-
% This function is undocumented and will change in a future release
15-
1613
%getMessage returns error message given a messageID and a SLOT.
1714
% The value in SLOT should be ordered, where the n-th element
1815
% will replace the value "{n}".
@@ -41,13 +38,19 @@
4138
catalog("llms:parameterMustBeUnique") = "A parameter name equivalent to '{1}' already exists in Parameters. Redefining a parameter is not allowed.";
4239
catalog("llms:mustBeAssistantCall") = "Input struct must contain field 'role' with value 'assistant', and field 'content'.";
4340
catalog("llms:mustBeAssistantWithContent") = "Input struct must contain field 'content' containing text with one or more characters.";
44-
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function_call' must be a struct with fields 'name' and 'arguments'.";
41+
catalog("llms:mustBeAssistantWithIdAndFunction") = "Field 'tool_call' must be a struct with fields 'id' and 'function'.";
42+
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function' must be a struct with fields 'name' and 'arguments'.";
4543
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
4644
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
4745
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
4846
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter.";
4947
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
50-
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, FunctionCall must not be specified.";
48+
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified.";
5149
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";
52-
end
53-
50+
catalog("llms:invalidOptionAndValueForModel") = "'{1}' with value '{2}' is not supported for ModelName '{3}'";
51+
catalog("llms:invalidOptionForModel") = "{1} is not supported for ModelName '{2}'";
52+
catalog("llms:functionNotAvailableForModel") = "This function is not supported for ModelName '{1}'";
53+
catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'";
54+
catalog("llms:pngExpected") = "Argument must be a PNG image.";
55+
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
56+
end

.gitattributes

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
*.fig binary
2+
*.mat binary
3+
*.mdl binary diff merge=mlAutoMerge
4+
*.mdlp binary
5+
*.mexa64 binary
6+
*.mexw64 binary
7+
*.mexmaci64 binary
8+
*.mlapp binary
9+
*.mldatx binary
10+
*.mlproj binary
11+
*.mlx binary
12+
*.p binary
13+
*.sfx binary
14+
*.sldd binary
15+
*.slreqx binary merge=mlAutoMerge
16+
*.slmx binary merge=mlAutoMerge
17+
*.sltx binary
18+
*.slxc binary
19+
*.slx binary merge=mlAutoMerge
20+
*.slxp binary
21+
22+
## Other common binary file types
23+
*.docx binary
24+
*.exe binary
25+
*.jpg binary
26+
*.pdf binary
27+
*.png binary
28+
*.xlsx binary

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.env
2+
*.asv
3+
startup.m

0 commit comments

Comments
 (0)