Skip to content

Commit e00e13c

Browse files
committed
add similarity check
1 parent 526bfd7 commit e00e13c

File tree

5 files changed

+211
-100
lines changed

5 files changed

+211
-100
lines changed

GenerativeAgentForDoq.ipynb

Lines changed: 106 additions & 100 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ pip3 install pypdf
7474
pip3 install langchain
7575
pip3 install unstructured
7676
pip3 install tabulate
77+
pip3 install scikit-learn
7778
```
7879

7980
Please set the OpenAPI API key as an environment variable.

data_model/openai_libs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import openai
5+
import tiktoken
6+
from tiktoken.core import Encoding
7+
from openai.embeddings_utils import cosine_similarity
8+
9+
llm_model = "gpt-4-0613"
10+
embedding_model = "text-embedding-ada-002"
11+
12+
# Embedding
13+
def get_embedding(text_input: str):
14+
global embedding_model
15+
16+
# ベクトル変換
17+
response = openai.Embedding.create(
18+
input = text_input.replace("\n", " "), # 入力文章
19+
model = embedding_model, # GPTモデル
20+
)
21+
22+
# 出力結果取得
23+
embeddings = response['data'][0]['embedding']
24+
25+
return embeddings
26+
27+
def get_score(text1: str, text2: str):
28+
vec1 = get_embedding(text1)
29+
vec2 = get_embedding(text2)
30+
result = cosine_similarity(vec1, vec2)
31+
return result
32+
33+
34+
def get_tokenlen(data: str):
35+
encoding: Encoding = tiktoken.encoding_for_model(llm_model)
36+
tokens = encoding.encode(data)
37+
return len(tokens)

data_model/similarity_extractor.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from data_model_accessor import DataModelAccessor
5+
import json
6+
from openai_libs import get_score, get_tokenlen
7+
8+
class SimilarityExtractor:
9+
def __init__(self, accessor: DataModelAccessor, maxtoken_num: int):
10+
self.maxtoken_num = maxtoken_num
11+
self.accessor = accessor
12+
13+
def get_filelist(self, query: str):
14+
scores = self._calc_scores(query, accessor.get_filelist())
15+
result = []
16+
token_sum = 0
17+
for entry in scores:
18+
if token_sum + entry["tokens"] > self.maxtoken_num:
19+
break
20+
result.append(entry["file"])
21+
token_sum += entry["tokens"]
22+
return result
23+
24+
def _calc_scores(self, query: str, filelist: list):
25+
scores = []
26+
for entry in filelist:
27+
#print("file:", entry)
28+
json_data = self.accessor.get_data_model(entry).get_json_data()
29+
json_str = json.dumps(json_data)
30+
score = get_score(query, json_str)
31+
tokens = get_tokenlen(json_str)
32+
scores.append({
33+
"file": entry,
34+
"tokens": tokens,
35+
"score": score
36+
})
37+
scores.sort(key=lambda x: x["score"], reverse=True)
38+
return scores
39+
40+
41+
def extract(self, head_name: str, filelists: list):
42+
models = self.accessor.get_json_models(filelists)
43+
data = {
44+
head_name: models
45+
}
46+
return data
47+
48+
49+
if __name__ == "__main__":
50+
import sys
51+
import json
52+
if len(sys.argv) != 3:
53+
print("Usage: <query> <dir>")
54+
sys.exit(1)
55+
query = sys.argv[1]
56+
dir = sys.argv[2]
57+
58+
accessor = DataModelAccessor(dir)
59+
extractor = SimilarityExtractor(accessor, 1024)
60+
61+
filelist = extractor.get_filelist(query)
62+
data = extractor.extract("knowledges", filelist)
63+
data_str = json.dumps(data, indent=4, ensure_ascii=False)
64+
print(data_str)

tools/do_query.bash

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ do
4444
rm -rf test/*.json
4545
echo "INFO: CRITICAL THINKING"
4646
python3 critical_thinking.py "$query" ${background_file}
47+
python3 data_model/reflection_data_persistentor.py tmpdir ./test/result/critical_thinking.json
4748
echo "INFO: GETTING DOCUMENTS"
4849
get_docs "${query}" critical_thinking.json
4950
documents=`cat tmp.list`
@@ -88,6 +89,8 @@ do
8889
else
8990
python3 reflection.py "$query" ../documents/document.list "./test/result/critical_thinking.json" ${background_file} "./prompt_templates/ptemplate_reflection.txt"
9091
fi
92+
python3 data_model/reflection_data_persistentor.py tmpdir ./test/result/reflection.json
93+
9194
cp ./test/result/reflection.json ./test/result/prev_reflection.json
9295
python3 reflection.py "$query" ../documents/document.list "./test/result/reflection.json" ${background_file} "./prompt_templates/ptemplate_reflection_addterms.txt"
9396
mv ./test/result/reflection.json ./test/result/next_reflection.json

0 commit comments

Comments
 (0)