Add new schemas and vertex type for CustomComponent

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-25 13:03:29 -03:00
commit 47397153f4
6 changed files with 220 additions and 156 deletions

View file

@ -15,6 +15,7 @@ from langflow.graph.vertex.types import (
VectorStoreVertex,
WrapperVertex,
RetrieverVertex,
CustomComponentVertex,
)
__all__ = [
@ -34,4 +35,5 @@ __all__ = [
"VectorStoreVertex",
"WrapperVertex",
"RetrieverVertex",
"CustomComponentVertex",
]

View file

@ -0,0 +1,34 @@
from typing import Any, List
from pydantic import BaseModel
class ResultPair(BaseModel):
result: Any
extra: Any
class Payload(BaseModel):
result_pairs: List[ResultPair] = []
def __iter__(self):
return iter(self.result_pairs)
def add_result_pair(self, result: Any, extra: Any = None) -> None:
self.result_pairs.append(ResultPair(result=result, extra=extra))
def get_last_result_pair(self) -> ResultPair:
return self.result_pairs[-1]
# format all but the last result pair
# into a string
def format(self, sep: str = "\n") -> str:
# Result: the result
# Extra: the extra if it exists don't show if it doesn't
return sep.join(
[
f"Result: {result_pair.result}\nExtra: {result_pair.extra}"
if result_pair.extra is not None
else f"Result: {result_pair.result}"
for result_pair in self.result_pairs[:-1]
]
)

View file

@ -1,22 +1,25 @@
from langflow.graph.vertex import types
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.custom.base import custom_component_creator
from langflow.interface.document_loaders.base import documentloader_creator
from langflow.interface.embeddings.base import embedding_creator
from langflow.interface.llms.base import llm_creator
from langflow.interface.memories.base import memory_creator
from langflow.interface.output_parsers.base import output_parser_creator
from langflow.interface.prompts.base import prompt_creator
from langflow.interface.retrievers.base import retriever_creator
from langflow.interface.text_splitters.base import textsplitter_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.tools.base import tool_creator
from langflow.interface.vector_store.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.interface.output_parsers.base import output_parser_creator
from langflow.interface.retrievers.base import retriever_creator
from langflow.interface.custom.base import custom_component_creator
from langflow.utils.lazy_load import LazyLoadDictBase
chat_components = ["ChatInput", "ChatOutput", "TextInput", "SessionID"]
class VertexTypesDict(LazyLoadDictBase):
def __init__(self):
self._all_types_dict = None
@ -32,9 +35,6 @@ class VertexTypesDict(LazyLoadDictBase):
"Custom": ["Custom Tool", "Python Function"],
}
def get_custom_component_vertex_type(self):
return types.CustomComponentVertex
def get_type_dict(self):
return {
**{t: types.PromptVertex for t in prompt_creator.to_list()},
@ -50,8 +50,12 @@ class VertexTypesDict(LazyLoadDictBase):
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
**{t: types.CustomComponentVertex for t in custom_component_creator.to_list()},
**{
t: types.CustomComponentVertex
for t in custom_component_creator.to_list()
},
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
**{t: types.ChatVertex for t in chat_components},
}

View file

@ -1,6 +1,5 @@
import copy
from collections import deque
from typing import Dict, List
import copy
def find_last_node(nodes, edges):
@ -29,14 +28,23 @@ def ungroup_node(group_node_data, base_flow):
g_edges = flow["data"]["edges"]
# Redirect edges to the correct proxy node
updated_edges = get_updated_edges(base_flow, g_nodes, g_edges, group_node_data["id"])
updated_edges = get_updated_edges(
base_flow, g_nodes, g_edges, group_node_data["id"]
)
# Update template values
update_template(template, g_nodes)
nodes = [n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]] + g_nodes
nodes = [
n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]
] + g_nodes
edges = (
[e for e in base_flow["edges"] if e["target"] != group_node_data["id"] and e["source"] != group_node_data["id"]]
[
e
for e in base_flow["edges"]
if e["target"] != group_node_data["id"]
and e["source"] != group_node_data["id"]
]
+ g_edges
+ updated_edges
)
@ -47,38 +55,6 @@ def ungroup_node(group_node_data, base_flow):
return nodes
def raw_topological_sort(nodes, edges) -> List[Dict]:
# Redefine the above function but using the nodes and self._edges
# which are dicts instead of Vertex and Edge objects
# nodes have an id, edges have a source and target keys
# return a list of node ids in topological order
# States: 0 = unvisited, 1 = visiting, 2 = visited
state = {node["id"]: 0 for node in nodes}
nodes_dict = {node["id"]: node for node in nodes}
sorted_vertices = []
def dfs(node):
if state[node] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
if state[node] == 0:
state[node] = 1
for edge in edges:
if edge["source"] == node:
dfs(edge["target"])
state[node] = 2
sorted_vertices.append(node)
# Visit each node
for node in nodes:
if state[node["id"]] == 0:
dfs(node["id"])
reverse_sorted = list(reversed(sorted_vertices))
return [nodes_dict[node_id] for node_id in reverse_sorted]
def process_flow(flow_object):
cloned_flow = copy.deepcopy(flow_object)
processed_nodes = set() # To keep track of processed nodes
@ -90,7 +66,11 @@ def process_flow(flow_object):
if node_id in processed_nodes:
return
if node.get("data") and node["data"].get("node") and node["data"]["node"].get("flow"):
if (
node.get("data")
and node["data"].get("node")
and node["data"]["node"].get("flow")
):
process_flow(node["data"]["node"]["flow"]["data"])
new_nodes = ungroup_node(node["data"], cloned_flow)
# Add new nodes to the queue for future processing
@ -99,8 +79,7 @@ def process_flow(flow_object):
# Mark node as processed
processed_nodes.add(node_id)
sorted_nodes_list = raw_topological_sort(cloned_flow["nodes"], cloned_flow["edges"])
nodes_to_process = deque(sorted_nodes_list)
nodes_to_process = deque(cloned_flow["nodes"])
while nodes_to_process:
node = nodes_to_process.popleft()
@ -129,23 +108,29 @@ def update_template(template, g_nodes):
if node_index != -1:
display_name = None
show = g_nodes[node_index]["data"]["node"]["template"][field]["show"]
advanced = g_nodes[node_index]["data"]["node"]["template"][field]["advanced"]
advanced = g_nodes[node_index]["data"]["node"]["template"][field][
"advanced"
]
if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]:
display_name = g_nodes[node_index]["data"]["node"]["template"][field]["display_name"]
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
"display_name"
]
else:
display_name = g_nodes[node_index]["data"]["node"]["template"][field]["name"]
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
"name"
]
g_nodes[node_index]["data"]["node"]["template"][field] = value
g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show
g_nodes[node_index]["data"]["node"]["template"][field]["advanced"] = advanced
g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name
g_nodes[node_index]["data"]["node"]["template"][field][
"advanced"
] = advanced
g_nodes[node_index]["data"]["node"]["template"][field][
"display_name"
] = display_name
def update_target_handle(
new_edge,
g_nodes,
group_node_id,
):
def update_target_handle(new_edge, g_nodes, group_node_id):
"""
Updates the target handle of a given edge if it is a proxy node.
@ -162,8 +147,6 @@ def update_target_handle(
proxy_id = target_handle["proxy"]["id"]
if node := next((n for n in g_nodes if n["id"] == proxy_id), None):
set_new_target_handle(proxy_id, new_edge, target_handle, node)
else:
raise ValueError(f"Group node {group_node_id} has an invalid target proxy node {proxy_id}")
return new_edge

View file

@ -1,17 +1,19 @@
import ast
from typing import Any, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
from langflow.graph.utils import UnbuiltObject, flatten_list
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.utils.schemas import ChatOutputResponse
class AgentVertex(Vertex):
class AgentVertex(StatelessVertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="agents", params=params)
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
self.chains: List[ChainVertex] = []
self.steps: List[Callable] = [self._custom_build, self._run]
def __getstate__(self):
state = super().__getstate__()
@ -26,84 +28,85 @@ class AgentVertex(Vertex):
def _set_tools_and_chains(self) -> None:
for edge in self.edges:
if not hasattr(edge, "source_id"):
if not hasattr(edge, "source"):
continue
source_node = self.graph.get_vertex(edge.source_id)
source_node = edge.source
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
self.tools.append(source_node)
elif isinstance(source_node, ChainVertex):
self.chains.append(source_node)
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
if not self._built or force:
self._set_tools_and_chains()
# First, build the tools
for tool_node in self.tools:
await tool_node.build(user_id=user_id)
async def _custom_build(self, *args, **kwargs):
user_id = kwargs.get("user_id", None)
self._set_tools_and_chains()
# First, build the tools
for tool_node in self.tools:
await tool_node.build(user_id=user_id)
# Next, build the chains and the rest
for chain_node in self.chains:
await chain_node.build(tools=self.tools, user_id=user_id)
# Next, build the chains and the rest
for chain_node in self.chains:
await chain_node.build(tools=self.tools, user_id=user_id)
await self._build(user_id=user_id)
return self._built_object
await self._build(user_id=user_id)
class ToolVertex(Vertex):
def __init__(
self,
data: Dict,
graph,
params: Optional[Dict] = None,
):
class ToolVertex(StatelessVertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="tools", params=params)
class LLMVertex(Vertex):
class LLMVertex(StatelessVertex):
built_node_type = None
class_built_object = None
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="llms", params=params)
self.steps: List[Callable] = [self._custom_build]
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
async def _custom_build(self, *args, **kwargs):
# LLM is different because some models might take up too much memory
# or time to load. So we only load them when we need them.ß
# or time to load. So we only load them when we need them.
# Avoid deepcopying the LLM
# that are loaded from a file
force = kwargs.get("force", False)
user_id = kwargs.get("user_id", None)
if self.vertex_type == self.built_node_type:
return self.class_built_object
self._built_object = self.class_built_object
if not self._built or force:
await self._build(user_id=user_id)
self.built_node_type = self.vertex_type
self.class_built_object = self._built_object
# Avoid deepcopying the LLM
# that are loaded from a file
return self._built_object
class ToolkitVertex(Vertex):
class ToolkitVertex(StatelessVertex):
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="toolkits", params=params)
class FileToolVertex(ToolVertex):
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, params=params)
super().__init__(
data,
params=params,
graph=graph,
)
class WrapperVertex(Vertex):
def __init__(self, data: Dict, graph):
class WrapperVertex(StatelessVertex):
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="wrappers")
self.steps: List[Callable] = [self._custom_build]
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
async def _custom_build(self, *args, **kwargs):
force = kwargs.get("force", False)
user_id = kwargs.get("user_id", None)
if not self._built or force:
if "headers" in self.params:
self.params["headers"] = ast.literal_eval(self.params["headers"])
await self._build(user_id=user_id)
return self._built_object
class DocumentLoaderVertex(Vertex):
class DocumentLoaderVertex(StatefulVertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="documentloaders", params=params)
@ -111,7 +114,7 @@ class DocumentLoaderVertex(Vertex):
# This built_object is a list of documents. Maybe we should
# show how many documents are in the list?
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
self._built_object
)
@ -121,12 +124,12 @@ class DocumentLoaderVertex(Vertex):
return f"{self.vertex_type}()"
class EmbeddingVertex(Vertex):
class EmbeddingVertex(StatefulVertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="embeddings", params=params)
class VectorStoreVertex(Vertex):
class VectorStoreVertex(StatefulVertex):
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="vectorstores")
@ -135,6 +138,15 @@ class VectorStoreVertex(Vertex):
# VectorStores may contain databse connections
# so we need to define the __reduce__ method and the __setstate__ method
# to avoid pickling errors
def clean_edges_for_pickling(self):
# for each edge that has self as source
# we need to clear the _built_object of the target
# so that we don't try to pickle a database connection
for edge in self.edges:
if edge.source == self:
edge.target._built_object = None
edge.target._built = False
edge.target.params[edge.target_param] = self
def remove_docs_and_texts_from_params(self):
# remove documents and texts from params
@ -142,33 +154,34 @@ class VectorStoreVertex(Vertex):
self.params.pop("documents", None)
self.params.pop("texts", None)
# def __getstate__(self):
# # We want to save the params attribute
# # and if "documents" or "texts" are in the params
# # we want to remove them because they have already
# # been processed.
# params = self.params.copy()
# params.pop("documents", None)
# params.pop("texts", None)
def __getstate__(self):
# We want to save the params attribute
# and if "documents" or "texts" are in the params
# we want to remove them because they have already
# been processed.
params = self.params.copy()
params.pop("documents", None)
params.pop("texts", None)
self.clean_edges_for_pickling()
# return super().__getstate__()
return super().__getstate__()
def __setstate__(self, state):
super().__setstate__(state)
self.remove_docs_and_texts_from_params()
class MemoryVertex(Vertex):
class MemoryVertex(StatefulVertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="memory")
class RetrieverVertex(Vertex):
class RetrieverVertex(StatefulVertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="retrievers")
class TextSplitterVertex(Vertex):
class TextSplitterVertex(StatefulVertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="textsplitters", params=params)
@ -176,7 +189,7 @@ class TextSplitterVertex(Vertex):
# This built_object is a list of documents. Maybe we should
# show how many documents are in the list?
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
@ -184,54 +197,51 @@ class TextSplitterVertex(Vertex):
return f"{self.vertex_type}()"
class ChainVertex(Vertex):
class ChainVertex(StatelessVertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="chains")
self.steps = [self._custom_build, self._run]
async def build(
self,
force: bool = False,
user_id=None,
*args,
**kwargs,
) -> Any:
if not self._built or force:
# Temporarily remove the code from the params
self.params.pop("code", None)
# Check if the chain requires a PromptVertex
async def _custom_build(self, *args, **kwargs):
force = kwargs.get("force", False)
user_id = kwargs.get("user_id", None)
# Remove this once LLMChain is CustomComponent
self.params.pop("code", None)
for key, value in self.params.items():
if isinstance(value, PromptVertex):
# Build the PromptVertex, passing the tools if available
tools = kwargs.get("tools", None)
self.params[key] = value.build(tools=tools, pinned=force)
# Temporarily remove "code" from the params
self.params.pop("code", None)
await self._build(user_id=user_id)
for key, value in self.params.items():
if isinstance(value, PromptVertex):
# Build the PromptVertex, passing the tools if available
tools = kwargs.get("tools", None)
self.params[key] = await value.build(tools=tools, force=force)
def set_artifacts(self) -> None:
if isinstance(self._built_object, UnbuiltObject):
return
if self._built_object and hasattr(self._built_object, "input_keys"):
self.artifacts = dict(input_keys=self._built_object.input_keys)
await self._build(user_id=user_id)
return self._built_object
def _built_object_repr(self):
if isinstance(self._built_object, str):
return self._built_object
return super()._built_object_repr()
class PromptVertex(Vertex):
class PromptVertex(StatelessVertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="prompts")
self.steps: List[Callable] = [self._custom_build]
async def build(
self,
force: bool = False,
user_id=None,
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
*args,
**kwargs,
) -> Any:
async def _custom_build(self, *args, **kwargs):
force = kwargs.get("force", False)
user_id = kwargs.get("user_id", None)
tools = kwargs.get("tools", [])
if not self._built or force:
if "input_variables" not in self.params or self.params["input_variables"] is None:
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = [await tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
@ -253,11 +263,12 @@ class PromptVertex(Vertex):
self.params.pop("input_variables", None)
await self._build(user_id=user_id)
return self._built_object
def _built_object_repr(self):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
return super()._built_object_repr()
elif isinstance(self._built_object, UnbuiltObject):
return super()._built_object_repr()
# We'll build the prompt with the artifacts
# to show the user what the prompt looks like
# with the variables filled in
@ -265,15 +276,10 @@ class PromptVertex(Vertex):
# Remove the handle_keys from the artifacts
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
template = ""
try:
if (
not hasattr(self._built_object, "template")
and hasattr(self._built_object, "prompt")
and not isinstance(self._built_object, UnbuiltObject)
):
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
template = self._built_object.prompt.template
elif not isinstance(self._built_object, UnbuiltObject) and hasattr(self._built_object, "template"):
else:
template = self._built_object.template
for key, value in artifacts.items():
if value:
@ -284,14 +290,24 @@ class PromptVertex(Vertex):
return str(self._built_object)
class OutputParserVertex(Vertex):
class OutputParserVertex(StatelessVertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="output_parsers")
class CustomComponentVertex(Vertex):
class CustomComponentVertex(StatelessVertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components", is_task=False)
super().__init__(data, graph=graph, base_type="custom_components")
def _built_object_repr(self):
if self.artifacts and "repr" in self.artifacts:
return self.artifacts["repr"] or super()._built_object_repr()
class ChatVertex(StatelessVertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components", is_task=True)
self.steps = [self._build, self._run]
def _built_object_repr(self):
if self.task_id and self.is_task:
@ -301,3 +317,18 @@ class CustomComponentVertex(Vertex):
return f"Task {self.task_id} is not running"
if self.artifacts and "repr" in self.artifacts:
return self.artifacts["repr"] or super()._built_object_repr()
def _run(self, *args, **kwargs):
if self.is_power_component:
if self.vertex_type == "ChatOutput":
sender = self.params.get("sender", None)
sender_name = self.params.get("sender_name", None)
self.artifacts = ChatOutputResponse(
message=str(self._built_object),
sender=sender,
sender_name=sender_name,
).dict()
self._built_result = self._built_object
else:
super()._run(*args, **kwargs)

View file

@ -0,0 +1,10 @@
from pydantic import BaseModel
from typing import Optional
class ChatOutputResponse(BaseModel):
"""Chat output response schema."""
message: str
sender: Optional[str] = "Machine"
sender_name: Optional[str] = "AI"