Skip to content

Commit 6fc116b

Browse files
committed
add prompt param
1 parent 328a92e commit 6fc116b

File tree

9 files changed

+80
-15
lines changed

9 files changed

+80
-15
lines changed

critical_thinking.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,14 @@ def save_to_json(self, file_path):
3535

3636
if __name__ == "__main__":
3737
import sys
38+
from params import get_param
39+
prompt_template_path = get_param("prompt_templates_path")
40+
3841
if len(sys.argv) != 3:
3942
print("Usage: <MainQuestion> <BackgroundKnowledge>")
4043
sys.exit(1)
4144
main_question = sys.argv[1]
4245
background_knowledge_path = sys.argv[2]
43-
think = CriticalThinking(main_question, "./prompt_templates/ptemplate_critical_thinking.txt", background_knowledge_path)
46+
think = CriticalThinking(main_question, prompt_template_path + "/ptemplate_critical_thinking.txt", background_knowledge_path)
4447
think.create()
4548
think.save_to_raw("test/result/critical_thinking.json")

evaluator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def evaluate(self, template_path, ref_json_path):
6969

7070
if __name__ == "__main__":
7171
import sys
72+
from params import get_param
73+
prompt_template_path = get_param("prompt_templates_path")
7274

7375
if len(sys.argv) != 4 and len(sys.argv) != 5:
7476
print("Usage: <MainQuestion> <plan> <memory> [<reflection>]")
@@ -78,12 +80,12 @@ def evaluate(self, template_path, ref_json_path):
7880
mem_json_path = sys.argv[3]
7981
plan = Plan()
8082
plan.load_from_json(plan_json_path)
81-
mission_path= "./prompt_templates/ptemplate_mission.txt"
83+
mission_path= prompt_template_path + "/ptemplate_mission.txt"
8284
memory_stream = MemoryStream()
8385
memory_stream.load_from_json(mem_json_path)
8486
evaluator = Evaluator(main_question, mission_path, plan, memory_stream)
8587
if len(sys.argv) == 4:
8688
evaluator.merge_data()
8789
elif len(sys.argv) == 5:
8890
ref_json_path = sys.argv[4]
89-
evaluator.evaluate("./prompt_templates/ptemplate_evaluate.txt", ref_json_path)
91+
evaluator.evaluate(prompt_template_path + "/ptemplate_evaluate.txt", ref_json_path)

params.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"prompt_templates_path": "./prompt_templates",
3+
"documents_path": "../documents",
4+
"documents_data_path": "./documents_data",
5+
"reflections_data_path": "./reflections_data"
6+
}

params.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import json
5+
6+
def get_param(param_name: str):
7+
with open('./params.json', 'r') as file:
8+
param = json.load(file)
9+
return param.get(param_name)
10+
11+
if __name__ == "__main__":
12+
import sys
13+
if len(sys.argv) != 2:
14+
print("Usage: <param_name>")
15+
sys.exit(1)
16+
param_name = sys.argv[1]
17+
18+
print(get_param(param_name))

planner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def save_strategy_history(self):
111111

112112
if __name__ == "__main__":
113113
import sys
114+
from params import get_param
115+
prompt_template_path = get_param("prompt_templates_path")
114116
if len(sys.argv) != 5:
115117
print("Usage: <MainQuestion> <doc_list.txt> <background_knowledge_path> <acquired_knowledge_path>")
116118
sys.exit(1)
@@ -124,9 +126,9 @@ def save_strategy_history(self):
124126
batched_list = [total_list[i:i+batch_size] for i in range(0, len(total_list), batch_size)]
125127
planner = Planner(
126128
main_question = main_question,
127-
mission_path= "./prompt_templates/ptemplate_mission.txt",
128-
strategy_path= "./prompt_templates/ptemplate_strategy.txt",
129-
query_plan_path= "./prompt_templates/ptemplate_query_plan.txt",
129+
mission_path= prompt_template_path + "/ptemplate_mission.txt",
130+
strategy_path= prompt_template_path + "/ptemplate_strategy.txt",
131+
query_plan_path= prompt_template_path + "/ptemplate_query_plan.txt",
130132
strategy_history_path="./test/strategy_history.json",
131133
background_knowledge_path = background_knowledge_path,
132134
acquired_knowledge_path = acquired_knowledge_path

prompt_template.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def get_prompt(self, **kwargs) -> str:
1515
return self.template.format(**kwargs)
1616

1717
if __name__ == "__main__":
18-
pt = PromptTemplate("./prompt_templates/ptemplate_query.txt")
18+
from params import get_param
19+
prompt_template_path = get_param("prompt_templates_path")
20+
21+
pt = PromptTemplate(prompt_template_path + "/ptemplate_query.txt")
1922
while True:
2023
target_doc_id = input("TargetDocID> ")
2124
question = input("question> ")

query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ def _save(self, sub_question: str, reply: str, point: int):
4646
if __name__ == "__main__":
4747
import sys
4848
from db_manager import get_qa
49+
from params import get_param
50+
param_prompt_template_path = get_param("prompt_templates_path")
51+
4952
db_dir = ".."
5053
doc_id = "DB"
5154
qa = get_qa(db_dir, doc_id)
5255
memory_stream = MemoryStream()
53-
prompt_template_path = "./prompt_templates/ptemplate_query.txt"
56+
prompt_template_path = param_prompt_template_path + "/ptemplate_query.txt"
5457
query = Query("1", "Athrillとは何ですか?", memory_stream, qa)
5558
while True:
5659
question = input("question> ")

tactical_plannig.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,25 @@ def _generate_document_question(self, prompt_template_path, document_id, purpose
4141
if __name__ == "__main__":
4242
from query import Query
4343
from memory_stream import MemoryStream
44+
from params import get_param
45+
param_prompt_template_path = get_param("prompt_templates_path")
46+
4447
plan = Plan()
4548
plan.load_from_json("./test/result/plan.json")
4649
db_dir = "../documents/dbs"
4750
tactical_planning = TacticalPlanning(plan, db_dir)
4851
memory_stream = MemoryStream()
4952

5053
while True:
51-
ret = tactical_planning.generate_question("./prompt_templates/ptemplate_subq_detail.txt")
54+
ret = tactical_planning.generate_question(param_prompt_template_path + "/ptemplate_subq_detail.txt")
5255
if ret == None:
5356
print("END")
5457
break
5558
plan_id = ret[0]
5659
doc_id = ret[1]
5760
question = ret[2]
5861
qa = get_qa(db_dir, doc_id)
59-
prompt_template_path = "./prompt_templates/ptemplate_query.txt"
62+
prompt_template_path = param_prompt_template_path + "/ptemplate_query.txt"
6063
query = Query(doc_id, question, memory_stream, qa)
6164
memory_id = query.run(prompt_template_path, question)
6265
if memory_id < 0:

tools/do_query.bash

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@ then
77
exit 1
88
fi
99

10+
PRMT_TMP_PATH=`python3 params.py "prompt_templates_path"`
11+
if [ ! -d ${PRMT_TMP_PATH} ]
12+
then
13+
echo "ERROR: can not found path: ${PRMT_TMP_PATH}"
14+
exit 1
15+
fi
16+
DOCUMENT_PATH=`python3 params.py "documents_path"`
17+
if [ ! -d ${DOCUMENT_PATH} ]
18+
then
19+
echo "ERROR: can not found path: ${DOCUMENT_PATH}"
20+
exit 1
21+
fi
22+
DOC_DATA_PATH=`python3 params.py "documents_data_path"`
23+
if [ ! -d ${DOC_DATA_PATH} ]
24+
then
25+
echo "ERROR: can not found path: ${DOC_DATA_PATH}"
26+
exit 1
27+
fi
28+
REF_DATA_PATH=`python3 params.py "reflections_data_path"`
29+
if [ ! -d ${REF_DATA_PATH} ]
30+
then
31+
echo "ERROR: can not found path: ${REF_DATA_PATH}"
32+
exit 1
33+
fi
34+
1035
if [ -d test ]
1136
then
1237
:
@@ -47,8 +72,8 @@ query="`cat ${query_dir}/query.txt`"
4772

4873
DOCUMENT_TOKENS=2048
4974
REFLECTION_TOKENS=2048
50-
USE_BACKGROUND="FALSE"
51-
#USE_BACKGROUND="TRUE"
75+
#USE_BACKGROUND="FALSE"
76+
USE_BACKGROUND="TRUE"
5277
ADD_REFLECTION="TRUE"
5378
#ADD_REFLECTION="FALSE"
5479
export NEW_STARTEGY=
@@ -142,9 +167,9 @@ do
142167
echo "INFO: REFLECTING..."
143168
if [ -f "./test/result/reflection.json" ]
144169
then
145-
python3 reflection.py "$query" ../documents/document.list "./test/result/reflection.json" ${background_file} "./prompt_templates/ptemplate_reflection.txt"
170+
python3 reflection.py "$query" ../documents/document.list "./test/result/reflection.json" ${background_file} "${PRMT_TMP_PATH}/ptemplate_reflection.txt"
146171
else
147-
python3 reflection.py "$query" ../documents/document.list "./test/result/critical_thinking.json" ${background_file} "./prompt_templates/ptemplate_reflection.txt"
172+
python3 reflection.py "$query" ../documents/document.list "./test/result/critical_thinking.json" ${background_file} "${PRMT_TMP_PATH}/ptemplate_reflection.txt"
148173
fi
149174
cp ./test/result/reflection.json ./test/result/reflection_org.json
150175
python3 check_recover_json.py ./test/result/reflection.json
@@ -156,7 +181,7 @@ do
156181

157182
echo "INFO: ADD REFLECTION TERMS..."
158183
#cp ./test/result/reflection.json ./test/result/prev_reflection.json
159-
python3 reflection.py "$query" ../documents/document.list "./test/result/reflection.json" ${background_file} "./prompt_templates/ptemplate_reflection_addterms.txt"
184+
python3 reflection.py "$query" ../documents/document.list "./test/result/reflection.json" ${background_file} "${PRMT_TMP_PATH}/ptemplate_reflection_addterms.txt"
160185
python3 check_recover_json.py ./test/result/reflection.json
161186
if [ $? -ne 0 ]
162187
then

0 commit comments

Comments
 (0)