Skip to content

Commit 76a1cc5

Browse files
authored
Merge pull request #4 from muhrifqii/dev/chat-memory
Chat Memory
2 parents 0605dba + 72b4a88 commit 76a1cc5

File tree

12 files changed

+205
-93
lines changed

12 files changed

+205
-93
lines changed

README.md

+11
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@ Sample project to showcase a java spring boot application with Generative AI pow
1111
- Flyway
1212
- [Ollama](https://ollama.com)
1313

14+
## Prerequisites - Ollama
15+
16+
- Install Ollama on your local machine
17+
- Run `ollama run llama3.1`
18+
19+
## Running the project
20+
21+
```bash
22+
make up
23+
```
24+
1425
## Video on The Stream Chat Response Feature
1526

1627

Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
package com.muhrifqii.llm.api.usecases;
22

33
import com.muhrifqii.llm.api.datamodels.conversations.Message;
4-
import com.muhrifqii.llm.api.datamodels.conversations.UserMessage;
54

65
import reactor.core.publisher.Flux;
76
import reactor.core.publisher.Mono;
87

98
public interface MessageStoreUsecase {
109
Flux<Message> getMessages(String conversationId, String cursor);
1110

12-
Mono<Message> saveUserMessage(UserMessage userMessage);
13-
14-
Mono<Message> saveAssistantMessage(Message message);
11+
Mono<Message> saveMessage(Message message, boolean newEntity);
1512
}

llm/api/src/main/java/com/muhrifqii/llm/api/utils/DateUtils.java

+7
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,11 @@ public static String toIsoString(LocalDateTime date) {
3030
.map(DateTimeFormatter.ISO_DATE_TIME::format)
3131
.orElse("");
3232
}
33+
34+
public static LocalDateTime fromIsoString(String date) {
35+
return Optional.ofNullable(date)
36+
.map(DateTimeFormatter.ISO_DATE_TIME::parse)
37+
.map(LocalDateTime::from)
38+
.orElse(null);
39+
}
3340
}
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,37 @@
11
package com.muhrifqii.llm.configurations;
22

3+
import java.time.Duration;
4+
35
import org.springframework.ai.chat.client.ChatClient;
4-
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
56
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
6-
import org.springframework.ai.chat.memory.ChatMemory;
7+
import org.springframework.boot.web.client.ClientHttpRequestFactories;
8+
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
9+
import org.springframework.boot.web.client.RestClientCustomizer;
710
import org.springframework.context.annotation.Bean;
811
import org.springframework.context.annotation.Configuration;
912

10-
import com.muhrifqii.llm.annotations.MemCachedChatMemory;
13+
import com.muhrifqii.llm.services.MessageStoreAdvisor;
1114

1215
@Configuration
1316
public class ChatModelConfig {
1417
@Bean
1518
ChatClient chatClient(
1619
ChatClient.Builder builder,
17-
@MemCachedChatMemory ChatMemory chatMemory) {
20+
MessageStoreAdvisor messageStoreAdvisor) {
1821
return builder
1922
.defaultSystem(
2023
"You are a Pokédex AI named Slaking.AI, a highly advanced AI that specializes in providing detailed and accurate information about Pokémon. You have access to all known data about Pokémon species, including their types, abilities, evolutions, habitat, and more. Your responses should be concise, factual, and directly related to the Pokémon in question. Ensure to offer relevant insights based on the user's query, and avoid speculation. Your goal is to assist users in learning everything they need to know about any Pokémon they ask about, much like a Pokédex would in the Pokémon world")
2124
.defaultAdvisors(
22-
new SimpleLoggerAdvisor(),
23-
new MessageChatMemoryAdvisor(chatMemory))
25+
messageStoreAdvisor,
26+
new SimpleLoggerAdvisor())
2427
.build();
2528
}
29+
30+
@Bean
31+
RestClientCustomizer restClientCustomizer() {
32+
return restClientBuilder -> restClientBuilder
33+
.requestFactory(ClientHttpRequestFactories.get(ClientHttpRequestFactorySettings.DEFAULTS
34+
.withConnectTimeout(Duration.ofSeconds(180))
35+
.withReadTimeout(Duration.ofSeconds(180))));
36+
}
2637
}

llm/ollama-provider/src/main/java/com/muhrifqii/llm/repositories/MessageEntity.java

+18-10
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,29 @@
77
import org.springframework.ai.chat.messages.MessageType;
88
import org.springframework.data.annotation.CreatedDate;
99
import org.springframework.data.annotation.Id;
10+
import org.springframework.data.annotation.Transient;
1011
import org.springframework.data.domain.Persistable;
1112
import org.springframework.data.relational.core.mapping.Table;
1213
import org.springframework.lang.Nullable;
1314

14-
import lombok.Builder;
15+
import lombok.Getter;
16+
import lombok.Setter;
17+
import lombok.experimental.Accessors;
1518

1619
@Table("ai_messages")
17-
@Builder
18-
public record MessageEntity(
19-
@Id String id,
20-
String coversationId,
21-
String content,
22-
String messageType,
23-
@CreatedDate LocalDateTime createdAt)
24-
implements Persistable<String>, Message {
20+
@Getter
21+
@Setter
22+
@Accessors(fluent = true)
23+
public class MessageEntity implements Persistable<String>, Message {
24+
@Id
25+
private String id;
26+
private String conversationId;
27+
private String content;
28+
private String messageType;
29+
@CreatedDate
30+
private LocalDateTime createdAt;
31+
@Transient
32+
private boolean newEntity;
2533

2634
@Override
2735
@Nullable
@@ -31,7 +39,7 @@ public String getId() {
3139

3240
@Override
3341
public boolean isNew() {
34-
return createdAt == null;
42+
return createdAt == null || newEntity;
3543
}
3644

3745
@Override

llm/ollama-provider/src/main/java/com/muhrifqii/llm/services/ChatHelper.java

+25-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ public static Message newMessage(String conversationID) {
3535
.id(generateId())
3636
.conversationId(conversationID)
3737
.messageType(MessageType.USER.getValue())
38-
.createdAt(DateUtils.nowIsoString())
3938
.build();
4039
}
4140

@@ -81,20 +80,40 @@ public static com.muhrifqii.llm.repositories.Conversation updateConversation(
8180
public static Message mapMessage(MessageEntity source) {
8281
return Message.builder()
8382
.id(source.id())
84-
.conversationId(source.coversationId())
83+
.conversationId(source.conversationId())
8584
.content(source.content())
8685
.messageType(source.messageType())
8786
.createdAt(DateUtils.toIsoString(source.createdAt()))
8887
.build();
8988
}
9089

91-
public static MessageEntity mapFromAi(String conversationId, org.springframework.ai.chat.messages.Message source) {
92-
return MessageEntity.builder()
90+
public static MessageEntity mapMessage(Message source, boolean newEntity) {
91+
return new MessageEntity()
92+
.id(source.id())
93+
.conversationId(source.conversationId())
94+
.content(source.content())
95+
.messageType(source.messageType())
96+
.createdAt(DateUtils.fromIsoString(source.createdAt()))
97+
.newEntity(newEntity);
98+
}
99+
100+
public static Message mapFromAi(String conversationId, org.springframework.ai.chat.messages.Message source) {
101+
return Message.builder()
102+
.id(generateId())
103+
.conversationId(conversationId)
104+
.content(source.getContent())
105+
.messageType(source.getMessageType().getValue())
106+
.createdAt(DateUtils.nowIsoString())
107+
.build();
108+
}
109+
110+
public static Message mapFromAi(String conversationId, org.springframework.ai.chat.messages.UserMessage source) {
111+
return Message.builder()
93112
.id(generateId())
94-
.coversationId(conversationId)
113+
.conversationId(conversationId)
95114
.content(source.getContent())
96115
.messageType(source.getMessageType().getValue())
97-
.createdAt(DateUtils.now())
116+
.createdAt(DateUtils.nowIsoString())
98117
.build();
99118
}
100119

llm/ollama-provider/src/main/java/com/muhrifqii/llm/services/ChatService.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class ChatService implements ChatServiceTrait {
2727

2828
@Override
2929
public Mono<Message> chat(String conversationID, UserMessage message) {
30-
return Mono.fromCallable(() -> chatClient
30+
return Mono.fromSupplier(() -> chatClient
3131
.prompt()
3232
.user(message.content())
3333
.advisors(advSpec -> chatMemoryAdvisorSpec(advSpec, conversationID))
@@ -71,7 +71,7 @@ private Flux<Message> chatPromptStream(Message source, UserMessage userMessage)
7171
}
7272

7373
private void chatMemoryAdvisorSpec(AdvisorSpec advisorSpec, String conversationId) {
74-
if (Constants.EMPTY_SLUG.equals(conversationId)) {
74+
if (Constants.EMPTY_SLUG.equals(conversationId) || conversationId == null) {
7575
return;
7676
}
7777
advisorSpec.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId);

llm/ollama-provider/src/main/java/com/muhrifqii/llm/services/ConversationService.java

+3-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import com.muhrifqii.llm.api.datamodels.HandledError;
1111
import com.muhrifqii.llm.api.datamodels.conversations.Conversation;
1212
import com.muhrifqii.llm.api.datamodels.conversations.Message;
13-
import com.muhrifqii.llm.api.datamodels.conversations.UserMessage;
1413
import com.muhrifqii.llm.api.traits.ConversationServiceTrait;
1514
import com.muhrifqii.llm.repositories.ConversationRepository;
1615
import com.muhrifqii.llm.repositories.MessageRepository;
@@ -74,13 +73,9 @@ public Flux<Message> getMessages(String conversationId, String cursor) {
7473
}
7574

7675
@Override
77-
public Mono<Message> saveAssistantMessage(Message message) {
78-
return null;
79-
}
80-
81-
@Override
82-
public Mono<Message> saveUserMessage(UserMessage userMessage) {
83-
return null;
76+
public Mono<Message> saveMessage(Message message, boolean newEntity) {
77+
return messageRepository.save(ChatHelper.mapMessage(message, newEntity))
78+
.map(ChatHelper::mapMessage);
8479
}
8580

8681
private Mono<Conversation> throwIfConversationIdNotExist(String id) {

llm/ollama-provider/src/main/java/com/muhrifqii/llm/services/DbPersistedChatMemory.java

-55
This file was deleted.

llm/ollama-provider/src/main/java/com/muhrifqii/llm/services/MemcachedChatMemory.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.Map;
66
import java.util.concurrent.ConcurrentHashMap;
77

8+
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
89
import org.springframework.ai.chat.memory.ChatMemory;
910
import org.springframework.ai.chat.messages.Message;
1011
import org.springframework.stereotype.Component;
@@ -22,22 +23,40 @@ public class MemcachedChatMemory implements ChatMemory {
2223

2324
@Override
2425
public void add(String conversationId, List<Message> messages) {
26+
if (!isConversationIdValid(conversationId)) {
27+
return;
28+
}
29+
2530
log.debug("add:{} with {} messages", conversationId, messages.size());
2631
this.conversationHistory.putIfAbsent(conversationId, new ArrayList<>());
2732
this.conversationHistory.get(conversationId).addAll(messages);
2833
}
2934

3035
@Override
3136
public List<Message> get(String conversationId, int lastN) {
37+
if (!isConversationIdValid(conversationId)) {
38+
return List.of();
39+
}
40+
3241
log.debug("get:{}:{}", lastN, conversationId);
3342
List<Message> all = this.conversationHistory.get(conversationId);
34-
return all != null ? all.stream().skip(Math.max(0, all.size() - lastN)).toList() : List.of();
43+
return all != null ? all.stream()
44+
.skip(Math.max(0, all.size() - lastN))
45+
.toList() : List.of();
3546
}
3647

3748
@Override
3849
public void clear(String conversationId) {
50+
if (!isConversationIdValid(conversationId)) {
51+
return;
52+
}
53+
3954
log.debug("clear:{}", conversationId);
4055
this.conversationHistory.remove(conversationId);
4156
}
4257

58+
private boolean isConversationIdValid(String conversationId) {
59+
return !AbstractChatMemoryAdvisor.DEFAULT_CHAT_MEMORY_CONVERSATION_ID.equals(conversationId);
60+
}
61+
4362
}

0 commit comments

Comments
 (0)