Skip to content

Commit c1815ad

Browse files
committed
Merge branch 'streaming' into 'main'
Streaming See merge request dferreir/llms-with-matlab!12
2 parents 744e646 + b7d1a73 commit c1815ad

File tree

5 files changed

+97
-9
lines changed

5 files changed

+97
-9
lines changed

+llms/+internal/callOpenAIChatAPI.m

+12-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
% - FrequencyPenalty (frequence_penalty)
2020
% - ApiKey
2121
% - TimeOut
22+
% - StreamFun
2223
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
2324
%
2425
% Example
@@ -65,19 +66,25 @@
6566
nvp.FrequencyPenalty = 0
6667
nvp.ApiKey = ""
6768
nvp.TimeOut = 10
69+
nvp.StreamFun = []
6870
end
6971

7072
END_POINT = "https://api.openai.com/v1/chat/completions";
7173

7274
parameters = buildParametersCall(messages, functions, nvp);
7375

74-
response = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut);
76+
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
7577

7678
% If call errors, "choices" will not be part of response.Body.Data, instead
7779
% we get response.Body.Data.error
7880
if response.StatusCode=="OK"
7981
% Outputs the first generation
80-
message = response.Body.Data.choices(1).message;
82+
if isempty(nvp.StreamFun)
83+
message = response.Body.Data.choices(1).message;
84+
else
85+
message = struct("role", "assistant", ...
86+
"content", streamedText);
87+
end
8188
if isfield(message, "function_call")
8289
text = "";
8390
else
@@ -95,6 +102,9 @@
95102

96103
parameters = struct();
97104
parameters.messages = messages;
105+
106+
parameters.stream = ~isempty(nvp.StreamFun);
107+
98108
if ~isempty(functions)
99109
parameters.functions = functions;
100110
end

+llms/+internal/sendRequest.m

+14-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
function response = sendRequest(parameters, token, endpoint, timeout)
2-
% This function is undocumented and will change in a future release
3-
1+
function [response, streamedText] = sendRequest(parameters, token, endpoint, timeout, streamFun)
42
%sendRequest Sends a request to an ENDPOINT using PARAMETERS and
53
% api key TOKEN. TIMEOUT is the nubmer of seconds to wait for initial
6-
% server connection.
4+
% server connection. STREAMFUN is an optional callback function.
75

86
% Copyright 2023 The MathWorks, Inc.
97

@@ -12,6 +10,7 @@
1210
token
1311
endpoint
1412
timeout
13+
streamFun
1514
end
1615

1716
% Define the headers for the API request
@@ -24,9 +23,18 @@
2423

2524
% Create a HTTPOptions object;
2625
httpOpts = matlab.net.http.HTTPOptions;
27-
% Set the ConnectTimeout option
2826

27+
% Set the ConnectTimeout option
2928
httpOpts.ConnectTimeout = timeout;
29+
3030
% Send the request and store the response
31-
response = send(request, matlab.net.URI(endpoint),httpOpts);
31+
if isempty(streamFun)
32+
response = send(request, matlab.net.URI(endpoint),httpOpts);
33+
streamedText = "";
34+
else
35+
% User defined a stream callback function
36+
consumer = llms.stream.responseStreamer(streamFun);
37+
response = send(request, matlab.net.URI(endpoint),httpOpts,consumer);
38+
streamedText = consumer.ResponseText;
39+
end
3240
end

+llms/+stream/responseStreamer.m

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
classdef responseStreamer < matlab.net.http.io.StringConsumer
2+
%responseStreamer Responsible for obtaining the streaming results from the
3+
%API
4+
5+
% Copyright 2023 The MathWorks, Inc.
6+
7+
properties
8+
ResponseText
9+
StreamFun
10+
end
11+
12+
methods
13+
function this = responseStreamer(streamFun)
14+
this.StreamFun = streamFun;
15+
end
16+
end
17+
18+
methods (Access=protected)
19+
function length = start(this)
20+
if this.Response.StatusCode ~= matlab.net.http.StatusCode.OK
21+
length = 0;
22+
else
23+
length = this.start@matlab.net.http.io.StringConsumer;
24+
end
25+
end
26+
end
27+
28+
methods
29+
function [len,stop] = putData(this, data)
30+
[len,stop] = this.putData@matlab.net.http.io.StringConsumer(data);
31+
32+
% Extract out the response text from the message
33+
str = native2unicode(data','UTF-8');
34+
str = split(str,newline);
35+
str = str(strlength(str)>0);
36+
str = erase(str,"data: ");
37+
38+
for i = 1:length(str)
39+
json = jsondecode(str{i});
40+
if strcmp(json.choices.finish_reason,'stop')
41+
stop = true;
42+
return
43+
else
44+
txt = json.choices.delta.content;
45+
this.StreamFun(txt);
46+
this.ResponseText = [this.ResponseText txt];
47+
end
48+
end
49+
end
50+
end
51+
end

openAIChat.m

+12-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
% FrequencyPenalty - Penalty value for using a token that is frequent
3232
% in the training data. Default value is 0.
3333
%
34+
% StreamFun - Function to callback when streaming the
35+
% result
36+
%
3437
% openAIChat Functions:
3538
% openAIChat - Chat completion API from OpenAI.
3639
% generate - Generate a response using the openAIChat instance.
@@ -95,6 +98,7 @@
9598
Functions
9699
FunctionsStruct
97100
ApiKey
101+
StreamFun
98102
end
99103

100104
methods
@@ -112,6 +116,13 @@
112116
nvp.PresencePenalty {mustBeValidPenalty} = 0
113117
nvp.FrequencyPenalty {mustBeValidPenalty} = 0
114118
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
119+
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
120+
end
121+
122+
if isfield(nvp,"StreamFun")
123+
this.StreamFun = nvp.StreamFun;
124+
else
125+
this.StreamFun = [];
115126
end
116127

117128
if ~isempty(nvp.Functions)
@@ -182,7 +193,7 @@
182193
TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,...
183194
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
184195
PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ...
185-
ApiKey=this.ApiKey,TimeOut=this.TimeOut);
196+
ApiKey=this.ApiKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun);
186197
end
187198

188199
function this = set.Temperature(this, temperature)

tests/topenAIChat.m

+8
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ function assignValueToProperty(property, value)
201201
function invalidConstructorInput = iGetInvalidConstructorInput
202202
validFunction = openAIFunction("funName");
203203
invalidConstructorInput = struct( ...
204+
"InvalidStreamFunType", struct( ...
205+
"Input",{{"StreamFun", "2" }},...
206+
"Error", "MATLAB:validators:mustBeA"), ...
207+
...
208+
"InvalidStreamFunSize", struct( ...
209+
"Input",{{"StreamFun", [1 1 1] }},...
210+
"Error", "MATLAB:validation:IncompatibleSize"), ...
211+
...
204212
"InvalidTimeOutType", struct( ...
205213
"Input",{{"TimeOut", "2" }},...
206214
"Error", "MATLAB:validators:mustBeReal"), ...

0 commit comments

Comments
 (0)