fix: loading agent correctly and added correct chat models loading

This commit is contained in:
Gabriel Almeida 2023-04-02 10:41:40 -03:00
commit 58652f7c2b
9 changed files with 35 additions and 16 deletions

View file

@ -16,7 +16,7 @@ prompts:
llms:
- OpenAI
- OpenAIChat
- ChatOpenAI
tools:
- Search

View file

@ -179,7 +179,9 @@ class Node:
params=self.params,
)
except Exception as exc:
raise ValueError(f"Error building node {self.node_type}") from exc
raise ValueError(
f"Error building node {self.node_type}: {str(exc)}"
) from exc
if self._built_object is None:
raise ValueError(f"Node type {self.node_type} not found")

View file

@ -106,7 +106,10 @@ class Graph:
if node_type in prompt_creator.to_list():
nodes.append(PromptNode(node))
elif node_type in agent_creator.to_list():
elif (
node_type in agent_creator.to_list()
or node_lc_type in agent_creator.to_list()
):
nodes.append(AgentNode(node))
elif node_type in chain_creator.to_list():
nodes.append(ChainNode(node))
@ -118,7 +121,10 @@ class Graph:
nodes.append(ToolkitNode(node))
elif node_type in wrapper_creator.to_list():
nodes.append(WrapperNode(node))
elif node_type in llm_creator.to_list():
elif (
node_type in llm_creator.to_list()
or node_lc_type in llm_creator.to_list()
):
nodes.append(LLMNode(node))
else:
nodes.append(Node(node))

View file

@ -41,7 +41,7 @@ class AgentCreator(LangChainTypeCreator):
else agent.__name__
)
if agent_name in settings.agents or settings.dev:
names.append(name)
names.append(agent_name)
return names

View file

@ -8,7 +8,7 @@ from langchain.agents import Agent
from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM
from langchain.tools import BaseTool
from langchain.chat_models.base import BaseChatModel
from langflow.interface.tools.util import get_tool_by_name
@ -31,14 +31,25 @@ def import_by_type(_type: str, name: str) -> Any:
func_dict = {
"agents": import_agent,
"prompts": import_prompt,
"llms": import_llm,
"llms": {"llm": import_llm, "chat": import_chat_llm},
"tools": import_tool,
"chains": import_chain,
"toolkits": import_toolkit,
"wrappers": import_wrapper,
"memory": import_memory,
}
return func_dict[_type](name)
if _type == "llms":
key = "chat" if "chat" in name.lower() else "llm"
loaded_func = func_dict[_type][key]
else:
loaded_func = func_dict[_type]
return loaded_func(name)
def import_chat_llm(llm: str) -> BaseChatModel:
"""Import chat llm from llm name"""
return import_class(f"langchain.chat_models.{llm}")
def import_memory(memory: str) -> Any:

View file

@ -105,7 +105,7 @@ class TemplateFieldCreator(BaseModel, ABC):
if name == "OpenAI" and key == "model_name":
self.options = constants.OPENAI_MODELS
self.is_list = True
elif name == "OpenAIChat" and key == "model_name":
elif name == "ChatOpenAI" and key == "model_name":
self.options = constants.CHAT_OPENAI_MODELS
self.is_list = True
@ -216,6 +216,6 @@ class FrontendNode(BaseModel):
if name == "OpenAI" and key == "model_name":
field.options = constants.OPENAI_MODELS
field.is_list = True
elif name == "OpenAIChat" and key == "model_name":
elif name == "ChatOpenAI" and key == "model_name":
field.options = constants.CHAT_OPENAI_MODELS
field.is_list = True

View file

@ -327,7 +327,7 @@ def format_dict(d, name: Optional[str] = None):
if name == "OpenAI" and key == "model_name":
value["options"] = constants.OPENAI_MODELS
value["list"] = True
elif name == "OpenAIChat" and key == "model_name":
elif name == "ChatOpenAI" and key == "model_name":
value["options"] = constants.CHAT_OPENAI_MODELS
value["list"] = True

View file

@ -267,7 +267,7 @@
"y": 514.9920887988924
},
"data": {
"type": "OpenAIChat",
"type": "ChatOpenAI",
"node": {
"template": {
"cache": {
@ -365,7 +365,7 @@
"type": "bool",
"list": false
},
"_type": "OpenAIChat"
"_type": "ChatOpenAI"
},
"description": "Wrapper around OpenAI Chat large language models.To use, you should have the ``openai`` python package installed, and theenvironment variable ``OPENAI_API_KEY`` set with your API key.Any parameters that are valid to be passed to the openai.create call can be passedin, even if not explicitly saved on this class.",
"base_classes": [
@ -423,7 +423,7 @@
},
{
"source": "dndnode_36",
"sourceHandle": "OpenAIChat|dndnode_36|BaseLanguageModel|BaseLLM",
"sourceHandle": "ChatOpenAI|dndnode_36|BaseLanguageModel|BaseLLM",
"target": "dndnode_33",
"targetHandle": "BaseLanguageModel|llm|dndnode_33",
"className": "animate-pulse",

View file

@ -210,7 +210,7 @@ def test_format_dict():
}
assert format_dict(input_dict) == expected_output
# Test 7: Check class name-specific cases (OpenAI, OpenAIChat)
# Test 7: Check class name-specific cases (OpenAI, ChatOpenAI)
input_dict = {
"model_name": {"type": "str", "required": False},
}
@ -237,7 +237,7 @@ def test_format_dict():
},
}
assert format_dict(input_dict, "OpenAI") == expected_output_openai
assert format_dict(input_dict, "OpenAIChat") == expected_output_openai_chat
assert format_dict(input_dict, "ChatOpenAI") == expected_output_openai_chat
# Test 8: Replace dict type with str
input_dict = {