diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index a73aeca..242b833 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,21 +1,25 @@ -name: Pre-commit Black Check +name: Pre-commit Auto-fix and Comment on: push: - branches: - - master + branches: [ master ] pull_request: - branches: - - master + branches: [ master ] jobs: pre-commit: - name: Run Pre-commit Hooks runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write steps: - name: Checkout repository uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + fetch-depth: 0 + ref: ${{ github.head_ref || github.ref_name }} - name: Set up Python uses: actions/setup-python@v4 @@ -23,7 +27,33 @@ jobs: python-version: "3.x" - name: Install pre-commit - run: pip install pre-commit + run: python -m pip install --upgrade pre-commit + + - name: Run pre-commit (auto-fix) + id: precommit_run + continue-on-error: true + run: | + pre-commit run --all-files --verbose --hook-stage manual - - name: Run pre-commit - run: pre-commit run --all-files --verbose + - name: Commit & push if changed + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + if ! git diff --quiet; then + git add -u + git commit -m "style: apply pre-commit fixes (black, isort)" || true + git push + echo "FIXES_PUSHED=true" >> $GITHUB_ENV + fi + + - name: Comment on PR + if: env.FIXES_PUSHED == 'true' && github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: "✅ Pre-commit hooks fixed formatting (black, isort) and the bot pushed the changes." + }) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30c473d..463dc09 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,16 @@ repos: + - repo: https://github.com/pycqa/isort + rev: 6.0.0 + hooks: + - id: isort - repo: https://github.com/psf/black rev: 25.1.0 hooks: - id: black + - repo: https://github.com/pycqa/flake8 + rev: 7.1.2 + hooks: + - id: flake8 + args: + - --ignore=E501,E722,E262,F401,F841,W503,E226 + - --max-line-length=88 \ No newline at end of file diff --git a/data_utils/datahub_source.py b/data_utils/datahub_source.py index 9be2c7d..e91e0bc 100644 --- a/data_utils/datahub_source.py +++ b/data_utils/datahub_source.py @@ -1,13 +1,18 @@ -from datahub.metadata.schema_classes import DatasetPropertiesClass, SchemaMetadataClass -from datahub.emitter.rest_emitter import DatahubRestEmitter -from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph -from datahub.metadata.schema_classes import UpstreamLineageClass from collections import defaultdict + import requests +from datahub.emitter.rest_emitter import DatahubRestEmitter +from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph +from datahub.metadata.schema_classes import ( + DatasetPropertiesClass, + SchemaMetadataClass, + UpstreamLineageClass, +) + from data_utils.queries import ( - ROOT_GLOSSARY_NODES_QUERY, GLOSSARY_NODE_QUERY, LIST_QUERIES_QUERY, + ROOT_GLOSSARY_NODES_QUERY, ) diff --git a/evaluation/gen_answer.py b/evaluation/gen_answer.py index 65feb91..9848bb6 100644 --- a/evaluation/gen_answer.py +++ b/evaluation/gen_answer.py @@ -1,10 +1,9 @@ +import uuid from argparse import ArgumentParser -from langchain_core.messages import HumanMessage - -from utils import load_question_json, save_answer_json +from langchain_core.messages import HumanMessage from tqdm import tqdm -import uuid +from utils import load_question_json, save_answer_json from llm_utils.graph import builder diff --git a/evaluation/gen_persona.py b/evaluation/gen_persona.py index 4884d63..52ad642 100644 --- a/evaluation/gen_persona.py +++ b/evaluation/gen_persona.py @@ -1,12 +1,12 @@ import os +from argparse import ArgumentParser -from utils import save_persona_json, pretty_print_persona +from langchain_core.prompts import ChatPromptTemplate +from langchain_openai.chat_models import ChatOpenAI from persona_class import PersonaList +from utils import pretty_print_persona, save_persona_json from llm_utils.tools import _get_table_info -from langchain_openai.chat_models import ChatOpenAI -from langchain_core.prompts import ChatPromptTemplate -from argparse import ArgumentParser def get_table_des_string(tables_desc): diff --git a/evaluation/gen_question.py b/evaluation/gen_question.py index 25f9c66..202810f 100644 --- a/evaluation/gen_question.py +++ b/evaluation/gen_question.py @@ -1,10 +1,10 @@ -from utils import load_persona_json, save_question_json +import os +from argparse import ArgumentParser + from langchain_core.prompts import ChatPromptTemplate from langchain_openai.chat_models import ChatOpenAI from tqdm import tqdm - -from argparse import ArgumentParser -import os +from utils import load_persona_json, save_question_json def get_persona_prompt(persona): diff --git a/evaluation/persona_class.py b/evaluation/persona_class.py index d276558..6b34628 100644 --- a/evaluation/persona_class.py +++ b/evaluation/persona_class.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from typing import List +from pydantic import BaseModel + class Persona(BaseModel): name: str diff --git a/evaluation/utils.py b/evaluation/utils.py index f569d7e..9b448d9 100644 --- a/evaluation/utils.py +++ b/evaluation/utils.py @@ -1,7 +1,8 @@ import json -from persona_class import PersonaList -from glob import glob import os +from glob import glob + +from persona_class import PersonaList def save_persona_json(data, filepath): diff --git a/interface/viz_eval.py b/interface/viz_eval.py index 2989ac8..01556ce 100644 --- a/interface/viz_eval.py +++ b/interface/viz_eval.py @@ -1,9 +1,9 @@ -import streamlit as st -import json import glob -import pandas as pd +import json import os +import pandas as pd +import streamlit as st st.set_page_config(layout="wide", page_title="Lang2SQL 평가 시각화") diff --git a/llm_utils/chains.py b/llm_utils/chains.py index 81d957e..319cc12 100644 --- a/llm_utils/chains.py +++ b/llm_utils/chains.py @@ -1,15 +1,16 @@ import os + +from dotenv import load_dotenv from langchain_core.prompts import ( ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, ) -from .llm_factory import get_llm - -from dotenv import load_dotenv from prompt.template_loader import get_prompt_template +from .llm_factory import get_llm + env_path = os.path.join(os.getcwd(), ".env") if os.path.exists(env_path): diff --git a/llm_utils/connect_db.py b/llm_utils/connect_db.py index aa2c099..afea7c9 100644 --- a/llm_utils/connect_db.py +++ b/llm_utils/connect_db.py @@ -1,5 +1,6 @@ import os from typing import Union + import pandas as pd from clickhouse_driver import Client from dotenv import load_dotenv diff --git a/llm_utils/graph.py b/llm_utils/graph.py index a6f5137..cad94f0 100644 --- a/llm_utils/graph.py +++ b/llm_utils/graph.py @@ -1,20 +1,17 @@ -import os import json +import os -from typing_extensions import TypedDict, Annotated +from langchain.chains.sql_database.prompt import SQL_PROMPTS from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages -from langchain.chains.sql_database.prompt import SQL_PROMPTS from pydantic import BaseModel, Field -from .llm_factory import get_llm - -from llm_utils.chains import ( - query_refiner_chain, - query_maker_chain, -) +from typing_extensions import Annotated, TypedDict +from llm_utils.chains import query_maker_chain, query_refiner_chain from llm_utils.tools import get_info_from_db +from .llm_factory import get_llm + # 노드 식별자 정의 QUERY_REFINER = "query_refiner" GET_TABLE_INFO = "get_table_info" diff --git a/llm_utils/llm_factory.py b/llm_utils/llm_factory.py index bdb4d64..d0117e3 100644 --- a/llm_utils/llm_factory.py +++ b/llm_utils/llm_factory.py @@ -4,7 +4,8 @@ from dotenv import load_dotenv from langchain.llms.base import BaseLanguageModel -from langchain_aws import ChatBedrockConverse, BedrockEmbeddings +from langchain_aws import BedrockEmbeddings, ChatBedrockConverse +from langchain_community.llms.bedrock import Bedrock from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings from langchain_huggingface import ( ChatHuggingFace, @@ -13,12 +14,11 @@ ) from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_openai import ( + AzureChatOpenAI, AzureOpenAIEmbeddings, ChatOpenAI, - AzureChatOpenAI, OpenAIEmbeddings, ) -from langchain_community.llms.bedrock import Bedrock # .env 파일 로딩 load_dotenv() diff --git a/llm_utils/prompts_class.py b/llm_utils/prompts_class.py index ceeadbd..dc57647 100644 --- a/llm_utils/prompts_class.py +++ b/llm_utils/prompts_class.py @@ -1,6 +1,6 @@ -from langchain.chains.sql_database.prompt import SQL_PROMPTS import os +from langchain.chains.sql_database.prompt import SQL_PROMPTS from langchain_core.prompts import load_prompt diff --git a/llm_utils/tools.py b/llm_utils/tools.py index 31c2a09..5ccf077 100644 --- a/llm_utils/tools.py +++ b/llm_utils/tools.py @@ -1,11 +1,11 @@ import os -from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar from langchain.schema import Document +from tqdm import tqdm from data_utils.datahub_source import DatahubMetadataFetcher -from tqdm import tqdm -from concurrent.futures import ThreadPoolExecutor T = TypeVar("T") R = TypeVar("R") diff --git a/setup.py b/setup.py index 71a31ac..9851412 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ # setup.py -from setuptools import setup, find_packages +from setuptools import find_packages, setup with open("docs/README.md", "r", encoding="utf-8") as fh: long_description = fh.read()