Merge remote-tracking branch 'origin/main' into dev

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-16 19:28:47 -03:00
commit 3a324ed45a
6 changed files with 536 additions and 18 deletions

View file

@ -3,7 +3,7 @@ agents:
- ZeroShotAgent
- JsonAgent
- CSVAgent
- initialize_agent
- AgentInitializer
- VectorStoreAgent
- VectorStoreRouterAgent
- SQLAgent

View file

@ -13,7 +13,7 @@ CUSTOM_NODES = {
"agents": {
"JsonAgent": frontend_node.agents.JsonAgentNode(),
"CSVAgent": frontend_node.agents.CSVAgentNode(),
"initialize_agent": frontend_node.agents.InitializeAgentNode(),
"AgentInitializer": frontend_node.agents.InitializeAgentNode(),
"VectorStoreAgent": frontend_node.agents.VectorStoreAgentNode(),
"VectorStoreRouterAgent": frontend_node.agents.VectorStoreRouterAgentNode(),
"SQLAgent": frontend_node.agents.SQLAgentNode(),

View file

@ -64,7 +64,9 @@ class JsonAgent(CustomAgentExecutor):
llm=llm,
prompt=prompt,
)
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names # type: ignore
)
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def run(self, *args, **kwargs):
@ -111,7 +113,9 @@ class CSVAgent(CustomAgentExecutor):
prompt=partial_prompt,
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
)
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
@ -148,7 +152,9 @@ class VectorStoreAgent(CustomAgentExecutor):
prompt=prompt,
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)
@ -216,7 +222,9 @@ class SQLAgent(CustomAgentExecutor):
prompt=prompt,
)
tool_names = {tool.name for tool in tools} # type: ignore
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools, # type: ignore
@ -263,7 +271,9 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
prompt=prompt,
)
tool_names = {tool.name for tool in tools}
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) # type: ignore
agent = ZeroShotAgent(
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)
@ -273,11 +283,11 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
class InitializeAgent(CustomAgentExecutor):
"""Implementation of initialize_agent function"""
"""Implementation of AgentInitializer function"""
@staticmethod
def function_name():
return "initialize_agent"
return "AgentInitializer"
@classmethod
def initialize(

View file

@ -97,7 +97,7 @@ class TimeTravelGuideChain(BaseCustomConversationChain):
class CombineDocsChain(CustomChain):
"""Implementation of initialize_agent function"""
"""Implementation of load_qa_chain function"""
@staticmethod
def function_name():

View file

@ -17,14 +17,14 @@ from langflow.interface.importing.utils import import_class
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.chains.custom import CUSTOM_CHAINS
## LLMs
# LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Toolkits
# Toolkits
toolkit_type_to_loader_dict: dict[str, Any] = {
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is lower case it is a loader
@ -39,25 +39,25 @@ toolkit_type_to_cls_dict: dict[str, Any] = {
if not toolkit_name.islower()
}
## Memories
# Memories
memory_type_to_cls_dict: dict[str, Any] = {
memory_name: import_class(f"langchain.memory.{memory_name}")
for memory_name in memory.__all__
}
## Wrappers
# Wrappers
wrapper_type_to_cls_dict: dict[str, Any] = {
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
}
## Embeddings
# Embeddings
embedding_type_to_cls_dict: dict[str, Any] = {
embedding_name: import_class(f"langchain.embeddings.{embedding_name}")
for embedding_name in embeddings.__all__
}
## Document Loaders
# Document Loaders
documentloaders_type_to_cls_dict: dict[str, Any] = {
documentloader_name: import_class(
f"langchain.document_loaders.{documentloader_name}"
@ -65,7 +65,7 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
for documentloader_name in document_loaders.__all__
}
## Text Splitters
# Text Splitters
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
inspect.getmembers(text_splitter, inspect.isclass)
)

File diff suppressed because one or more lines are too long