Initialize Agent and Memory implementations

This will pave our way to add multiple functions from langchain loading modules
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-04-02 10:50:10 -03:00 committed by GitHub
commit c4faf3f383
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 407 additions and 133 deletions

View file

@ -8,6 +8,7 @@ agents:
- ZeroShotAgent
- JsonAgent
- CSVAgent
- initialize_agent
prompts:
- PromptTemplate
@ -15,7 +16,7 @@ prompts:
llms:
- OpenAI
- OpenAIChat
- ChatOpenAI
tools:
- Search
@ -33,13 +34,16 @@ toolkits:
- OpenAPIToolkit
- JsonToolkit
embeddings:
#
memories:
- ConversationBufferMemory
vectorstores:
#
embeddings: []
vectorstores: []
documentloaders: []
documentloaders:
#
dev: false

View file

@ -1,9 +1,13 @@
from langflow.template import nodes
CUSTOM_NODES = {
"prompts": {**nodes.ZeroShotPromptNode().to_dict()},
"tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()},
"agents": {**nodes.JsonAgentNode().to_dict(), **nodes.CSVAgentNode().to_dict()},
"prompts": {"ZeroShotPrompt": nodes.ZeroShotPromptNode()},
"tools": {"PythonFunction": nodes.PythonFunctionNode(), "Tool": nodes.ToolNode()},
"agents": {
"JsonAgent": nodes.JsonAgentNode(),
"CSVAgent": nodes.CSVAgentNode(),
"initialize_agent": nodes.InitializeAgentNode(),
},
}

View file

@ -121,10 +121,10 @@ class Node:
f"Required input {key} for module {self.node_type} not found"
)
elif value["list"]:
if key in params:
if key not in params:
params[key] = []
if edge is not None:
params[key].append(edge.source)
else:
params[key] = [edge.source]
elif value["required"] or edge is not None:
params[key] = edge.source
elif value["required"] or value.get("value"):
@ -179,7 +179,9 @@ class Node:
params=self.params,
)
except Exception as exc:
raise ValueError(f"Error building node {self.node_type}") from exc
raise ValueError(
f"Error building node {self.node_type}: {str(exc)}"
) from exc
if self._built_object is None:
raise ValueError(f"Node type {self.node_type} not found")

View file

@ -106,7 +106,10 @@ class Graph:
if node_type in prompt_creator.to_list():
nodes.append(PromptNode(node))
elif node_type in agent_creator.to_list():
elif (
node_type in agent_creator.to_list()
or node_lc_type in agent_creator.to_list()
):
nodes.append(AgentNode(node))
elif node_type in chain_creator.to_list():
nodes.append(ChainNode(node))
@ -118,7 +121,10 @@ class Graph:
nodes.append(ToolkitNode(node))
elif node_type in wrapper_creator.to_list():
nodes.append(WrapperNode(node))
elif node_type in llm_creator.to_list():
elif (
node_type in llm_creator.to_list()
or node_lc_type in llm_creator.to_list()
):
nodes.append(LLMNode(node))
else:
nodes.append(Node(node))

View file

@ -31,12 +31,18 @@ class AgentCreator(LangChainTypeCreator):
except ValueError as exc:
raise ValueError("Agent not found") from exc
# Now this is a generator
def to_list(self) -> List[str]:
return [
agent.__name__
for agent in self.type_to_loader_dict.values()
if agent.__name__ in settings.agents or settings.dev
]
names = []
for name, agent in self.type_to_loader_dict.items():
agent_name = (
agent.function_name()
if hasattr(agent, "function_name")
else agent.__name__
)
if agent_name in settings.agents or settings.dev:
names.append(agent_name)
return names
agent_creator = AgentCreator()

View file

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, List, Optional
from langchain import LLMChain
from langchain.agents import AgentExecutor, ZeroShotAgent
@ -8,12 +8,19 @@ from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFI
from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.schema import BaseLanguageModel
from langchain.llms.base import BaseLLM
from langchain.tools.python.tool import PythonAstREPLTool
from langchain.agents import initialize_agent, Tool
from langchain.memory.chat_memory import BaseChatMemory
class JsonAgent(AgentExecutor):
"""Json agent"""
@staticmethod
def function_name():
return "JsonAgent"
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
@ -46,6 +53,10 @@ class JsonAgent(AgentExecutor):
class CSVAgent(AgentExecutor):
"""CSV agent"""
@staticmethod
def function_name():
return "CSVAgent"
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
@ -87,7 +98,28 @@ class CSVAgent(AgentExecutor):
return super().run(*args, **kwargs)
class InitializeAgent(AgentExecutor):
"""Implementation of initialize_agent function"""
@staticmethod
def function_name():
return "initialize_agent"
@classmethod
def initialize(
cls, llm: BaseLLM, tools: List[Tool], agent: str, memory: BaseChatMemory
):
return initialize_agent(tools=tools, llm=llm, agent=agent, memory=memory)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
CUSTOM_AGENTS = {
"JsonAgent": JsonAgent,
"CSVAgent": CSVAgent,
"initialize_agent": InitializeAgent,
}

View file

@ -0,0 +1,45 @@
from langchain import LLMChain
from langchain.agents import AgentExecutor, ZeroShotAgent
from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.schema import BaseLanguageModel
class MalfoyAgent(AgentExecutor):
"""Json agent"""
prefix = "Malfoy: "
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_toolkit_and_llm(cls, toolkit: JsonToolkit, llm: BaseLanguageModel):
tools = toolkit.get_tools()
tool_names = [tool.name for tool in tools]
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=JSON_PREFIX,
suffix=JSON_SUFFIX,
format_instructions=FORMAT_INSTRUCTIONS,
input_variables=None,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
)
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
PREBUILT_AGENTS = {
"MalfoyAgent": MalfoyAgent,
}

View file

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
@ -20,7 +20,7 @@ class LangChainTypeCreator(BaseModel, ABC):
return self.type_dict
@abstractmethod
def get_signature(self, name: str) -> Optional[Dict[Any, Any]]:
def get_signature(self, name: str) -> Union[Optional[Dict[Any, Any]], FrontendNode]:
pass
@abstractmethod
@ -42,6 +42,8 @@ class LangChainTypeCreator(BaseModel, ABC):
signature = self.get_signature(name)
if signature is None:
raise ValueError(f"{name} not found")
if isinstance(signature, FrontendNode):
return signature
fields = [
TemplateField(
name=key,

View file

@ -6,15 +6,8 @@ from langchain.agents import agent_toolkits
from langchain.chat_models import ChatOpenAI
## Memory
# from langchain.memory.buffer_window import ConversationBufferWindowMemory
# from langchain.memory.chat_memory import ChatMessageHistory
# from langchain.memory.combined import CombinedMemory
# from langchain.memory.entity import ConversationEntityMemory
# from langchain.memory.kg import ConversationKGMemory
# from langchain.memory.readonly import ReadOnlySharedMemory
# from langchain.memory.simple import SimpleMemory
# from langchain.memory.summary import ConversationSummaryMemory
# from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from langchain import memory
## Document Loaders
from langchain.document_loaders import (
AirbyteJSONLoader,
@ -104,23 +97,6 @@ llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Memory
memory_type_to_cls_dict: dict[str, Any] = {
# "CombinedMemory": CombinedMemory,
# "ConversationBufferWindowMemory": ConversationBufferWindowMemory,
# "ConversationBufferMemory": ConversationBufferMemory,
# "SimpleMemory": SimpleMemory,
# "ConversationSummaryBufferMemory": ConversationSummaryBufferMemory,
# "ConversationKGMemory": ConversationKGMemory,
# "ConversationEntityMemory": ConversationEntityMemory,
# "ConversationSummaryMemory": ConversationSummaryMemory,
# "ChatMessageHistory": ChatMessageHistory,
# "ConversationStringBufferMemory": ConversationStringBufferMemory,
# "ReadOnlySharedMemory": ReadOnlySharedMemory,
}
## Chain
# from langchain.chains.loading import type_to_loader_dict
# from langchain.chains.conversation.base import ConversationChain
@ -142,6 +118,14 @@ toolkit_type_to_cls_dict: dict[str, Any] = {
if not toolkit_name.islower()
}
## Memory
memory_type_to_cls_dict: dict[str, Any] = {
memory_name: import_class(f"langchain.memory.{memory_name}")
for memory_name in memory.__all__
}
wrapper_type_to_cls_dict: dict[str, Any] = {
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]

View file

@ -8,7 +8,7 @@ from langchain.agents import Agent
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM
from langchain.tools import BaseTool
from langchain.chat_models.base import BaseChatModel
from langflow.interface.tools.util import get_tool_by_name
@ -31,13 +31,30 @@ def import_by_type(_type: str, name: str) -> Any:
func_dict = {
"agents": import_agent,
"prompts": import_prompt,
"llms": import_llm,
"llms": {"llm": import_llm, "chat": import_chat_llm},
"tools": import_tool,
"chains": import_chain,
"toolkits": import_toolkit,
"wrappers": import_wrapper,
"memory": import_memory,
}
return func_dict[_type](name)
if _type == "llms":
key = "chat" if "chat" in name.lower() else "llm"
loaded_func = func_dict[_type][key] # type: ignore
else:
loaded_func = func_dict[_type]
return loaded_func(name)
def import_chat_llm(llm: str) -> BaseChatModel:
"""Import chat llm from llm name"""
return import_class(f"langchain.chat_models.{llm}")
def import_memory(memory: str) -> Any:
"""Import memory from memory name"""
return import_module(f"from langchain.memory import {memory}")
def import_class(class_path: str) -> Any:

View file

@ -0,0 +1,65 @@
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.template.base import Template, TemplateField
from langflow.template.nodes import PromptTemplateNode
from pydantic import root_validator
CHARACTER_PROMPT = """I want you to act like {character} from {series}.
I want you to respond and answer like {character}. do not write any explanations. only answer like {character}.
You must know all of the knowledge of {character}."""
class BaseCustomPrompt(PromptTemplate):
template: str = ""
description: Optional[str]
human_text: str = "\n {input}"
@root_validator(pre=False)
def build_template(cls, values):
format_dict = {}
for key in values.get("input_variables", []):
new_value = values[key]
format_dict[key] = new_value
values["template"] = values["template"].format(**format_dict)
values["template"] = values["template"] + values["human_text"]
values["input_variables"] = extract_input_variables_from_prompt(
values["template"]
)
return values
def build_frontend_node(self) -> PromptTemplateNode:
return PromptTemplateNode(
template=Template(
type_name="test",
fields=[
TemplateField(name=field, field_type="str", required=True)
for field in self.input_variables
],
),
description=self.description or "",
)
class SeriesCharacterPrompt(BaseCustomPrompt):
# Add a very descriptive description for the prompt generator
description: Optional[
str
] = "A prompt that asks the AI to act like a character from a series."
character: str
series: str
human_text: str = "\n {input}"
template: str = CHARACTER_PROMPT
input_variables: List[str] = ["character", "series"]
if __name__ == "__main__":
prompt = SeriesCharacterPrompt(character="Walter White", series="Breaking Bad")
user_input = "I am the one who knocks"
full_prompt = prompt.format(input=user_input)
print(full_prompt)

View file

@ -57,7 +57,14 @@ def get_result_and_thought_using_graph(loaded_langchain, message: str):
loaded_langchain.verbose = True
try:
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
result = loaded_langchain(message)
chat_input = {}
for key in loaded_langchain.input_keys:
if key != "chat_history":
chat_input[key] = message
break
if hasattr(loaded_langchain, "run"):
loaded_langchain = loaded_langchain.run
result = loaded_langchain
result = (
result.get(loaded_langchain.output_keys[0])

View file

@ -1,8 +1,9 @@
from abc import ABC
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
from pydantic import BaseModel
from langflow.template.constants import FORCE_SHOW_FIELDS
from langflow.utils import constants
@ -20,8 +21,6 @@ class TemplateFieldCreator(BaseModel, ABC):
content: Union[str, None] = None
password: bool = False
options: list[str] = []
# _name will be used to store the name of the field
# in the template
name: str = ""
def to_dict(self):
@ -53,49 +52,37 @@ class TemplateFieldCreator(BaseModel, ABC):
if "List" in _type:
_type = _type.replace("List[", "")[:-1]
self.is_list = True
else:
self.is_list = False
# Replace 'Mapping' with 'dict'
if "Mapping" in _type:
_type = _type.replace("Mapping", "dict")
# Change type from str to Tool
self.field_type = "Tool" if key in ["allowed_tools"] else _type
self.field_type = "Tool" if key in {"allowed_tools"} else self.field_type
self.field_type = "int" if key in ["max_value_length"] else self.field_type
self.field_type = "int" if key in {"max_value_length"} else self.field_type
# Show or not field
self.show = bool(
(self.required and key not in ["input_variables"])
or key
in [
"allowed_tools",
"memory",
"prefix",
"examples",
"temperature",
"model_name",
"headers",
"max_value_length",
]
or key in FORCE_SHOW_FIELDS
or "api_key" in key
)
# Add password field
self.password = any(
text in key.lower() for text in ["password", "token", "api", "key"]
text in key.lower() for text in {"password", "token", "api", "key"}
)
# Add multline
self.multiline = key in [
self.multiline = key in {
"suffix",
"prefix",
"template",
"examples",
"code",
"headers",
]
}
# Replace dict type with str
if "dict" in self.field_type.lower():
@ -118,7 +105,7 @@ class TemplateFieldCreator(BaseModel, ABC):
if name == "OpenAI" and key == "model_name":
self.options = constants.OPENAI_MODELS
self.is_list = True
elif name == "OpenAIChat" and key == "model_name":
elif name == "ChatOpenAI" and key == "model_name":
self.options = constants.CHAT_OPENAI_MODELS
self.is_list = True
@ -131,13 +118,17 @@ class Template(BaseModel):
type_name: str
fields: list[TemplateField]
def process_fields(self, name: Optional[str] = None) -> None:
for field in self.fields:
signature = field.to_dict()
field.process_field(field.name, signature, name)
def process_fields(
self,
name: Optional[str] = None,
format_field_func: Union[Callable, None] = None,
):
if format_field_func:
for field in self.fields:
format_field_func(field, name)
def to_dict(self):
self.process_fields(self.type_name)
def to_dict(self, format_field_func=None):
self.process_fields(self.type_name, format_field_func)
result = {field.name: field.to_dict() for field in self.fields}
result["_type"] = self.type_name # type: ignore
return result
@ -152,8 +143,79 @@ class FrontendNode(BaseModel):
def to_dict(self):
return {
self.name: {
"template": self.template.to_dict(),
"template": self.template.to_dict(self.format_field),
"description": self.description,
"base_classes": self.base_classes,
}
}
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
key = field.name
value = field.to_dict()
_type = value["type"]
# Remove 'Optional' wrapper
if "Optional" in _type:
_type = _type.replace("Optional[", "")[:-1]
# Check for list type
if "List" in _type:
_type = _type.replace("List[", "")[:-1]
field.is_list = True
# Replace 'Mapping' with 'dict'
if "Mapping" in _type:
_type = _type.replace("Mapping", "dict")
# Change type from str to Tool
field.field_type = "Tool" if key in {"allowed_tools"} else field.field_type
field.field_type = "int" if key in {"max_value_length"} else field.field_type
# Show or not field
field.show = bool(
(field.required and key not in ["input_variables"])
or key in FORCE_SHOW_FIELDS
or "api_key" in key
)
# Add password field
field.password = any(
text in key.lower() for text in {"password", "token", "api", "key"}
)
# Add multline
field.multiline = key in {
"suffix",
"prefix",
"template",
"examples",
"code",
"headers",
}
# Replace dict type with str
if "dict" in field.field_type.lower():
field.field_type = "code"
if key == "dict_":
field.field_type = "file"
field.suffixes = [".json", ".yaml", ".yml"]
field.file_types = ["json", "yaml", "yml"]
# Replace default value with actual value
if "default" in value:
field.value = value["default"]
if key == "headers":
field.value = """{'Authorization':
'Bearer <token>'}"""
# Add options to openai
if name == "OpenAI" and key == "model_name":
field.options = constants.OPENAI_MODELS
field.is_list = True
elif name == "ChatOpenAI" and key == "model_name":
field.options = constants.CHAT_OPENAI_MODELS
field.is_list = True

View file

@ -0,0 +1,11 @@
FORCE_SHOW_FIELDS = [
"allowed_tools",
"memory",
"prefix",
"examples",
"temperature",
"model_name",
"headers",
"max_value_length",
"max_tokens",
]

View file

@ -1,7 +1,9 @@
from typing import Optional
from langchain.agents.mrkl import prompt
from langflow.template.base import FrontendNode, Template, TemplateField
from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION
from langchain.agents import loading
class ZeroShotPromptNode(FrontendNode):
@ -48,6 +50,16 @@ class ZeroShotPromptNode(FrontendNode):
return super().to_dict()
class PromptTemplateNode(FrontendNode):
name: str = "PromptTemplate"
template: Template
description: str
base_classes: list[str] = ["BasePromptTemplate"]
def to_dict(self):
return super().to_dict()
class PythonFunctionNode(FrontendNode):
name: str = "PythonFunction"
template: Template = Template(
@ -141,6 +153,53 @@ class JsonAgentNode(FrontendNode):
return super().to_dict()
class InitializeAgentNode(FrontendNode):
name: str = "initialize_agent"
template: Template = Template(
type_name="initailize_agent",
fields=[
TemplateField(
field_type="str",
required=True,
is_list=True,
show=True,
multiline=False,
options=list(loading.AGENT_TO_CLASS.keys()),
name="agent",
),
TemplateField(
field_type="BaseChatMemory",
required=False,
show=True,
name="memory",
),
TemplateField(
field_type="Tool",
required=False,
show=True,
name="tools",
is_list=True,
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
),
],
)
description: str = """Construct a json agent from an LLM and tools."""
base_classes: list[str] = ["AgentExecutor"]
def to_dict(self):
return super().to_dict()
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
# do nothing and don't return anything
pass
class CSVAgentNode(FrontendNode):
name: str = "CSVAgent"
template: Template = Template(

View file

@ -3,6 +3,7 @@ import inspect
import re
from typing import Dict, Optional
from langflow.template.constants import FORCE_SHOW_FIELDS
from langflow.utils import constants
@ -284,17 +285,7 @@ def format_dict(d, name: Optional[str] = None):
# Show or not field
value["show"] = bool(
(value["required"] and key not in ["input_variables"])
or key
in [
"allowed_tools",
"memory",
"prefix",
"examples",
"temperature",
"model_name",
"headers",
"max_value_length",
]
or key in FORCE_SHOW_FIELDS
or "api_key" in key
)
@ -336,7 +327,7 @@ def format_dict(d, name: Optional[str] = None):
if name == "OpenAI" and key == "model_name":
value["options"] = constants.OPENAI_MODELS
value["list"] = True
elif name == "OpenAIChat" and key == "model_name":
elif name == "ChatOpenAI" and key == "model_name":
value["options"] = constants.CHAT_OPENAI_MODELS
value["list"] = True

View file

@ -5,7 +5,7 @@ import { DropDownComponentType } from "../../types/components";
import { classNames } from "../../utils";
export default function Dropdown({value, options, onSelect}:DropDownComponentType) {
let [internalValue,setInternalValue] = useState(value??"choose an option")
let [internalValue,setInternalValue] = useState(value??"Choose an option")
return (
<>
<Listbox value={internalValue} onChange={(value)=>{

View file

@ -267,7 +267,7 @@
"y": 514.9920887988924
},
"data": {
"type": "OpenAIChat",
"type": "ChatOpenAI",
"node": {
"template": {
"cache": {
@ -365,7 +365,7 @@
"type": "bool",
"list": false
},
"_type": "OpenAIChat"
"_type": "ChatOpenAI"
},
"description": "Wrapper around OpenAI Chat large language models.To use, you should have the ``openai`` python package installed, and theenvironment variable ``OPENAI_API_KEY`` set with your API key.Any parameters that are valid to be passed to the openai.create call can be passedin, even if not explicitly saved on this class.",
"base_classes": [
@ -423,7 +423,7 @@
},
{
"source": "dndnode_36",
"sourceHandle": "OpenAIChat|dndnode_36|BaseLanguageModel|BaseLLM",
"sourceHandle": "ChatOpenAI|dndnode_36|BaseLanguageModel|BaseLLM",
"target": "dndnode_33",
"targetHandle": "BaseLanguageModel|llm|dndnode_33",
"className": "animate-pulse",

View file

@ -340,43 +340,21 @@ def test_build_params(basic_graph):
assert isinstance(llm_node.params["model_name"], str)
def test_build(basic_graph, complex_graph):
def test_build(basic_graph, complex_graph, openapi_graph):
"""Test Node's build method"""
# def build(self):
# # The params dict is used to build the module
# # it contains values and keys that point to nodes which
# # have their own params dict
# # When build is called, we iterate through the params dict
# # and if the value is a node, we call build on that node
# # and use the output of that build as the value for the param
# # if the value is not a node, then we use the value as the param
# # and continue
# # Another aspect is that the node_type is the class that we need to import
# # and instantiate with these built params
assert_agent_was_built(basic_graph)
assert_agent_was_built(complex_graph)
assert_agent_was_built(openapi_graph)
# # Build each node in the params dict
# for key, value in self.params.items():
# if isinstance(value, Node):
# self.params[key] = value.build()
# # Get the class from LANGCHAIN_TYPES_DICT
# # and instantiate it with the params
# # and return the instance
# return LANGCHAIN_TYPES_DICT[self.node_type](**self.params)
assert isinstance(basic_graph, Graph)
def assert_agent_was_built(graph):
"""Assert that the agent was built"""
assert isinstance(graph, Graph)
# Now we test the build method
# Build the Agent
agent = basic_graph.build()
result = graph.build()
# The agent should be a AgentExecutor
assert isinstance(agent, AgentExecutor)
# Now we test the complex example
assert isinstance(complex_graph, Graph)
# Now we test the build method
agent = complex_graph.build()
# The agent should be a AgentExecutor
assert isinstance(agent, AgentExecutor)
assert isinstance(result, AgentExecutor)
def test_agent_node_build(basic_graph):
@ -384,7 +362,6 @@ def test_agent_node_build(basic_graph):
assert agent_node is not None
built_object = agent_node.build()
assert built_object is not None
# Add any further assertions specific to the AgentNode's build() method
def test_tool_node_build(basic_graph):

View file

@ -210,7 +210,7 @@ def test_format_dict():
}
assert format_dict(input_dict) == expected_output
# Test 7: Check class name-specific cases (OpenAI, OpenAIChat)
# Test 7: Check class name-specific cases (OpenAI, ChatOpenAI)
input_dict = {
"model_name": {"type": "str", "required": False},
}
@ -237,7 +237,7 @@ def test_format_dict():
},
}
assert format_dict(input_dict, "OpenAI") == expected_output_openai
assert format_dict(input_dict, "OpenAIChat") == expected_output_openai_chat
assert format_dict(input_dict, "ChatOpenAI") == expected_output_openai_chat
# Test 8: Replace dict type with str
input_dict = {