From d7bc9b7526e835339909cb388a5cc5ca9b70aee2 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 14 Dec 2023 23:09:00 -0500 Subject: [PATCH] x --- .../tool_usage/agents/__init__.py | 10 +- .../tool_usage/agents/anthropic_tool_agent.py | 228 ++++++++++++++++++ .../tool_usage/agents/experimental/encoder.py | 189 +++++++++++++++ .../tool_usage/agents/experimental/prompts.py | 2 +- poetry.lock | 13 +- pyproject.toml | 1 + .../agents/test_anthropic_tool_parsing.py | 17 ++ 7 files changed, 457 insertions(+), 3 deletions(-) create mode 100644 langchain_benchmarks/tool_usage/agents/anthropic_tool_agent.py create mode 100644 tests/unit_tests/agents/test_anthropic_tool_parsing.py diff --git a/langchain_benchmarks/tool_usage/agents/__init__.py b/langchain_benchmarks/tool_usage/agents/__init__.py index 4e9f2896..4ad7c7ee 100644 --- a/langchain_benchmarks/tool_usage/agents/__init__.py +++ b/langchain_benchmarks/tool_usage/agents/__init__.py @@ -1,7 +1,15 @@ from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter +from langchain_benchmarks.tool_usage.agents.anthropic_tool_user import ( + AnthropicToolUserFactory, +) from langchain_benchmarks.tool_usage.agents.experimental.factory import ( CustomAgentFactory, ) from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory -__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"] +__all__ = [ + "OpenAIAgentFactory", + "apply_agent_executor_adapter", + "CustomAgentFactory", + "AnthropicToolUserFactory", +] diff --git a/langchain_benchmarks/tool_usage/agents/anthropic_tool_agent.py b/langchain_benchmarks/tool_usage/agents/anthropic_tool_agent.py new file mode 100644 index 00000000..df698c86 --- /dev/null +++ b/langchain_benchmarks/tool_usage/agents/anthropic_tool_agent.py @@ -0,0 +1,228 @@ +""" +Module contains re-implementation of the anthropic tool agent SDK using +langchain primitives. +""" +import re +from typing import Dict, Optional, Union +from typing import List, Sequence, Tuple + +import xmltodict +from langchain.agents import AgentOutputParser +from langchain.prompts.chat import ChatPromptTemplate +from langchain.pydantic_v1 import BaseModel, Field +from langchain.schema.runnable import Runnable +from langchain.tools import StructuredTool +from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish +from langchain_core.exceptions import OutputParserException +from langchain_core.language_models import BaseChatModel, BaseLanguageModel +from langchain_core.messages import AIMessage +from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.prompts import MessagesPlaceholder +from typing_extensions import NotRequired, TypedDict + +from langchain_benchmarks import RateLimiter +from langchain_benchmarks.rate_limiting import with_rate_limit +from langchain_benchmarks.tool_usage.agents.experimental.encoder import ( + AstPrinter, + FunctionResult, + AnthropicXMLEncoder, +) +from langchain_benchmarks.tool_usage.agents.experimental.prompts import ( + _ANTHROPIC_TOOL_USER_PROMPT, +) +from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import ( + convert_tool_to_function_definition, +) + + +class _ToolInvocationRequest(BaseModel): + """Light-weight pydantic model for validating the raw tool invocation request. + + The purpose of this model, is to make sure that whatever as parsed from + the raw llm output has `tool_name` and potential `arguments` fields, and + nothing else. + """ + + tool_name: str + # OK parameterless tools which do not take arguments + arguments: Optional[Dict] = Field(default_factory=dict) + + +class AnthropicToolParser(AgentOutputParser): + """A generalized parser that makes it easier to parameterize different parsing.""" + + def parse(self, text: str) -> Union[AgentFinish, AgentAction]: + """Parse the output of the agent.""" + wrapping_xml_tag = "function_calls" + open_tag = f"<{wrapping_xml_tag}>" + close_tag = f"" + if open_tag in text: + # This is a hack to make sure that is always present + # in the output if . may be a stop sequence for the + # language model, so depending on implementation + # the stop sequence may be cut off. + # There might be a better way to do this, but this works and + # is simple. + if not self.require_closing_xml_tag: + text += close_tag + + pattern = rf"{open_tag}(?P.*?){close_tag}" + match = re.search(pattern, text, re.DOTALL) + if match: + content = match.group("invocation").strip() + return parse_invocation(content, self.wrapping_xml_tag) + + return AgentFinish( + log=text, + return_values={ + "output": text, + }, + ) + + +def parse_invocation(text: str, tag: str) -> AgentAction: + """Parse the content of the function invocation. + + Args: + text: The text to parse. + tag: The tag that wraps the function invocation request. + + Returns: + An AgentAction that corresponds to the function invocation. + + Raises: + OutputParserException: If the parsing fails. + + This exception is meant to be caught by the agent executor and + handled appropriately to provide feedback to the LLM. + """ + + ai_content = f"<{tag}>{text}\n" + try: + function_calls = xmltodict.parse(ai_content, force_list=("function_calls",)) + except Exception as e: + # Convert this to something controllable by the user. + err_msg = () + + raise OutputParserException( + error=e, + llm_output=ai_content, + observation=err_msg, + send_to_llm=True, + ) + + try: + request = _ToolInvocationRequest.validate(result) + except Exception as e: # Using broad exception since it's not just ValidationError + # Can also raise DictError if result is not a dict. + err_msg = ( + f"ERROR: Please use the format " + f'<{tag}>{{"tool_name": $TOOL_NAME, "arguments": $ARGUMENTS}}\n' + ) + raise OutputParserException( + error=e, + llm_output=ai_content, + send_to_llm=True, + observation=err_msg, + ) + + return AgentActionMessageLog( + message_log=[AIMessage(content=ai_content)], + tool=request.tool_name, + tool_input=request.arguments, + log=f"\nInvoking {request.tool_name}: {request.arguments}\n\t", + ) + + +def format_steps_for_chat( + intermediate_steps: List[Tuple[AgentAction, str]], + ast_printer: AstPrinter, +) -> List[BaseMessage]: + """Format the steps.""" + messages = [] + for action, observation in intermediate_steps: + # Action messages contains the tool invocation request from the LLM + # Now add the result of the tool invocation. + + if action.tool == "_Exception": + messages.append( + AIMessage( + content=action.log, + ) + ) + messages.append( + # Tool input is the error message for the exception + HumanMessage(content=action.tool_input) + ) + else: + messages.extend(action.messages) + function_result: FunctionResult = { + "name": action.tool, + "error": None, + "result": observation, + } + messages.append( + HumanMessage( + content=ast_printer.visit_function_result(function_result), + ) + ) + + return messages + + +# PUBLIC API + + +class AgentInput(TypedDict): + """The input to the agent.""" + + input: str + """The input to the agent.""" + intermediate_steps: List[Tuple[AgentAction, str]] + """The intermediate steps taken by the agent.""" + examples: NotRequired[List[BaseMessage]] + """A list of messages that can be used to form example traces.""" + + +def create_agent( + model: Union[BaseChatModel, BaseLanguageModel], + tools: Sequence[StructuredTool], + parser: AgentOutputParser, + *, + rate_limiter: Optional[RateLimiter] = None, +) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]: + """Create an agent for a chat model.""" + + function_definitions = [convert_tool_to_function_definition(tool) for tool in tools] + ast_printer_ = AnthropicXMLEncoder() + tool_description = ast_printer_.visit_function_definitions(function_definitions) + + template = ChatPromptTemplate.from_messages( + [ + ("system", _ANTHROPIC_TOOL_USER_PROMPT), + MessagesPlaceholder("examples"), # Can use to add example traces + ("human", "{input}"), + MessagesPlaceholder("history"), + ] + ).partial(tool_description=tool_description) + + # For the time being, hard-coding the fact that we're using a tag. + model = model.bind(stop=[""]) + + if rate_limiter: + # Apply a rate limiter if it was provided + model = with_rate_limit(model, rate_limiter) + + agent = ( + { + "input": lambda x: x["input"], + "history": lambda x: format_steps_for_chat( + x["intermediate_steps"], ast_printer_ + ), + "examples": lambda x: x.get("examples", []), + } + | template + | model + | parser + ) + return agent diff --git a/langchain_benchmarks/tool_usage/agents/experimental/encoder.py b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py index c6799609..ba82df96 100644 --- a/langchain_benchmarks/tool_usage/agents/experimental/encoder.py +++ b/langchain_benchmarks/tool_usage/agents/experimental/encoder.py @@ -74,10 +74,20 @@ def visit_function_definitions( def visit_function_invocation(self, function_invocation: FunctionInvocation) -> str: """Render a function invocation.""" + @abc.abstractmethod + def visit_function_invocations( + self, function_invocations: List[FunctionInvocation] + ) -> str: + """Render a function invocation.""" + @abc.abstractmethod def visit_function_result(self, function_result: FunctionResult) -> str: """Render a function result.""" + @abc.abstractmethod + def visit_function_results(self, function_results: List[FunctionResult]) -> str: + """Render a function result.""" + class AstPrinter(Visitor): """Print the AST.""" @@ -154,6 +164,18 @@ def visit_function_invocation(self, invocation: FunctionInvocation) -> str: ) return "\n".join(lines) + def visit_function_invocations( + self, function_invocations: List[FunctionInvocation] + ) -> str: + """Render a function invocation.""" + strs = [ + self.visit_function_invocation(function_invocation) + for function_invocation in function_invocations + ] + return ( + "\n" + "\n".join(strs) + "\n" + ) + def visit_function_result(self, function_result: FunctionResult) -> str: """Render a function result.""" lines = [ @@ -180,6 +202,158 @@ def visit_function_result(self, function_result: FunctionResult) -> str: return "\n".join(lines) + def visit_function_results(self, function_results: List[FunctionResult]) -> str: + """Render a function result.""" + strs = [ + self.visit_function_result(function_result) + for function_result in function_results + ] + return "\n" + "\n".join(strs) + "\n" + + +class AnthropicXMLEncoder(AstPrinter): + """Adapter for Anthropic tool usage api. + + As described here: https://github.com/anthropics/anthropic-tools/tree/main + """ + + def visit_function_definition(self, function_definition: FunctionDefinition) -> str: + """Render a function. + + Function definition example: + + + get_time_of_day + + get_time_of_day(time_zone: str) -> str - Retrieve the current time of day + + Args: + time_zone: The time zone to get the current time for, + + Returns: + time format + + + + time_zone + str + + + + + """ + parameters_lines = [] + + for parameter in function_definition["parameters"]: + parameters_lines.extend( + [ + "", + f"{parameter['name']}", + f"{parameter['type']}", + f"{parameter['description']}", + "", + ] + ) + lines = [ + "", + f"{function_definition['name']}", + "", + f"{function_definition['description']}", + "", + "", + *parameters_lines, + "", + "", + ] + return "\n".join(lines) + + def visit_function_definitions( + self, function_definitions: List[FunctionDefinition] + ) -> str: + """Render a function.""" + strs = [ + self.visit_function_definition(function_definition) + for function_definition in function_definitions + ] + + lines = [ + "", + *strs, + "", + ] + return "\n".join(lines) + + def visit_function_invocation(self, invocation: FunctionInvocation) -> str: + """Render a function invocation. + + + get_time_of_day + + UTC + + + """ + arguments_as_strings = [ + f"<{argument['name']}>{argument['value']}" + for argument in invocation["arguments"] + ] + lines = [ + "", + f"{invocation['name']}", + "", + *arguments_as_strings, + "", + "", + ] + return "\n".join(lines) + + def visit_function_invocations(self, invocations: List[FunctionInvocation]) -> str: + """Render a function invocation.""" + strs = [ + self.visit_function_invocation(invocation) for invocation in invocations + ] + + lines = [ + "", + *strs, + "", + ] + return "\n".join(lines) + + def visit_function_result(self, function_result: FunctionResult) -> str: + """Render a function result. + + + + get_time_of_day + + 02:57:27 + + + + """ + lines = [ + "", + f"{function_result['name']}", + f"{function_result['result']}", + "", + ] + return "\n".join(lines) + + def visit_function_results(self, function_results: List[FunctionResult]) -> str: + """Render a function result.""" + strs = [ + self.visit_function_result(function_result) + for function_result in function_results + ] + + lines = [ + "", + *strs, + "", + ] + return "\n".join(lines) + class TypeScriptEncoder(AstPrinter): def visit_function_definition(self, function_definition: FunctionDefinition) -> str: @@ -238,3 +412,18 @@ def visit_function_result(self, function_result: FunctionResult) -> str: if function_result.get("id"): lines.append(f"// ID: {function_result['id']}") return "\n".join(lines) + + def visit_function_results(self, function_results: List[FunctionResult]) -> str: + """Render a function result.""" + strs = [ + self.visit_function_result(function_result) + for function_result in function_results + ] + return "\n".join(strs) + + def visit_function_invocations(self, invocations: List[FunctionInvocation]) -> str: + """Render a function invocation.""" + strs = [ + self.visit_function_invocation(invocation) for invocation in invocations + ] + return "\n".join(strs) diff --git a/langchain_benchmarks/tool_usage/agents/experimental/prompts.py b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py index 9abc051e..6d7b67c9 100644 --- a/langchain_benchmarks/tool_usage/agents/experimental/prompts.py +++ b/langchain_benchmarks/tool_usage/agents/experimental/prompts.py @@ -1,4 +1,4 @@ -AGENT_INSTRUCTIONS_XML_FORMAT = """\ +_ANTHROPIC_TOOL_USER_PROMPT = """\ In this environment you have access to a set of tools you can use to answer the user's question. You may call them like this: diff --git a/poetry.lock b/poetry.lock index 2ec957bd..e4db049b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3860,6 +3860,17 @@ files = [ {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, ] +[[package]] +name = "xmltodict" +version = "0.13.0" +description = "Makes working with XML feel like you are working with JSON" +optional = false +python-versions = ">=3.4" +files = [ + {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"}, + {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"}, +] + [[package]] name = "y-py" version = "0.6.2" @@ -4083,4 +4094,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "91171e1e590780b3d7df5efcf5eaddddabbe2715294add5ccf14f52cd3fa3b6d" +content-hash = "f01a0553fe50c69a84eb318ce208dcbc61edd208286c22ab294aaac242b508dc" diff --git a/pyproject.toml b/pyproject.toml index f62f52d3..ff1211c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ langsmith = ">=0.0.70" tqdm = "^4" ipywidgets = "^8" tabulate = ">=0.8.0" +xmltodict = "^0.13.0" [tool.poetry.group.dev.dependencies] jupyterlab = "^3.6.1" diff --git a/tests/unit_tests/agents/test_anthropic_tool_parsing.py b/tests/unit_tests/agents/test_anthropic_tool_parsing.py new file mode 100644 index 00000000..f53e66bd --- /dev/null +++ b/tests/unit_tests/agents/test_anthropic_tool_parsing.py @@ -0,0 +1,17 @@ +from langchain_benchmarks.tool_usage.agents.anthropic_tool_agent import parse_invocation +from xmltodict import parse + + +def test_parse_invocation() -> None: + """Test parsing a tool invocation.""" + invocation = parse_invocation( + """ + + get_time_of_day + + UTC + + + """, + "function_calls" + )