From a49f3f0d0870ab9633d88b363a797473ccd9c9a1 Mon Sep 17 00:00:00 2001 From: anjaaaaeeeellll Date: Tue, 6 May 2025 18:08:37 +0900 Subject: [PATCH 1/3] =?UTF-8?q?=EC=8B=9C=EA=B0=81=ED=99=94=20=EC=BD=94?= =?UTF-8?q?=EB=93=9C=20=EC=9E=91=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/lang2sql.py | 16 ++++- llm_utils/display_chart.py | 131 +++++++++++++++++++++++++++++++++++++ requirements.txt | 4 ++ 3 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 llm_utils/display_chart.py diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 973831a..208eeca 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -11,6 +11,7 @@ from llm_utils.connect_db import ConnectDB from llm_utils.graph import builder +from llm_utils.display_chart import DisplayChart DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" SIDEBAR_OPTIONS = { @@ -115,9 +116,22 @@ def display_result( if st.session_state.get("show_referenced_tables", True): st.write("참고한 테이블 목록:", res["searched_tables"]) if st.session_state.get("show_table", True): - sql = res["generated_query"] + st.write("쿼리 실행 결과") + sql = res["generated_query"].content.split("```")[1][ + 3: + ] # 쿼리 앞쪽의 "sql " 제거 df = database.run_sql(sql) st.dataframe(df.head(10) if len(df) > 10 else df) + if st.session_state.get("show_chart", True): + st.write("쿼리 결과 시각화") + display_code = DisplayChart( + question=res["refined_input"].content, + sql=sql, + df_metadata=f"Running df.dtypes gives:\n {df.dtypes}", + ) + plotly_code = display_code.generate_plotly_code() + fig = display_code.get_plotly_figure(plotly_code=plotly_code, df=df) + st.plotly_chart(fig) db = ConnectDB() diff --git a/llm_utils/display_chart.py b/llm_utils/display_chart.py new file mode 100644 index 0000000..16ad5af --- /dev/null +++ b/llm_utils/display_chart.py @@ -0,0 +1,131 @@ +import re +from llm_utils import llm_factory +from dotenv import load_dotenv +from langchain.chains.llm import LLMChain +from langchain_openai import ChatOpenAI +from langchain_core.prompts import PromptTemplate +from langchain_core.messages import HumanMessage, SystemMessage +import pandas as pd +import os + +import plotly +import plotly.express as px +import plotly.graph_objects as go + + +# .env 파일 로딩 +load_dotenv() + + +class DisplayChart: + """ + SQL쿼리가 실행된 결과를 그래프로 시각화하는 Class입니다. + + 쿼리 결과를 비롯한 유저 질문, sql를 prompt에 입력하여 + plotly코드를 출력하여 excute한 결과를 fig 객체로 반환합니다. + """ + + def __init__(self, question, sql, df_metadata): + self.question = question + self.sql = sql + self.df_metadata = df_metadata + + def llm_model_for_chart(self, message_log): + provider = os.getenv("LLM_PROVIDER") + if provider == "openai": + llm = ChatOpenAI( + model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"), + api_key=os.getenv("OPEN_AI_KEY"), + ) + result = llm.invoke(message_log) + return result + + def _extract_python_code(self, markdown_string: str) -> str: + # Strip whitespace to avoid indentation errors in LLM-generated code + markdown_string = markdown_string.content.split("```")[1][6:].strip() + + # Regex pattern to match Python code blocks + pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" + + # Find all matches in the markdown string + matches = re.findall(pattern, markdown_string, re.IGNORECASE) + + # Extract the Python code from the matches + python_code = [] + for match in matches: + python = match[0] if match[0] else match[1] + python_code.append(python.strip()) + + if len(python_code) == 0: + return markdown_string + + return python_code[0] + + def _sanitize_plotly_code(self, raw_plotly_code): + # Remove the fig.show() statement from the plotly code + plotly_code = raw_plotly_code.replace("fig.show()", "") + + return plotly_code + + def generate_plotly_code(self) -> str: + if self.question is not None: + system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{self.question}'" + else: + system_msg = "The following is a pandas DataFrame " + + if self.sql is not None: + system_msg += ( + f"\n\nThe DataFrame was produced using this query: {self.sql}\n\n" + ) + + system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{self.df_metadata}" + + message_log = [ + SystemMessage(content=system_msg), + HumanMessage( + content="Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." + ), + ] + + plotly_code = self.llm_model_for_chart(message_log) + + return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) + + def get_plotly_figure( + self, plotly_code: str, df: pd.DataFrame, dark_mode: bool = True + ) -> plotly.graph_objs.Figure: + + ldict = {"df": df, "px": px, "go": go} + try: + exec(plotly_code, globals(), ldict) + fig = ldict.get("fig", None) + + except Exception as e: + + # Inspect data types + numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() + categorical_cols = df.select_dtypes( + include=["object", "category"] + ).columns.tolist() + + # Decision-making for plot type + if len(numeric_cols) >= 2: + # Use the first two numeric columns for a scatter plot + fig = px.scatter(df, x=numeric_cols[0], y=numeric_cols[1]) + elif len(numeric_cols) == 1 and len(categorical_cols) >= 1: + # Use a bar plot if there's one numeric and one categorical column + fig = px.bar(df, x=categorical_cols[0], y=numeric_cols[0]) + elif len(categorical_cols) >= 1 and df[categorical_cols[0]].nunique() < 10: + # Use a pie chart for categorical data with fewer unique values + fig = px.pie(df, names=categorical_cols[0]) + else: + # Default to a simple line plot if above conditions are not met + fig = px.line(df) + + if fig is None: + return None + + if dark_mode: + fig.update_layout(template="plotly_dark") + + return fig diff --git a/requirements.txt b/requirements.txt index 0f1ecb0..85c0026 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,7 @@ langchain-google-genai>=2.1.3,<3.0.0 langchain-ollama>=0.3.2,<0.4.0 langchain-huggingface>=0.1.2,<0.2.0 clickhouse_driver +plotly +matplotlib +ipython +kaleido From a15d3e047fe86f69b85b5d30f73abfa98ccc47e8 Mon Sep 17 00:00:00 2001 From: anjaaaaeeeellll Date: Tue, 6 May 2025 18:24:03 +0900 Subject: [PATCH 2/3] =?UTF-8?q?=EC=A3=BC=EC=84=9D=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_utils/display_chart.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llm_utils/display_chart.py b/llm_utils/display_chart.py index 16ad5af..ae32f66 100644 --- a/llm_utils/display_chart.py +++ b/llm_utils/display_chart.py @@ -45,10 +45,12 @@ def _extract_python_code(self, markdown_string: str) -> str: markdown_string = markdown_string.content.split("```")[1][6:].strip() # Regex pattern to match Python code blocks - pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" + pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" # 여러 문자와 공백 뒤에 python이 나오고, 줄바꿈 이후의 모든 내용 # Find all matches in the markdown string - matches = re.findall(pattern, markdown_string, re.IGNORECASE) + matches = re.findall( + pattern, markdown_string, re.IGNORECASE + ) # 대소문자 구분 안함 # Extract the Python code from the matches python_code = [] From 6312b3e5ffabaa62e5316213d89bd7f497995e2a Mon Sep 17 00:00:00 2001 From: ehddnr301 Date: Thu, 22 May 2025 12:41:55 +0000 Subject: [PATCH 3/3] =?UTF-8?q?fix:=20conflict=20=EC=88=98=EC=A0=95?= =?UTF-8?q?=EA=B2=B0=EA=B3=BC=EC=97=90=EC=84=9C=20=EC=97=90=EB=9F=AC=20?= =?UTF-8?q?=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interface/lang2sql.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 6e45e15..6006228 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -179,17 +179,17 @@ def should_show(_key: str) -> bool: st.dataframe(df.head(10) if len(df) > 10 else df) except Exception as e: st.error(f"쿼리 실행 중 오류 발생: {e}") - if should_show("show_chart") and df is not None: + if should_show("show_chart"): + df = database.run_sql(sql) st.markdown("**쿼리 결과 시각화:**") display_code = DisplayChart( question=res["refined_input"].content, sql=sql, - df_metadata=f"Running df.dtypes gives:\n{df.dtypes}" + df_metadata=f"Running df.dtypes gives:\n{df.dtypes}", ) # plotly_code 변수도 따로 보관할 필요 없이 바로 그려도 됩니다 fig = display_code.get_plotly_figure( - plotly_code=display_code.generate_plotly_code(), - df=df + plotly_code=display_code.generate_plotly_code(), df=df ) st.plotly_chart(fig)