Skip to content

Commit 002548e

Browse files
authored
Merge pull request #100 from #79
Feature/시각화 코드 작성
2 parents 73be0fd + 6312b3e commit 002548e

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

interface/lang2sql.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.graph import builder
14+
from llm_utils.display_chart import DisplayChart
1415
from llm_utils.llm_response_parser import LLMResponseParser
1516

17+
1618
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1719
SIDEBAR_OPTIONS = {
1820
"show_total_token_usage": "Show Total Token Usage",
@@ -177,6 +179,19 @@ def should_show(_key: str) -> bool:
177179
st.dataframe(df.head(10) if len(df) > 10 else df)
178180
except Exception as e:
179181
st.error(f"쿼리 실행 중 오류 발생: {e}")
182+
if should_show("show_chart"):
183+
df = database.run_sql(sql)
184+
st.markdown("**쿼리 결과 시각화:**")
185+
display_code = DisplayChart(
186+
question=res["refined_input"].content,
187+
sql=sql,
188+
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
189+
)
190+
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
191+
fig = display_code.get_plotly_figure(
192+
plotly_code=display_code.generate_plotly_code(), df=df
193+
)
194+
st.plotly_chart(fig)
180195

181196

182197
db = ConnectDB()

llm_utils/display_chart.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import re
2+
from llm_utils import llm_factory
3+
from dotenv import load_dotenv
4+
from langchain.chains.llm import LLMChain
5+
from langchain_openai import ChatOpenAI
6+
from langchain_core.prompts import PromptTemplate
7+
from langchain_core.messages import HumanMessage, SystemMessage
8+
import pandas as pd
9+
import os
10+
11+
import plotly
12+
import plotly.express as px
13+
import plotly.graph_objects as go
14+
15+
16+
# .env 파일 로딩
17+
load_dotenv()
18+
19+
20+
class DisplayChart:
21+
"""
22+
SQL쿼리가 실행된 결과를 그래프로 시각화하는 Class입니다.
23+
24+
쿼리 결과를 비롯한 유저 질문, sql를 prompt에 입력하여
25+
plotly코드를 출력하여 excute한 결과를 fig 객체로 반환합니다.
26+
"""
27+
28+
def __init__(self, question, sql, df_metadata):
29+
self.question = question
30+
self.sql = sql
31+
self.df_metadata = df_metadata
32+
33+
def llm_model_for_chart(self, message_log):
34+
provider = os.getenv("LLM_PROVIDER")
35+
if provider == "openai":
36+
llm = ChatOpenAI(
37+
model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"),
38+
api_key=os.getenv("OPEN_AI_KEY"),
39+
)
40+
result = llm.invoke(message_log)
41+
return result
42+
43+
def _extract_python_code(self, markdown_string: str) -> str:
44+
# Strip whitespace to avoid indentation errors in LLM-generated code
45+
markdown_string = markdown_string.content.split("```")[1][6:].strip()
46+
47+
# Regex pattern to match Python code blocks
48+
pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" # 여러 문자와 공백 뒤에 python이 나오고, 줄바꿈 이후의 모든 내용
49+
50+
# Find all matches in the markdown string
51+
matches = re.findall(
52+
pattern, markdown_string, re.IGNORECASE
53+
) # 대소문자 구분 안함
54+
55+
# Extract the Python code from the matches
56+
python_code = []
57+
for match in matches:
58+
python = match[0] if match[0] else match[1]
59+
python_code.append(python.strip())
60+
61+
if len(python_code) == 0:
62+
return markdown_string
63+
64+
return python_code[0]
65+
66+
def _sanitize_plotly_code(self, raw_plotly_code):
67+
# Remove the fig.show() statement from the plotly code
68+
plotly_code = raw_plotly_code.replace("fig.show()", "")
69+
70+
return plotly_code
71+
72+
def generate_plotly_code(self) -> str:
73+
if self.question is not None:
74+
system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{self.question}'"
75+
else:
76+
system_msg = "The following is a pandas DataFrame "
77+
78+
if self.sql is not None:
79+
system_msg += (
80+
f"\n\nThe DataFrame was produced using this query: {self.sql}\n\n"
81+
)
82+
83+
system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{self.df_metadata}"
84+
85+
message_log = [
86+
SystemMessage(content=system_msg),
87+
HumanMessage(
88+
content="Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
89+
),
90+
]
91+
92+
plotly_code = self.llm_model_for_chart(message_log)
93+
94+
return self._sanitize_plotly_code(self._extract_python_code(plotly_code))
95+
96+
def get_plotly_figure(
97+
self, plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
98+
) -> plotly.graph_objs.Figure:
99+
100+
ldict = {"df": df, "px": px, "go": go}
101+
try:
102+
exec(plotly_code, globals(), ldict)
103+
fig = ldict.get("fig", None)
104+
105+
except Exception as e:
106+
107+
# Inspect data types
108+
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
109+
categorical_cols = df.select_dtypes(
110+
include=["object", "category"]
111+
).columns.tolist()
112+
113+
# Decision-making for plot type
114+
if len(numeric_cols) >= 2:
115+
# Use the first two numeric columns for a scatter plot
116+
fig = px.scatter(df, x=numeric_cols[0], y=numeric_cols[1])
117+
elif len(numeric_cols) == 1 and len(categorical_cols) >= 1:
118+
# Use a bar plot if there's one numeric and one categorical column
119+
fig = px.bar(df, x=categorical_cols[0], y=numeric_cols[0])
120+
elif len(categorical_cols) >= 1 and df[categorical_cols[0]].nunique() < 10:
121+
# Use a pie chart for categorical data with fewer unique values
122+
fig = px.pie(df, names=categorical_cols[0])
123+
else:
124+
# Default to a simple line plot if above conditions are not met
125+
fig = px.line(df)
126+
127+
if fig is None:
128+
return None
129+
130+
if dark_mode:
131+
fig.update_layout(template="plotly_dark")
132+
133+
return fig

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,9 @@ langchain-google-genai>=2.1.3,<3.0.0
1717
langchain-ollama>=0.3.2,<0.4.0
1818
langchain-huggingface>=0.1.2,<0.2.0
1919
clickhouse_driver
20+
plotly
21+
matplotlib
22+
ipython
23+
kaleido
2024
numpy<2.0
2125
pytest>=8.3.5

0 commit comments

Comments
 (0)