Merge remote-tracking branch 'origin/dev' into db
This commit is contained in:
commit
77ee6ecf59
26 changed files with 947 additions and 151 deletions
|
|
@ -75,6 +75,7 @@ toolkits:
|
|||
- JsonToolkit
|
||||
- VectorStoreInfo
|
||||
- VectorStoreRouterToolkit
|
||||
- VectorStoreToolkit
|
||||
tools:
|
||||
- Search
|
||||
- PAL-MATH
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
from typing import Any, Union
|
||||
|
||||
|
||||
def validate_prompt(prompt: str):
|
||||
|
|
@ -17,3 +18,14 @@ def fix_prompt(prompt: str):
|
|||
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
||||
"""Extract input variables from prompt."""
|
||||
return re.findall(r"{(.*?)}", prompt)
|
||||
|
||||
|
||||
def flatten_list(list_of_lists: list[Union[list, Any]]) -> list:
|
||||
"""Flatten list of lists."""
|
||||
new_list = []
|
||||
for item in list_of_lists:
|
||||
if isinstance(item, list):
|
||||
new_list.extend(item)
|
||||
else:
|
||||
new_list.append(item)
|
||||
return new_list
|
||||
|
|
|
|||
|
|
@ -174,6 +174,12 @@ class Vertex:
|
|||
# turn result which is a function into a coroutine
|
||||
# so that it can be awaited
|
||||
self.params["coroutine"] = sync_to_async(result)
|
||||
if isinstance(result, list):
|
||||
# If the result is a list, then we need to extend the list
|
||||
# with the result but first check if the key exists
|
||||
# if it doesn't, then we need to create a new list
|
||||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
|
||||
self.params[key] = result
|
||||
elif isinstance(value, list) and all(
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list
|
||||
|
||||
|
||||
class AgentVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="agents")
|
||||
|
||||
self.tools: List[ToolVertex] = []
|
||||
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
|
||||
self.chains: List[ChainVertex] = []
|
||||
|
||||
def _set_tools_and_chains(self) -> None:
|
||||
for edge in self.edges:
|
||||
source_node = edge.source
|
||||
if isinstance(source_node, ToolVertex):
|
||||
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
|
||||
self.tools.append(source_node)
|
||||
elif isinstance(source_node, ChainVertex):
|
||||
self.chains.append(source_node)
|
||||
|
|
@ -32,13 +32,6 @@ class AgentVertex(Vertex):
|
|||
|
||||
self._build()
|
||||
|
||||
#! Cannot deepcopy VectorStore, VectorStoreRouter, or SQL agents
|
||||
if self.vertex_type in [
|
||||
"VectorStoreAgent",
|
||||
"VectorStoreRouterAgent",
|
||||
"SQLAgent",
|
||||
]:
|
||||
return self._built_object
|
||||
return self._built_object
|
||||
|
||||
|
||||
|
|
@ -47,70 +40,6 @@ class ToolVertex(Vertex):
|
|||
super().__init__(data, base_type="tools")
|
||||
|
||||
|
||||
class PromptVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="prompts")
|
||||
|
||||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
if (
|
||||
"input_variables" not in self.params
|
||||
or self.params["input_variables"] is None
|
||||
):
|
||||
self.params["input_variables"] = []
|
||||
# Check if it is a ZeroShotPrompt and needs a tool
|
||||
if "ShotPrompt" in self.vertex_type:
|
||||
tools = (
|
||||
[tool_node.build() for tool_node in tools]
|
||||
if tools is not None
|
||||
else []
|
||||
)
|
||||
self.params["tools"] = tools
|
||||
prompt_params = [
|
||||
key
|
||||
for key, value in self.params.items()
|
||||
if isinstance(value, str) and key != "format_instructions"
|
||||
]
|
||||
else:
|
||||
prompt_params = ["template"]
|
||||
for param in prompt_params:
|
||||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(set(self.params["input_variables"]))
|
||||
|
||||
self._build()
|
||||
return self._built_object
|
||||
|
||||
|
||||
class ChainVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="chains")
|
||||
|
||||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
# Check if the chain requires a PromptNode
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, PromptVertex):
|
||||
# Build the PromptNode, passing the tools if available
|
||||
self.params[key] = value.build(tools=tools, force=force)
|
||||
|
||||
self._build()
|
||||
|
||||
#! Cannot deepcopy SQLDatabaseChain
|
||||
if self.vertex_type in ["SQLDatabaseChain"]:
|
||||
return self._built_object
|
||||
return self._built_object
|
||||
|
||||
|
||||
class LLMVertex(Vertex):
|
||||
built_node_type = None
|
||||
class_built_object = None
|
||||
|
|
@ -196,3 +125,68 @@ class TextSplitterVertex(Vertex):
|
|||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nDocuments: {self._built_object[:3]}..."""
|
||||
return f"{self.vertex_type}()"
|
||||
|
||||
|
||||
class ChainVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="chains")
|
||||
|
||||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
# Check if the chain requires a PromptVertex
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, PromptVertex):
|
||||
# Build the PromptVertex, passing the tools if available
|
||||
self.params[key] = value.build(tools=tools, force=force)
|
||||
|
||||
self._build()
|
||||
|
||||
return self._built_object
|
||||
|
||||
|
||||
class PromptVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="prompts")
|
||||
|
||||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
if (
|
||||
"input_variables" not in self.params
|
||||
or self.params["input_variables"] is None
|
||||
):
|
||||
self.params["input_variables"] = []
|
||||
# Check if it is a ZeroShotPrompt and needs a tool
|
||||
if "ShotPrompt" in self.vertex_type:
|
||||
tools = (
|
||||
[tool_node.build() for tool_node in tools]
|
||||
if tools is not None
|
||||
else []
|
||||
)
|
||||
# flatten the list of tools if it is a list of lists
|
||||
# first check if it is a list
|
||||
if tools and isinstance(tools, list) and isinstance(tools[0], list):
|
||||
tools = flatten_list(tools)
|
||||
self.params["tools"] = tools
|
||||
prompt_params = [
|
||||
key
|
||||
for key, value in self.params.items()
|
||||
if isinstance(value, str) and key != "format_instructions"
|
||||
]
|
||||
else:
|
||||
prompt_params = ["template"]
|
||||
for param in prompt_params:
|
||||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(set(self.params["input_variables"]))
|
||||
|
||||
self._build()
|
||||
return self._built_object
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class JsonAgent(CustomAgentExecutor):
|
|||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(cls, toolkit: JsonToolkit, llm: BaseLanguageModel):
|
||||
tools = toolkit.get_tools()
|
||||
tools = toolkit if isinstance(toolkit, list) else toolkit.get_tools()
|
||||
tool_names = {tool.name for tool in tools}
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
|
|
|
|||
|
|
@ -110,8 +110,11 @@ def instantiate_tool(node_type, class_object, params):
|
|||
|
||||
def instantiate_toolkit(node_type, class_object, params):
|
||||
loaded_toolkit = class_object(**params)
|
||||
if toolkits_creator.has_create_function(node_type):
|
||||
return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
# Commenting this out for now to use toolkits as normal tools
|
||||
# if toolkits_creator.has_create_function(node_type):
|
||||
# return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
if isinstance(loaded_toolkit, BaseToolkit):
|
||||
return loaded_toolkit.get_tools()
|
||||
return loaded_toolkit
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,24 +42,27 @@ class ToolkitCreator(LangChainTypeCreator):
|
|||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
try:
|
||||
return build_template_from_class(name, self.type_to_loader_dict)
|
||||
template = build_template_from_class(name, self.type_to_loader_dict)
|
||||
# add Tool to base_classes
|
||||
if "toolkit" in name.lower() and template:
|
||||
template["base_classes"].append("Tool")
|
||||
return template
|
||||
except ValueError as exc:
|
||||
raise ValueError("Prompt not found") from exc
|
||||
raise ValueError("Toolkit not found") from exc
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Prompt {name} not loaded: {exc}")
|
||||
logger.error(f"Toolkit {name} not loaded: {exc}")
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
return list(self.type_to_loader_dict.keys())
|
||||
|
||||
def get_create_function(self, name: str) -> Callable:
|
||||
if loader_name := self.create_functions.get(name, None):
|
||||
# import loader
|
||||
if loader_name := self.create_functions.get(name):
|
||||
return import_module(
|
||||
f"from langchain.agents.agent_toolkits import {loader_name[0]}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Loader not found")
|
||||
raise ValueError("Toolkit not found")
|
||||
|
||||
def has_create_function(self, name: str) -> bool:
|
||||
# check if the function list is not empty
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class CSVAgentNode(FrontendNode):
|
|||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct a json agent from a CSV and tools."""
|
||||
description: str = """Construct a CSV agent from a CSV and tools."""
|
||||
base_classes: list[str] = ["AgentExecutor"]
|
||||
|
||||
def to_dict(self):
|
||||
|
|
@ -194,7 +194,7 @@ class InitializeAgentNode(FrontendNode):
|
|||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct a json agent from an LLM and tools."""
|
||||
description: str = """Construct a zero shot agent from an LLM and tools."""
|
||||
base_classes: list[str] = ["AgentExecutor", "function"]
|
||||
|
||||
def to_dict(self):
|
||||
|
|
|
|||
|
|
@ -117,17 +117,30 @@ class FrontendNode(BaseModel):
|
|||
) -> None:
|
||||
"""Handles specific field values for certain fields."""
|
||||
if key == "headers":
|
||||
field.value = """{'Authorization':
|
||||
'Bearer <token>'}"""
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
field.options = constants.OPENAI_MODELS
|
||||
field.is_list = True
|
||||
elif name == "ChatOpenAI" and key == "model_name":
|
||||
field.options = constants.CHAT_OPENAI_MODELS
|
||||
field.is_list = True
|
||||
elif (name == "Anthropic" or name == "ChatAnthropic") and key == "model_name":
|
||||
field.options = constants.ANTHROPIC_MODELS
|
||||
field.value = """{'Authorization': 'Bearer <token>'}"""
|
||||
FrontendNode._handle_model_specific_field_values(field, key, name)
|
||||
FrontendNode._handle_api_key_specific_field_values(field, key, name)
|
||||
|
||||
@staticmethod
|
||||
def _handle_model_specific_field_values(
|
||||
field: TemplateField, key: str, name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Handles specific field values related to models."""
|
||||
model_dict = {
|
||||
"OpenAI": constants.OPENAI_MODELS,
|
||||
"ChatOpenAI": constants.CHAT_OPENAI_MODELS,
|
||||
"Anthropic": constants.ANTHROPIC_MODELS,
|
||||
"ChatAnthropic": constants.ANTHROPIC_MODELS,
|
||||
}
|
||||
if name in model_dict and key == "model_name":
|
||||
field.options = model_dict[name]
|
||||
field.is_list = True
|
||||
|
||||
@staticmethod
|
||||
def _handle_api_key_specific_field_values(
|
||||
field: TemplateField, key: str, name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Handles specific field values related to API keys."""
|
||||
if "api_key" in key and "OpenAI" in str(name):
|
||||
field.display_name = "OpenAI API Key"
|
||||
field.required = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue