Skip to content

Feature/시각화 코드 작성 #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from llm_utils.connect_db import ConnectDB
from llm_utils.graph import builder
from llm_utils.display_chart import DisplayChart
from llm_utils.llm_response_parser import LLMResponseParser


DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
SIDEBAR_OPTIONS = {
"show_total_token_usage": "Show Total Token Usage",
Expand Down Expand Up @@ -177,6 +179,19 @@ def should_show(_key: str) -> bool:
st.dataframe(df.head(10) if len(df) > 10 else df)
except Exception as e:
st.error(f"쿼리 실행 중 오류 발생: {e}")
if should_show("show_chart"):
df = database.run_sql(sql)
st.markdown("**쿼리 결과 시각화:**")
display_code = DisplayChart(
question=res["refined_input"].content,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n{df.dtypes}",
)
# plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다
fig = display_code.get_plotly_figure(
plotly_code=display_code.generate_plotly_code(), df=df
)
st.plotly_chart(fig)


db = ConnectDB()
Expand Down
133 changes: 133 additions & 0 deletions llm_utils/display_chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import re
from llm_utils import llm_factory
from dotenv import load_dotenv
from langchain.chains.llm import LLMChain
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
import pandas as pd
import os

import plotly
import plotly.express as px
import plotly.graph_objects as go


# .env 파일 로딩
load_dotenv()


class DisplayChart:
"""
SQL쿼리가 실행된 결과를 그래프로 시각화하는 Class입니다.

쿼리 결과를 비롯한 유저 질문, sql를 prompt에 입력하여
plotly코드를 출력하여 excute한 결과를 fig 객체로 반환합니다.
"""

def __init__(self, question, sql, df_metadata):
self.question = question
self.sql = sql
self.df_metadata = df_metadata

def llm_model_for_chart(self, message_log):
provider = os.getenv("LLM_PROVIDER")
if provider == "openai":
llm = ChatOpenAI(
model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"),
api_key=os.getenv("OPEN_AI_KEY"),
)
result = llm.invoke(message_log)
return result

def _extract_python_code(self, markdown_string: str) -> str:
# Strip whitespace to avoid indentation errors in LLM-generated code
markdown_string = markdown_string.content.split("```")[1][6:].strip()

# Regex pattern to match Python code blocks
pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" # 여러 문자와 공백 뒤에 python이 나오고, 줄바꿈 이후의 모든 내용

# Find all matches in the markdown string
matches = re.findall(
pattern, markdown_string, re.IGNORECASE
) # 대소문자 구분 안함

# Extract the Python code from the matches
python_code = []
for match in matches:
python = match[0] if match[0] else match[1]
python_code.append(python.strip())

if len(python_code) == 0:
return markdown_string

return python_code[0]

def _sanitize_plotly_code(self, raw_plotly_code):
# Remove the fig.show() statement from the plotly code
plotly_code = raw_plotly_code.replace("fig.show()", "")

return plotly_code

def generate_plotly_code(self) -> str:
if self.question is not None:
system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{self.question}'"
else:
system_msg = "The following is a pandas DataFrame "

if self.sql is not None:
system_msg += (
f"\n\nThe DataFrame was produced using this query: {self.sql}\n\n"
)

system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{self.df_metadata}"

message_log = [
SystemMessage(content=system_msg),
HumanMessage(
content="Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
),
]

plotly_code = self.llm_model_for_chart(message_log)

return self._sanitize_plotly_code(self._extract_python_code(plotly_code))

def get_plotly_figure(
self, plotly_code: str, df: pd.DataFrame, dark_mode: bool = True
) -> plotly.graph_objs.Figure:

ldict = {"df": df, "px": px, "go": go}
try:
exec(plotly_code, globals(), ldict)
fig = ldict.get("fig", None)

except Exception as e:

# Inspect data types
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
categorical_cols = df.select_dtypes(
include=["object", "category"]
).columns.tolist()

# Decision-making for plot type
if len(numeric_cols) >= 2:
# Use the first two numeric columns for a scatter plot
fig = px.scatter(df, x=numeric_cols[0], y=numeric_cols[1])
elif len(numeric_cols) == 1 and len(categorical_cols) >= 1:
# Use a bar plot if there's one numeric and one categorical column
fig = px.bar(df, x=categorical_cols[0], y=numeric_cols[0])
elif len(categorical_cols) >= 1 and df[categorical_cols[0]].nunique() < 10:
# Use a pie chart for categorical data with fewer unique values
fig = px.pie(df, names=categorical_cols[0])
else:
# Default to a simple line plot if above conditions are not met
fig = px.line(df)

if fig is None:
return None

if dark_mode:
fig.update_layout(template="plotly_dark")

return fig
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ langchain-google-genai>=2.1.3,<3.0.0
langchain-ollama>=0.3.2,<0.4.0
langchain-huggingface>=0.1.2,<0.2.0
clickhouse_driver
plotly
matplotlib
ipython
kaleido
numpy<2.0
pytest>=8.3.5