fix: set up creators

This commit is contained in:
Gabriel Almeida 2023-03-30 16:35:08 -03:00
commit cfbd22d2b0
10 changed files with 43 additions and 29 deletions

View file

@ -5,9 +5,8 @@
from copy import deepcopy
import types
from typing import Any, Dict, List, Optional, Union
from langflow.utils import payload
from langflow.interface.listing import ALL_TYPES_DICT, ALL_TOOLS_NAMES, TOOLS_DICT
from typing import Any, Dict, List
from langflow.interface.listing import ALL_TYPES_DICT, TOOLS_DICT
from langflow.interface import loading

View file

@ -26,3 +26,6 @@ class AgentCreator(LangChainTypeCreator):
for agent in self.type_to_loader_dict.values()
if agent.__name__ in settings.agents or settings.dev
]
agent_creator = AgentCreator()

View file

@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from abc import ABC, abstractmethod
from langflow.template.template import Template, Field, FrontendNode
@ -16,23 +16,25 @@ class LangChainTypeCreator(BaseModel, ABC):
pass
@abstractmethod
def get_signature(self, name: str) -> Dict:
def get_signature(self, name: str) -> Optional[Dict[Any, Any]]:
pass
@abstractmethod
def to_list(self) -> List[str]:
pass
def to_dict(self):
result = {self.type_name: {}} # type: Dict
def to_dict(self) -> Dict:
result: Dict = {self.type_name: {}}
for name in self.to_list():
result[self.type_name][name] = self.get_signature(name)
result[self.type_name][name] = self.frontend_node(name).to_dict()
return result
def frontend_node(self, name) -> FrontendNode:
signature = self.get_signature(name)
if signature is None:
raise ValueError(f"{name} not found")
fields = [
Field(
name=key,

View file

@ -1,7 +1,5 @@
from typing import Dict, List
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.signature import get_chain_signature
from langflow.template.template import Field, FrontendNode, Template
from langflow.utils.util import build_template_from_function
from langflow.settings import settings
from langchain.chains import loading as chains_loading
@ -33,3 +31,6 @@ class ChainCreator(LangChainTypeCreator):
or settings.dev
)
]
chain_creator = ChainCreator()

View file

@ -1,5 +1,5 @@
from typing import Callable, Optional
from langchain import LLMChain, PromptTemplate
from langchain import LLMChain
from langchain.agents import AgentExecutor, ZeroShotAgent
from langflow.utils import validate
from pydantic import BaseModel, validator
@ -64,6 +64,12 @@ class JsonAgent(BaseModel):
prompt=prompt,
)
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
return AgentExecutor.from_agent_and_tools(
self.agent_executor = AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)
def __call__(self, *args, **kwargs):
return self.agent_executor(*args, **kwargs)
def run(self, *args, **kwargs):
return self.agent_executor.run(*args, **kwargs)

View file

@ -25,3 +25,6 @@ class LLMCreator(LangChainTypeCreator):
for llm in self.type_to_loader_dict.values()
if llm.__name__ in settings.llms or settings.dev
]
llm_creator = LLMCreator()

View file

@ -25,3 +25,6 @@ class MemoryCreator(LangChainTypeCreator):
for memory in self.type_to_loader_dict.values()
if memory.__name__ in settings.memories or settings.dev
]
memory_creator = MemoryCreator()

View file

@ -30,3 +30,6 @@ class PromptCreator(LangChainTypeCreator):
or settings.dev
]
return library_prompts + list(custom_prompts.keys())
prompt_creator = PromptCreator()

View file

@ -18,7 +18,7 @@ class ToolCreator(LangChainTypeCreator):
@property
def type_to_loader_dict(self) -> Dict:
return ALL_TOOLS_NAMES
return util.get_tools_dict()
def get_signature(self, name: str) -> Dict | None:
"""Get the signature of a tool."""
@ -26,7 +26,7 @@ class ToolCreator(LangChainTypeCreator):
NODE_INPUTS = ["llm", "func"]
base_classes = ["Tool"]
all_tools = {}
for tool in ALL_TOOLS_NAMES:
for tool in self.type_to_loader_dict.keys():
if tool_params := util.get_tool_params(util.get_tool_by_name(tool)):
tool_name = tool_params.get("name") or str(tool)
all_tools[tool_name] = {"type": tool, "params": tool_params}
@ -126,3 +126,6 @@ class ToolCreator(LangChainTypeCreator):
# Add Tool
custom_tools = customs.get_custom_nodes("tools")
return tools + list(custom_tools.keys())
tool_creator = ToolCreator()

View file

@ -1,12 +1,9 @@
from langflow.interface.agents import AgentCreator
from langflow.interface.listing import list_type
from langflow.interface.llms import LLMCreator
from langflow.interface.memories.base import MemoryCreator
from langflow.interface.prompts import PromptCreator
from langflow.interface.signature import get_signature
from langchain import chains
from langflow.interface.chains import ChainCreator
from langflow.interface.tools import ToolCreator
from langflow.interface.agents.base import agent_creator
from langflow.interface.llms.base import llm_creator
from langflow.interface.memories.base import memory_creator
from langflow.interface.prompts.base import prompt_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.tools.base import tool_creator
def get_type_list():
@ -23,12 +20,6 @@ def get_type_list():
def build_langchain_types_dict():
"""Build a dictionary of all langchain types"""
chain_creator = ChainCreator()
agent_creator = AgentCreator()
prompt_creator = PromptCreator()
tool_creator = ToolCreator()
llm_creator = LLMCreator()
memory_creator = MemoryCreator()
all_types = {}