diff --git a/src/unstract/sdk/__init__.py b/src/unstract/sdk/__init__.py index 19fbe1db..396122e4 100644 --- a/src/unstract/sdk/__init__.py +++ b/src/unstract/sdk/__init__.py @@ -1,4 +1,4 @@ -__version__ = "v0.70.1" +__version__ = "v0.71.0" def get_sdk_version() -> str: diff --git a/src/unstract/sdk/constants.py b/src/unstract/sdk/constants.py index eb91d08a..3b1df597 100644 --- a/src/unstract/sdk/constants.py +++ b/src/unstract/sdk/constants.py @@ -57,8 +57,7 @@ class LogStage: class LogState: - """State of logs INPUT_UPDATE tag for update the FE input component - OUTPUT_UPDATE tag for update the FE output component.""" + """Tags to update corresponding FE component.""" INPUT_UPDATE = "INPUT_UPDATE" OUTPUT_UPDATE = "OUTPUT_UPDATE" @@ -175,3 +174,10 @@ class UsageKwargs: FILE_NAME = "file_name" WORKFLOW_ID = "workflow_id" EXECUTION_ID = "execution_id" + + +class RequestHeader: + """Keys used in request headers.""" + + REQUEST_ID = "X-Request-ID" + AUTHORIZATION = "Authorization" diff --git a/src/unstract/sdk/platform.py b/src/unstract/sdk/platform.py index dd9c8d7f..1fce27cc 100644 --- a/src/unstract/sdk/platform.py +++ b/src/unstract/sdk/platform.py @@ -1,7 +1,13 @@ from typing import Any import requests -from unstract.sdk.constants import LogLevel, ToolEnv +from requests import ConnectionError, RequestException, Response +from unstract.sdk.constants import ( + MimeType, + PromptStudioKeys, + RequestHeader, + ToolEnv, +) from unstract.sdk.helper import SdkHelper from unstract.sdk.tool.base import BaseTool @@ -18,8 +24,11 @@ def __init__( tool: BaseTool, platform_host: str, platform_port: str, + request_id: str | None = None, ) -> None: - """Args: + """Constructor for base class to connect to platform service. + + Args: tool (AbstractTool): Instance of AbstractTool platform_host (str): Host of platform service platform_port (str): Port of platform service @@ -30,28 +39,115 @@ def __init__( self.tool = tool self.base_url = SdkHelper.get_platform_base_url(platform_host, platform_port) self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY) + self.request_id = request_id class PlatformHelper(PlatformBase): - """Implementation of `UnstractPlatformBase` to interact with platform - service. + """Implementation of `PlatformBase`. Notes: - PLATFORM_SERVICE_API_KEY environment variable is required. """ - def __init__(self, tool: BaseTool, platform_host: str, platform_port: str): - """Constructor of the implementation of `UnstractPlatformBase` + def __init__( + self, + tool: BaseTool, + platform_host: str, + platform_port: str, + request_id: str | None = None, + ) -> None: + """Constructor for helper to connect to platform service. Args: tool (AbstractTool): Instance of AbstractTool platform_host (str): Host of platform service platform_port (str): Port of platform service + request_id (str | None, optional): Request ID for the service. + Defaults to None. """ super().__init__( - tool=tool, platform_host=platform_host, platform_port=platform_port + tool=tool, + platform_host=platform_host, + platform_port=platform_port, + request_id=request_id, ) + def _get_headers(self, headers: dict[str, str] | None = None) -> dict[str, str]: + """Get default headers for requests. + + Returns: + dict[str, str]: Default headers including request ID and authorization + """ + request_headers = { + RequestHeader.REQUEST_ID: self.request_id, + RequestHeader.AUTHORIZATION: f"Bearer {self.bearer_token}", + } + if headers: + request_headers.update(headers) + return request_headers + + def _call_service( + self, + url_path: str, + payload: dict[str, Any] | None = None, + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, + method: str = "GET", + ) -> dict[str, Any]: + """Talks to platform-service to make GET / POST calls. + + Only GET calls are made to platform-service though functionality exists. + + Args: + url_path (str): URL path to the service endpoint + payload (dict, optional): Payload to send in the request body + params (dict, optional): Query parameters to include in the request + headers (dict, optional): Headers to include in the request + method (str): HTTP method to use for the request (GET or POST) + + Returns: + dict: Response from the platform service + + Sample Response: + { + "status": "OK", + "error": "", + structure_output : {} + } + """ + url: str = f"{self.base_url}/{url_path}" + req_headers = self._get_headers(headers) + response: Response = Response() + try: + if method.upper() == "POST": + response = requests.post( + url=url, json=payload, params=params, headers=req_headers + ) + elif method.upper() == "GET": + response = requests.get(url=url, params=params, headers=req_headers) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + except ConnectionError as connect_err: + msg = "Unable to connect to platform service. Please contact admin." + msg += " \n" + str(connect_err) + self.tool.stream_error_and_exit(msg) + except RequestException as e: + # Extract error information from the response if available + error_message = str(e) + content_type = response.headers.get("Content-Type", "").lower() + if MimeType.JSON in content_type: + response_json = response.json() + if "error" in response_json: + error_message = response_json["error"] + elif response.text: + error_message = response.text + self.tool.stream_error_and_exit( + f"Error from platform service. {error_message}" + ) + return response.json() + def get_platform_details(self) -> dict[str, Any] | None: """Obtains platform details associated with the platform key. @@ -60,18 +156,30 @@ def get_platform_details(self) -> dict[str, Any] | None: Returns: Optional[dict[str, Any]]: Dictionary containing the platform details """ - url = f"{self.base_url}/platform_details" - headers = {"Authorization": f"Bearer {self.bearer_token}"} - response = requests.get(url, headers=headers) - if response.status_code != 200: - self.tool.stream_log( - ( - "Error while retrieving platform details: " - f"[{response.status_code}] {response.reason}" - ), - level=LogLevel.ERROR, - ) - return None - else: - platform_details: dict[str, Any] = response.json().get("details") - return platform_details + response = self._call_service( + url_path="platform_details", + payload=None, + params=None, + headers=None, + method="GET", + ) + return response.get("details") + + def get_exported_tool(self, prompt_registry_id: str) -> dict[str, Any]: + """Get exported custom tool by the help of unstract DB tool. + + Args: + prompt_registry_id (str): ID of the prompt_registry_id + tool (AbstractTool): Instance of AbstractTool + Required env variables: + PLATFORM_HOST: Host of platform service + PLATFORM_PORT: Port of platform service + """ + query_params = {PromptStudioKeys.PROMPT_REGISTRY_ID: prompt_registry_id} + return self._call_service( + url_path="custom_tool_instance", + payload=None, + params=query_params, + headers=None, + method="GET", + ) diff --git a/src/unstract/sdk/prompt.py b/src/unstract/sdk/prompt.py index 4d94ae93..c21d4184 100644 --- a/src/unstract/sdk/prompt.py +++ b/src/unstract/sdk/prompt.py @@ -1,15 +1,63 @@ +import functools import logging -from typing import Any +from collections.abc import Callable +from typing import Any, ParamSpec, TypeVar import requests +from deprecated import deprecated from requests import ConnectionError, RequestException, Response -from unstract.sdk.constants import LogLevel, MimeType, PromptStudioKeys, ToolEnv +from unstract.sdk.constants import ( + MimeType, + RequestHeader, + ToolEnv, +) from unstract.sdk.helper import SdkHelper +from unstract.sdk.platform import PlatformHelper from unstract.sdk.tool.base import BaseTool from unstract.sdk.utils.common_utils import log_elapsed logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") + + +def handle_service_exceptions(context: str) -> Callable[[Callable[P, R]], Callable[P, R]]: + """Decorator to handle exceptions in PromptTool service calls. + + Args: + context (str): Context string describing where the error occurred + Returns: + Callable: Decorated function that handles service exceptions + """ + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + try: + return func(*args, **kwargs) + except ConnectionError as e: + msg = f"Error while {context}. Unable to connect to prompt service." + logger.error(f"{msg}\n{e}") + args[0].tool.stream_error_and_exit(msg, e) + except RequestException as e: + error_message = str(e) + response = getattr(e, "response", None) + if response is not None: + if ( + MimeType.JSON in response.headers.get("Content-Type", "").lower() + and "error" in response.json() + ): + error_message = response.json()["error"] + elif response.text: + error_message = response.text + msg = f"Error while {context}. {error_message}" + args[0].tool.stream_error_and_exit(msg, e) + + return wrapper + + return decorator + class PromptTool: """Class to handle prompt service methods for Unstract Tools.""" @@ -20,6 +68,7 @@ def __init__( prompt_host: str, prompt_port: str, is_public_call: bool = False, + request_id: str | None = None, ) -> None: """Class to interact with prompt-service. @@ -32,10 +81,12 @@ def __init__( self.tool = tool self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port) self.is_public_call = is_public_call + self.request_id = request_id if not is_public_call: self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY) @log_elapsed(operation="ANSWER_PROMPTS") + @handle_service_exceptions("answering prompt(s)") def answer_prompt( self, payload: dict[str, Any], @@ -45,28 +96,31 @@ def answer_prompt( url_path = "answer-prompt" if self.is_public_call: url_path = "answer-prompt-public" - return self._post_call( + return self._call_service( url_path=url_path, payload=payload, params=params, headers=headers ) @log_elapsed(operation="INDEX") + @handle_service_exceptions("indexing") def index( self, payload: dict[str, Any], params: dict[str, str] | None = None, headers: dict[str, str] | None = None, - ) -> dict[str, Any]: + ) -> str: url_path = "index" if self.is_public_call: url_path = "index-public" - return self._post_call( + prompt_service_response = self._call_service( url_path=url_path, payload=payload, params=params, headers=headers, ) + return prompt_service_response.get("doc_id") @log_elapsed(operation="EXTRACT") + @handle_service_exceptions("extracting") def extract( self, payload: dict[str, Any], @@ -76,120 +130,100 @@ def extract( url_path = "extract" if self.is_public_call: url_path = "extract-public" - return self._post_call( + prompt_service_response = self._call_service( url_path=url_path, payload=payload, params=params, headers=headers, ) + return prompt_service_response.get("extracted_text") + @log_elapsed(operation="SINGLE_PASS_EXTRACTION") + @handle_service_exceptions("single pass extraction") def single_pass_extraction( self, payload: dict[str, Any], params: dict[str, str] | None = None, headers: dict[str, str] | None = None, ) -> dict[str, Any]: - return self._post_call( + return self._call_service( url_path="single-pass-extraction", payload=payload, params=params, headers=headers, ) + @log_elapsed(operation="SUMMARIZATION") + @handle_service_exceptions("summarizing") def summarize( self, payload: dict[str, Any], params: dict[str, str] | None = None, headers: dict[str, str] | None = None, ) -> dict[str, Any]: - return self._post_call( + return self._call_service( url_path="summarize", payload=payload, params=params, headers=headers, ) - def _post_call( + def _get_headers(self, headers: dict[str, str] | None = None) -> dict[str, str]: + """Get default headers for requests. + + Returns: + dict[str, str]: Default headers including request ID and authorization + """ + request_headers = {RequestHeader.REQUEST_ID: self.request_id} + if self.is_public_call: + return request_headers + request_headers.update( + {RequestHeader.AUTHORIZATION: f"Bearer {self.bearer_token}"} + ) + + if headers: + request_headers.update(headers) + return request_headers + + def _call_service( self, url_path: str, - payload: dict[str, Any], + payload: dict[str, Any] | None = None, params: dict[str, str] | None = None, headers: dict[str, str] | None = None, + method: str = "POST", ) -> dict[str, Any]: """Communicates to prompt service to fetch response for the prompt. + Only POST calls are made to prompt-service though functionality exists. + Args: url_path (str): URL path to the service endpoint - payload (dict): Payload to send in the request body + payload (dict, optional): Payload to send in the request body params (dict, optional): Query parameters to include in the request headers (dict, optional): Headers to include in the request + method (str): HTTP method to use for the request (GET or POST) Returns: dict: Response from the prompt service - - Sample Response: - { - "status": "OK", - "error": "", - "cost": 0, - structure_output : {} - } """ - result: dict[str, Any] = { - "status": "ERROR", - "error": "", - "cost": 0, - "structure_output": "", - "status_code": 500, - } url: str = f"{self.base_url}/{url_path}" - - default_headers = {} - - if not self.is_public_call: - default_headers = {"Authorization": f"Bearer {self.bearer_token}"} - - if headers: - default_headers.update(headers) - + req_headers = self._get_headers(headers) response: Response = Response() - try: + if method.upper() == "POST": response = requests.post( - url=url, json=payload, params=params, headers=default_headers + url=url, json=payload, params=params, headers=req_headers ) - response.raise_for_status() - result["status"] = "OK" - result["structure_output"] = response.text - result["status_code"] = 200 - except ConnectionError as connect_err: - msg = "Unable to connect to prompt service. Please contact admin." - self._stringify_and_stream_err(connect_err, msg) - result["error"] = msg - except RequestException as e: - # Extract error information from the response if available - error_message = str(e) - content_type = response.headers.get("Content-Type", "").lower() - if MimeType.JSON in content_type: - response_json = response.json() - if "error" in response_json: - error_message = response_json["error"] - elif response.text: - error_message = response.text - result["error"] = error_message - result["status_code"] = response.status_code - self.tool.stream_log( - f"Error while fetching response for prompt: {result['error']}", - level=LogLevel.ERROR, - ) - return result + elif method.upper() == "GET": + response = requests.get(url=url, params=params, headers=req_headers) + else: + raise ValueError(f"Unsupported HTTP method: {method}") - def _stringify_and_stream_err(self, err: RequestException, msg: str) -> None: - error_message = str(err) - trace = f"{msg}: {error_message}" - self.tool.stream_log(trace, level=LogLevel.ERROR) - logger.error(trace) + response.raise_for_status() + return response.json() @staticmethod + @deprecated(version="v0.71.0", reason="Use remote FS APIs from SDK") def get_exported_tool( tool: BaseTool, prompt_registry_id: str ) -> dict[str, Any] | None: @@ -202,25 +236,9 @@ def get_exported_tool( PLATFORM_HOST: Host of platform service PLATFORM_PORT: Port of platform service """ - platform_host = tool.get_env_or_die(ToolEnv.PLATFORM_HOST) - platform_port = tool.get_env_or_die(ToolEnv.PLATFORM_PORT) - base_url = SdkHelper.get_platform_base_url(platform_host, platform_port) - bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY) - url = f"{base_url}/custom_tool_instance" - query_params = {PromptStudioKeys.PROMPT_REGISTRY_ID: prompt_registry_id} - headers = {"Authorization": f"Bearer {bearer_token}"} - response = requests.get(url, headers=headers, params=query_params) - if response.status_code == 200: - return response.json() - elif response.status_code == 404: - tool.stream_error_and_exit( - f"Exported tool '{prompt_registry_id}' is not found" - ) - return None - else: - tool.stream_error_and_exit( - f"Error while retrieving tool metadata " - "for the exported tool: " - f"{prompt_registry_id} / {response.reason}" - ) - return None + platform_helper: PlatformHelper = PlatformHelper( + tool=tool, + platform_port=tool.get_env_or_die(ToolEnv.PLATFORM_PORT), + platform_host=tool.get_env_or_die(ToolEnv.PLATFORM_HOST), + ) + return platform_helper.get_exported_tool(prompt_registry_id=prompt_registry_id) diff --git a/src/unstract/sdk/static/tool_template/v1/Dockerfile b/src/unstract/sdk/static/tool_template/v1/Dockerfile index de8b2a49..2b67d9ac 100644 --- a/src/unstract/sdk/static/tool_template/v1/Dockerfile +++ b/src/unstract/sdk/static/tool_template/v1/Dockerfile @@ -1,7 +1,7 @@ FROM python:3.9-slim LABEL maintainer="Zipstack Inc." -ENV UNSTRACT_ENTRYPOINT "python /app/src/main.py" + # Set the working directory in the container WORKDIR /app diff --git a/src/unstract/sdk/tool/executor.py b/src/unstract/sdk/tool/executor.py index 24ccfa50..c791a157 100644 --- a/src/unstract/sdk/tool/executor.py +++ b/src/unstract/sdk/tool/executor.py @@ -17,6 +17,11 @@ class ToolExecutor: """Takes care of executing a tool's intended command.""" def __init__(self, tool: BaseTool) -> None: + """Constructor for executor. + + Args: + tool (AbstractTool): Instance of AbstractTool + """ self.tool = tool def execute(self, args: argparse.Namespace) -> None: @@ -62,7 +67,7 @@ def execute_run(self, args: argparse.Namespace) -> None: settings = validator.validate_pre_execution(settings=settings) self.tool.stream_log( - f"Executing for file: {self.tool.get_exec_metadata['source_name']}, " + f"Executing for file: '{self.tool.get_exec_metadata['source_name']}', " f"with tool settings: {settings}" ) diff --git a/src/unstract/sdk/tool/stream.py b/src/unstract/sdk/tool/stream.py index 4fb2c531..0409c8f1 100644 --- a/src/unstract/sdk/tool/stream.py +++ b/src/unstract/sdk/tool/stream.py @@ -8,7 +8,6 @@ from unstract.sdk.constants import Command, LogLevel, LogStage, ToolEnv from unstract.sdk.exceptions import SdkError from unstract.sdk.utils import Utils -from unstract.sdk.utils.common_utils import UNSTRACT_TO_PY_LOG_LEVEL class StreamMixin: @@ -53,7 +52,8 @@ def _configure_logger(self) -> None: if rootlogger.hasHandlers(): return handler = logging.StreamHandler() - handler.setLevel(level=UNSTRACT_TO_PY_LOG_LEVEL[self.log_level]) + py_log_level = getattr(logging, self.log_level.value, logging.INFO) + handler.setLevel(level=py_log_level) # Determine if OpenTelemetry trace context should be included in logs otel_trace_context = ( @@ -70,7 +70,6 @@ def _configure_logger(self) -> None: ) ) rootlogger.addHandler(handler) - rootlogger.setLevel(level=UNSTRACT_TO_PY_LOG_LEVEL[self.log_level]) noisy_lib_list = [ "asyncio", @@ -122,17 +121,18 @@ def stream_log( } print(json.dumps(record)) - def stream_error_and_exit(self, message: str) -> None: + def stream_error_and_exit(self, message: str, err: Exception | None = None) -> None: """Stream error log and exit. Args: message (str): Error message + err (Exception): Actual exception that occurred """ self.stream_log(message, level=LogLevel.ERROR) if self._exec_by_tool: exit(1) else: - raise SdkError(f"SDK Error: {message}") + raise SdkError(message, actual_err=err) def get_env_or_die(self, env_key: str) -> str: """Returns the value of an env variable. diff --git a/src/unstract/sdk/utils/common_utils.py b/src/unstract/sdk/utils/common_utils.py index 364135c9..fceb8cca 100644 --- a/src/unstract/sdk/utils/common_utils.py +++ b/src/unstract/sdk/utils/common_utils.py @@ -60,14 +60,6 @@ class CommonUtils(Utils): logging.ERROR: LogLevel.ERROR, } -# Mapping from Unstract log level to python counterpart -UNSTRACT_TO_PY_LOG_LEVEL = { - LogLevel.DEBUG: logging.DEBUG, - LogLevel.INFO: logging.INFO, - LogLevel.WARN: logging.WARNING, - LogLevel.ERROR: logging.ERROR, -} - def log_elapsed(operation): """Adds an elapsed time log.