Merge remote-tracking branch 'origin/dev' into db

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-05 22:18:07 -03:00
commit 77ee6ecf59
26 changed files with 947 additions and 151 deletions

View file

@ -75,6 +75,7 @@ toolkits:
- JsonToolkit
- VectorStoreInfo
- VectorStoreRouterToolkit
- VectorStoreToolkit
tools:
- Search
- PAL-MATH

View file

@ -1,4 +1,5 @@
import re
from typing import Any, Union
def validate_prompt(prompt: str):
@ -17,3 +18,14 @@ def fix_prompt(prompt: str):
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
"""Extract input variables from prompt."""
return re.findall(r"{(.*?)}", prompt)
def flatten_list(list_of_lists: list[Union[list, Any]]) -> list:
"""Flatten list of lists."""
new_list = []
for item in list_of_lists:
if isinstance(item, list):
new_list.extend(item)
else:
new_list.append(item)
return new_list

View file

@ -174,6 +174,12 @@ class Vertex:
# turn result which is a function into a coroutine
# so that it can be awaited
self.params["coroutine"] = sync_to_async(result)
if isinstance(result, list):
# If the result is a list, then we need to extend the list
# with the result but first check if the key exists
# if it doesn't, then we need to create a new list
if isinstance(self.params[key], list):
self.params[key].extend(result)
self.params[key] = result
elif isinstance(value, list) and all(

View file

@ -1,20 +1,20 @@
from typing import Any, Dict, List, Optional, Union
from langflow.graph.vertex.base import Vertex
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list
class AgentVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="agents")
self.tools: List[ToolVertex] = []
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
self.chains: List[ChainVertex] = []
def _set_tools_and_chains(self) -> None:
for edge in self.edges:
source_node = edge.source
if isinstance(source_node, ToolVertex):
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
self.tools.append(source_node)
elif isinstance(source_node, ChainVertex):
self.chains.append(source_node)
@ -32,13 +32,6 @@ class AgentVertex(Vertex):
self._build()
#! Cannot deepcopy VectorStore, VectorStoreRouter, or SQL agents
if self.vertex_type in [
"VectorStoreAgent",
"VectorStoreRouterAgent",
"SQLAgent",
]:
return self._built_object
return self._built_object
@ -47,70 +40,6 @@ class ToolVertex(Vertex):
super().__init__(data, base_type="tools")
class PromptVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="prompts")
def build(
self,
force: bool = False,
tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None,
) -> Any:
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build() for tool_node in tools]
if tools is not None
else []
)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
for param in prompt_params:
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
self._build()
return self._built_object
class ChainVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def build(
self,
force: bool = False,
tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptNode
for key, value in self.params.items():
if isinstance(value, PromptVertex):
# Build the PromptNode, passing the tools if available
self.params[key] = value.build(tools=tools, force=force)
self._build()
#! Cannot deepcopy SQLDatabaseChain
if self.vertex_type in ["SQLDatabaseChain"]:
return self._built_object
return self._built_object
class LLMVertex(Vertex):
built_node_type = None
class_built_object = None
@ -196,3 +125,68 @@ class TextSplitterVertex(Vertex):
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nDocuments: {self._built_object[:3]}..."""
return f"{self.vertex_type}()"
class ChainVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def build(
self,
force: bool = False,
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptVertex
for key, value in self.params.items():
if isinstance(value, PromptVertex):
# Build the PromptVertex, passing the tools if available
self.params[key] = value.build(tools=tools, force=force)
self._build()
return self._built_object
class PromptVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="prompts")
def build(
self,
force: bool = False,
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
) -> Any:
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build() for tool_node in tools]
if tools is not None
else []
)
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
for param in prompt_params:
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
self._build()
return self._built_object

View file

@ -69,7 +69,7 @@ class JsonAgent(CustomAgentExecutor):
@classmethod
def from_toolkit_and_llm(cls, toolkit: JsonToolkit, llm: BaseLanguageModel):
tools = toolkit.get_tools()
tools = toolkit if isinstance(toolkit, list) else toolkit.get_tools()
tool_names = {tool.name for tool in tools}
prompt = ZeroShotAgent.create_prompt(
tools,

View file

@ -110,8 +110,11 @@ def instantiate_tool(node_type, class_object, params):
def instantiate_toolkit(node_type, class_object, params):
loaded_toolkit = class_object(**params)
if toolkits_creator.has_create_function(node_type):
return load_toolkits_executor(node_type, loaded_toolkit, params)
# Commenting this out for now to use toolkits as normal tools
# if toolkits_creator.has_create_function(node_type):
# return load_toolkits_executor(node_type, loaded_toolkit, params)
if isinstance(loaded_toolkit, BaseToolkit):
return loaded_toolkit.get_tools()
return loaded_toolkit

View file

@ -42,24 +42,27 @@ class ToolkitCreator(LangChainTypeCreator):
def get_signature(self, name: str) -> Optional[Dict]:
try:
return build_template_from_class(name, self.type_to_loader_dict)
template = build_template_from_class(name, self.type_to_loader_dict)
# add Tool to base_classes
if "toolkit" in name.lower() and template:
template["base_classes"].append("Tool")
return template
except ValueError as exc:
raise ValueError("Prompt not found") from exc
raise ValueError("Toolkit not found") from exc
except AttributeError as exc:
logger.error(f"Prompt {name} not loaded: {exc}")
logger.error(f"Toolkit {name} not loaded: {exc}")
return None
def to_list(self) -> List[str]:
return list(self.type_to_loader_dict.keys())
def get_create_function(self, name: str) -> Callable:
if loader_name := self.create_functions.get(name, None):
# import loader
if loader_name := self.create_functions.get(name):
return import_module(
f"from langchain.agents.agent_toolkits import {loader_name[0]}"
)
else:
raise ValueError("Loader not found")
raise ValueError("Toolkit not found")
def has_create_function(self, name: str) -> bool:
# check if the function list is not empty

View file

@ -146,7 +146,7 @@ class CSVAgentNode(FrontendNode):
),
],
)
description: str = """Construct a json agent from a CSV and tools."""
description: str = """Construct a CSV agent from a CSV and tools."""
base_classes: list[str] = ["AgentExecutor"]
def to_dict(self):
@ -194,7 +194,7 @@ class InitializeAgentNode(FrontendNode):
),
],
)
description: str = """Construct a json agent from an LLM and tools."""
description: str = """Construct a zero shot agent from an LLM and tools."""
base_classes: list[str] = ["AgentExecutor", "function"]
def to_dict(self):

View file

@ -117,17 +117,30 @@ class FrontendNode(BaseModel):
) -> None:
"""Handles specific field values for certain fields."""
if key == "headers":
field.value = """{'Authorization':
'Bearer <token>'}"""
if name == "OpenAI" and key == "model_name":
field.options = constants.OPENAI_MODELS
field.is_list = True
elif name == "ChatOpenAI" and key == "model_name":
field.options = constants.CHAT_OPENAI_MODELS
field.is_list = True
elif (name == "Anthropic" or name == "ChatAnthropic") and key == "model_name":
field.options = constants.ANTHROPIC_MODELS
field.value = """{'Authorization': 'Bearer <token>'}"""
FrontendNode._handle_model_specific_field_values(field, key, name)
FrontendNode._handle_api_key_specific_field_values(field, key, name)
@staticmethod
def _handle_model_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
"""Handles specific field values related to models."""
model_dict = {
"OpenAI": constants.OPENAI_MODELS,
"ChatOpenAI": constants.CHAT_OPENAI_MODELS,
"Anthropic": constants.ANTHROPIC_MODELS,
"ChatAnthropic": constants.ANTHROPIC_MODELS,
}
if name in model_dict and key == "model_name":
field.options = model_dict[name]
field.is_list = True
@staticmethod
def _handle_api_key_specific_field_values(
field: TemplateField, key: str, name: Optional[str] = None
) -> None:
"""Handles specific field values related to API keys."""
if "api_key" in key and "OpenAI" in str(name):
field.display_name = "OpenAI API Key"
field.required = False