Skip to content

Add support for OpenAI Developer message #3089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.DeveloperMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
Expand Down Expand Up @@ -129,6 +130,7 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
case USER -> new UserMessage(content);
case ASSISTANT -> new AssistantMessage(content);
case SYSTEM -> new SystemMessage(content);
case DEVELOPER -> new DeveloperMessage(content);
// The content is always stored empty for ToolResponseMessages.
// If we want to capture the actual content, we need to extend
// AddBatchPreparedStatement to support it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.DeveloperMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
Expand Down Expand Up @@ -72,6 +73,7 @@ void saveMessagesSingleMessage(String content, MessageType messageType) {
case ASSISTANT -> new AssistantMessage(content + " - " + conversationId);
case USER -> new UserMessage(content + " - " + conversationId);
case SYSTEM -> new SystemMessage(content + " - " + conversationId);
case DEVELOPER -> new DeveloperMessage(content + " - " + conversationId);
default -> throw new IllegalArgumentException("Type not supported: " + messageType);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.DeveloperMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.content.Media;
import org.springframework.util.MimeType;
import org.testcontainers.containers.Neo4jContainer;
Expand Down Expand Up @@ -402,6 +402,7 @@ private Message createMessageByType(String content, MessageType messageType) {
case ASSISTANT -> new AssistantMessage(content);
case USER -> new UserMessage(content);
case SYSTEM -> new SystemMessage(content);
case DEVELOPER -> new DeveloperMessage(content);
case TOOL -> new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData")));
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestDeveloperMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
Expand Down Expand Up @@ -575,6 +576,8 @@ private List<ChatRequestMessage> fromSpringAiMessage(Message message) {
return List.of(new ChatRequestUserMessage(items));
case SYSTEM:
return List.of(new ChatRequestSystemMessage(message.getText()));
case DEVELOPER:
return List.of(new ChatRequestDeveloperMessage(message.getText()));
case ASSISTANT:
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ChatCompletionsToolCall> toolCalls = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
* @author Ilayaperumal Gopinathan
* @author Alexandros Pappas
* @author Soby Chacko
* @author Andres da Silva Santos
* @see ChatModel
* @see StreamingChatModel
* @see OpenAiApi
Expand Down Expand Up @@ -552,7 +553,8 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM
|| message.getMessageType() == MessageType.DEVELOPER) {
Object content = message.getText();
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
* @author Josh Long
* @author Arjen Poutsma
* @author Thomas Vitale
* @author Andres da Silva Santos
* @since 1.0.0
*/
public interface ChatClient {
Expand Down Expand Up @@ -133,6 +134,23 @@ interface PromptSystemSpec {

}

/**
* Specification for a prompt developer.
*/
interface PromptDeveloperSpec {

PromptDeveloperSpec text(String text);

PromptDeveloperSpec text(Resource text, Charset charset);

PromptDeveloperSpec text(Resource text);

PromptDeveloperSpec params(Map<String, Object> p);

PromptDeveloperSpec param(String k, Object v);

}

interface AdvisorSpec {

AdvisorSpec param(String k, Object v);
Expand Down Expand Up @@ -232,6 +250,14 @@ interface ChatClientRequestSpec {

ChatClientRequestSpec toolContext(Map<String, Object> toolContext);

ChatClientRequestSpec developer(String text);

ChatClientRequestSpec developer(Resource textResource, Charset charset);

ChatClientRequestSpec developer(Resource text);

ChatClientRequestSpec developer(Consumer<PromptDeveloperSpec> consumer);

ChatClientRequestSpec system(String text);

ChatClientRequestSpec system(Resource textResource, Charset charset);
Expand Down Expand Up @@ -277,6 +303,14 @@ interface Builder {

Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer);

Builder defaultDeveloper(String text);

Builder defaultDeveloper(Resource text, Charset charset);

Builder defaultDeveloper(Resource text);

Builder defaultDeveloper(Consumer<PromptDeveloperSpec> developerSpecConsumer);

Builder defaultSystem(String text);

Builder defaultSystem(Resource text, Charset charset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
* @author Soby Chacko
* @author Dariusz Jedrzejczyk
* @author Thomas Vitale
* @author Andres da Silva Santos
* @since 1.0.0
*/
public class DefaultChatClient implements ChatClient {
Expand Down Expand Up @@ -288,6 +289,68 @@ protected Map<String, Object> params() {

}

public static class DefaultPromptDeveloperSpec implements PromptDeveloperSpec {

private final Map<String, Object> params = new HashMap<>();

@Nullable
private String text;

@Override
public PromptDeveloperSpec text(String text) {
Assert.hasText(text, "text cannot be null or empty");
this.text = text;
return this;
}

@Override
public PromptDeveloperSpec text(Resource text, Charset charset) {
Assert.notNull(text, "text cannot be null");
Assert.notNull(charset, "charset cannot be null");
try {
this.text(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}

@Override
public PromptDeveloperSpec text(Resource text) {
Assert.notNull(text, "text cannot be null");
this.text(text, Charset.defaultCharset());
return this;
}

@Override
public PromptDeveloperSpec param(String key, Object value) {
Assert.hasText(key, "key cannot be null or empty");
Assert.notNull(value, "value cannot be null");
this.params.put(key, value);
return this;
}

@Override
public PromptDeveloperSpec params(Map<String, Object> params) {
Assert.notNull(params, "params cannot be null");
Assert.noNullElements(params.keySet(), "param keys cannot contain null elements");
Assert.noNullElements(params.values(), "param values cannot contain null elements");
this.params.putAll(params);
return this;
}

@Nullable
protected String text() {
return this.text;
}

protected Map<String, Object> params() {
return this.params;
}

}

public static class DefaultAdvisorSpec implements AdvisorSpec {

private final List<Advisor> advisors = new ArrayList<>();
Expand Down Expand Up @@ -577,6 +640,8 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe

private final Map<String, Object> systemParams = new HashMap<>();

private final Map<String, Object> developerParams = new HashMap<>();

private final List<Advisor> advisors = new ArrayList<>();

private final Map<String, Object> advisorParams = new HashMap<>();
Expand All @@ -591,27 +656,32 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
@Nullable
private String systemText;

@Nullable
private String developerText;

@Nullable
private ChatOptions chatOptions;

/* copy constructor */
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks,
ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext, ccr.templateRenderer);
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.developerText,
ccr.developerParams, ccr.toolCallbacks, ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions,
ccr.advisors, ccr.advisorParams, ccr.observationRegistry, ccr.observationConvention,
ccr.toolContext, ccr.templateRenderer);
}

public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
List<ToolCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
ObservationRegistry observationRegistry,
@Nullable String developerText, Map<String, Object> developerParams, List<ToolCallback> toolCallbacks,
List<Message> messages, List<String> toolNames, List<Media> media, @Nullable ChatOptions chatOptions,
List<Advisor> advisors, Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
@Nullable TemplateRenderer templateRenderer) {

Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(userParams, "userParams cannot be null");
Assert.notNull(systemParams, "systemParams cannot be null");
Assert.notNull(developerParams, "developerParams cannot be null");
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.notNull(messages, "messages cannot be null");
Assert.notNull(toolNames, "toolNames cannot be null");
Expand All @@ -629,6 +699,8 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
this.userParams.putAll(userParams);
this.systemText = systemText;
this.systemParams.putAll(systemParams);
this.developerText = developerText;
this.developerParams.putAll(developerParams);

this.toolNames.addAll(toolNames);
this.toolCallbacks.addAll(toolCallbacks);
Expand Down Expand Up @@ -661,6 +733,15 @@ public Map<String, Object> getSystemParams() {
return this.systemParams;
}

@Nullable
public String getDeveloperText() {
return this.developerText;
}

public Map<String, Object> getDeveloperParams() {
return this.developerParams;
}

@Nullable
public ChatOptions getChatOptions() {
return this.chatOptions;
Expand Down Expand Up @@ -719,6 +800,10 @@ public Builder mutate() {
builder.defaultSystem(s -> s.text(this.systemText).params(this.systemParams));
}

if (StringUtils.hasText(this.developerText)) {
builder.defaultDeveloper(s -> s.text(this.developerText).params(this.developerParams));
}

if (this.chatOptions != null) {
builder.defaultOptions(this.chatOptions);
}
Expand Down Expand Up @@ -821,6 +906,41 @@ public ChatClientRequestSpec toolContext(Map<String, Object> toolContext) {
return this;
}

public ChatClientRequestSpec developer(String text) {
Assert.hasText(text, "text cannot be null or empty");
this.developerText = text;
return this;
}

public ChatClientRequestSpec developer(Resource text, Charset charset) {
Assert.notNull(text, "text cannot be null");
Assert.notNull(charset, "charset cannot be null");

try {
this.developerText = text.getContentAsString(charset);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}

public ChatClientRequestSpec developer(Resource text) {
Assert.notNull(text, "text cannot be null");
return this.developer(text, Charset.defaultCharset());
}

public ChatClientRequestSpec developer(Consumer<PromptDeveloperSpec> consumer) {
Assert.notNull(consumer, "consumer cannot be null");

var developerSpec = new DefaultPromptDeveloperSpec();
consumer.accept(developerSpec);
this.developerText = StringUtils.hasText(developerSpec.text()) ? developerSpec.text() : this.developerText;
this.developerParams.putAll(developerSpec.params());

return this;
}

public ChatClientRequestSpec system(String text) {
Assert.hasText(text, "text cannot be null or empty");
this.systemText = text;
Expand Down
Loading