From 47397153f483f1de223b4778468a526d2ef3a67c Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 25 Jan 2024 13:03:29 -0300 Subject: [PATCH] Add new schemas and vertex type for CustomComponent --- src/backend/langflow/graph/__init__.py | 2 + src/backend/langflow/graph/edge/schema.py | 34 +++ src/backend/langflow/graph/graph/constants.py | 18 +- src/backend/langflow/graph/graph/utils.py | 87 +++---- src/backend/langflow/graph/vertex/types.py | 225 ++++++++++-------- src/backend/langflow/utils/schemas.py | 10 + 6 files changed, 220 insertions(+), 156 deletions(-) create mode 100644 src/backend/langflow/graph/edge/schema.py create mode 100644 src/backend/langflow/utils/schemas.py diff --git a/src/backend/langflow/graph/__init__.py b/src/backend/langflow/graph/__init__.py index e63b9dcc0..04db8e5a1 100644 --- a/src/backend/langflow/graph/__init__.py +++ b/src/backend/langflow/graph/__init__.py @@ -15,6 +15,7 @@ from langflow.graph.vertex.types import ( VectorStoreVertex, WrapperVertex, RetrieverVertex, + CustomComponentVertex, ) __all__ = [ @@ -34,4 +35,5 @@ __all__ = [ "VectorStoreVertex", "WrapperVertex", "RetrieverVertex", + "CustomComponentVertex", ] diff --git a/src/backend/langflow/graph/edge/schema.py b/src/backend/langflow/graph/edge/schema.py new file mode 100644 index 000000000..ea691beae --- /dev/null +++ b/src/backend/langflow/graph/edge/schema.py @@ -0,0 +1,34 @@ +from typing import Any, List +from pydantic import BaseModel + + +class ResultPair(BaseModel): + result: Any + extra: Any + + +class Payload(BaseModel): + result_pairs: List[ResultPair] = [] + + def __iter__(self): + return iter(self.result_pairs) + + def add_result_pair(self, result: Any, extra: Any = None) -> None: + self.result_pairs.append(ResultPair(result=result, extra=extra)) + + def get_last_result_pair(self) -> ResultPair: + return self.result_pairs[-1] + + # format all but the last result pair + # into a string + def format(self, sep: str = "\n") -> str: + # Result: the result + # Extra: the extra if it exists don't show if it doesn't + return sep.join( + [ + f"Result: {result_pair.result}\nExtra: {result_pair.extra}" + if result_pair.extra is not None + else f"Result: {result_pair.result}" + for result_pair in self.result_pairs[:-1] + ] + ) diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py index 9514764b7..b8617e56a 100644 --- a/src/backend/langflow/graph/graph/constants.py +++ b/src/backend/langflow/graph/graph/constants.py @@ -1,22 +1,25 @@ from langflow.graph.vertex import types from langflow.interface.agents.base import agent_creator from langflow.interface.chains.base import chain_creator -from langflow.interface.custom.base import custom_component_creator from langflow.interface.document_loaders.base import documentloader_creator from langflow.interface.embeddings.base import embedding_creator from langflow.interface.llms.base import llm_creator from langflow.interface.memories.base import memory_creator -from langflow.interface.output_parsers.base import output_parser_creator from langflow.interface.prompts.base import prompt_creator -from langflow.interface.retrievers.base import retriever_creator from langflow.interface.text_splitters.base import textsplitter_creator from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.tools.base import tool_creator from langflow.interface.vector_store.base import vectorstore_creator from langflow.interface.wrappers.base import wrapper_creator +from langflow.interface.output_parsers.base import output_parser_creator +from langflow.interface.retrievers.base import retriever_creator +from langflow.interface.custom.base import custom_component_creator from langflow.utils.lazy_load import LazyLoadDictBase +chat_components = ["ChatInput", "ChatOutput", "TextInput", "SessionID"] + + class VertexTypesDict(LazyLoadDictBase): def __init__(self): self._all_types_dict = None @@ -32,9 +35,6 @@ class VertexTypesDict(LazyLoadDictBase): "Custom": ["Custom Tool", "Python Function"], } - def get_custom_component_vertex_type(self): - return types.CustomComponentVertex - def get_type_dict(self): return { **{t: types.PromptVertex for t in prompt_creator.to_list()}, @@ -50,8 +50,12 @@ class VertexTypesDict(LazyLoadDictBase): **{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()}, **{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()}, **{t: types.OutputParserVertex for t in output_parser_creator.to_list()}, - **{t: types.CustomComponentVertex for t in custom_component_creator.to_list()}, + **{ + t: types.CustomComponentVertex + for t in custom_component_creator.to_list() + }, **{t: types.RetrieverVertex for t in retriever_creator.to_list()}, + **{t: types.ChatVertex for t in chat_components}, } diff --git a/src/backend/langflow/graph/graph/utils.py b/src/backend/langflow/graph/graph/utils.py index d43fc4d84..71b81fea1 100644 --- a/src/backend/langflow/graph/graph/utils.py +++ b/src/backend/langflow/graph/graph/utils.py @@ -1,6 +1,5 @@ -import copy from collections import deque -from typing import Dict, List +import copy def find_last_node(nodes, edges): @@ -29,14 +28,23 @@ def ungroup_node(group_node_data, base_flow): g_edges = flow["data"]["edges"] # Redirect edges to the correct proxy node - updated_edges = get_updated_edges(base_flow, g_nodes, g_edges, group_node_data["id"]) + updated_edges = get_updated_edges( + base_flow, g_nodes, g_edges, group_node_data["id"] + ) # Update template values update_template(template, g_nodes) - nodes = [n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]] + g_nodes + nodes = [ + n for n in base_flow["nodes"] if n["id"] != group_node_data["id"] + ] + g_nodes edges = ( - [e for e in base_flow["edges"] if e["target"] != group_node_data["id"] and e["source"] != group_node_data["id"]] + [ + e + for e in base_flow["edges"] + if e["target"] != group_node_data["id"] + and e["source"] != group_node_data["id"] + ] + g_edges + updated_edges ) @@ -47,38 +55,6 @@ def ungroup_node(group_node_data, base_flow): return nodes -def raw_topological_sort(nodes, edges) -> List[Dict]: - # Redefine the above function but using the nodes and self._edges - # which are dicts instead of Vertex and Edge objects - # nodes have an id, edges have a source and target keys - # return a list of node ids in topological order - - # States: 0 = unvisited, 1 = visiting, 2 = visited - state = {node["id"]: 0 for node in nodes} - nodes_dict = {node["id"]: node for node in nodes} - sorted_vertices = [] - - def dfs(node): - if state[node] == 1: - # We have a cycle - raise ValueError("Graph contains a cycle, cannot perform topological sort") - if state[node] == 0: - state[node] = 1 - for edge in edges: - if edge["source"] == node: - dfs(edge["target"]) - state[node] = 2 - sorted_vertices.append(node) - - # Visit each node - for node in nodes: - if state[node["id"]] == 0: - dfs(node["id"]) - - reverse_sorted = list(reversed(sorted_vertices)) - return [nodes_dict[node_id] for node_id in reverse_sorted] - - def process_flow(flow_object): cloned_flow = copy.deepcopy(flow_object) processed_nodes = set() # To keep track of processed nodes @@ -90,7 +66,11 @@ def process_flow(flow_object): if node_id in processed_nodes: return - if node.get("data") and node["data"].get("node") and node["data"]["node"].get("flow"): + if ( + node.get("data") + and node["data"].get("node") + and node["data"]["node"].get("flow") + ): process_flow(node["data"]["node"]["flow"]["data"]) new_nodes = ungroup_node(node["data"], cloned_flow) # Add new nodes to the queue for future processing @@ -99,8 +79,7 @@ def process_flow(flow_object): # Mark node as processed processed_nodes.add(node_id) - sorted_nodes_list = raw_topological_sort(cloned_flow["nodes"], cloned_flow["edges"]) - nodes_to_process = deque(sorted_nodes_list) + nodes_to_process = deque(cloned_flow["nodes"]) while nodes_to_process: node = nodes_to_process.popleft() @@ -129,23 +108,29 @@ def update_template(template, g_nodes): if node_index != -1: display_name = None show = g_nodes[node_index]["data"]["node"]["template"][field]["show"] - advanced = g_nodes[node_index]["data"]["node"]["template"][field]["advanced"] + advanced = g_nodes[node_index]["data"]["node"]["template"][field][ + "advanced" + ] if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]: - display_name = g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] + display_name = g_nodes[node_index]["data"]["node"]["template"][field][ + "display_name" + ] else: - display_name = g_nodes[node_index]["data"]["node"]["template"][field]["name"] + display_name = g_nodes[node_index]["data"]["node"]["template"][field][ + "name" + ] g_nodes[node_index]["data"]["node"]["template"][field] = value g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show - g_nodes[node_index]["data"]["node"]["template"][field]["advanced"] = advanced - g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name + g_nodes[node_index]["data"]["node"]["template"][field][ + "advanced" + ] = advanced + g_nodes[node_index]["data"]["node"]["template"][field][ + "display_name" + ] = display_name -def update_target_handle( - new_edge, - g_nodes, - group_node_id, -): +def update_target_handle(new_edge, g_nodes, group_node_id): """ Updates the target handle of a given edge if it is a proxy node. @@ -162,8 +147,6 @@ def update_target_handle( proxy_id = target_handle["proxy"]["id"] if node := next((n for n in g_nodes if n["id"] == proxy_id), None): set_new_target_handle(proxy_id, new_edge, target_handle, node) - else: - raise ValueError(f"Group node {group_node_id} has an invalid target proxy node {proxy_id}") return new_edge diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 87e0ab879..19422b0c8 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,17 +1,19 @@ import ast -from typing import Any, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from langflow.graph.utils import UnbuiltObject, flatten_list -from langflow.graph.vertex.base import Vertex +from langflow.graph.vertex.base import StatefulVertex, StatelessVertex from langflow.interface.utils import extract_input_variables_from_prompt +from langflow.utils.schemas import ChatOutputResponse -class AgentVertex(Vertex): +class AgentVertex(StatelessVertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="agents", params=params) self.tools: List[Union[ToolkitVertex, ToolVertex]] = [] self.chains: List[ChainVertex] = [] + self.steps: List[Callable] = [self._custom_build, self._run] def __getstate__(self): state = super().__getstate__() @@ -26,84 +28,85 @@ class AgentVertex(Vertex): def _set_tools_and_chains(self) -> None: for edge in self.edges: - if not hasattr(edge, "source_id"): + if not hasattr(edge, "source"): continue - source_node = self.graph.get_vertex(edge.source_id) + source_node = edge.source if isinstance(source_node, (ToolVertex, ToolkitVertex)): self.tools.append(source_node) elif isinstance(source_node, ChainVertex): self.chains.append(source_node) - async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any: - if not self._built or force: - self._set_tools_and_chains() - # First, build the tools - for tool_node in self.tools: - await tool_node.build(user_id=user_id) + async def _custom_build(self, *args, **kwargs): + user_id = kwargs.get("user_id", None) + self._set_tools_and_chains() + # First, build the tools + for tool_node in self.tools: + await tool_node.build(user_id=user_id) - # Next, build the chains and the rest - for chain_node in self.chains: - await chain_node.build(tools=self.tools, user_id=user_id) + # Next, build the chains and the rest + for chain_node in self.chains: + await chain_node.build(tools=self.tools, user_id=user_id) - await self._build(user_id=user_id) - - return self._built_object + await self._build(user_id=user_id) -class ToolVertex(Vertex): - def __init__( - self, - data: Dict, - graph, - params: Optional[Dict] = None, - ): +class ToolVertex(StatelessVertex): + def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="tools", params=params) -class LLMVertex(Vertex): +class LLMVertex(StatelessVertex): built_node_type = None class_built_object = None def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="llms", params=params) + self.steps: List[Callable] = [self._custom_build] - async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any: + async def _custom_build(self, *args, **kwargs): # LLM is different because some models might take up too much memory - # or time to load. So we only load them when we need them.ß + # or time to load. So we only load them when we need them. + # Avoid deepcopying the LLM + # that are loaded from a file + force = kwargs.get("force", False) + user_id = kwargs.get("user_id", None) if self.vertex_type == self.built_node_type: - return self.class_built_object + self._built_object = self.class_built_object if not self._built or force: await self._build(user_id=user_id) self.built_node_type = self.vertex_type self.class_built_object = self._built_object - # Avoid deepcopying the LLM - # that are loaded from a file - return self._built_object -class ToolkitVertex(Vertex): +class ToolkitVertex(StatelessVertex): def __init__(self, data: Dict, graph, params=None): super().__init__(data, graph=graph, base_type="toolkits", params=params) class FileToolVertex(ToolVertex): def __init__(self, data: Dict, graph, params=None): - super().__init__(data, graph=graph, params=params) + super().__init__( + data, + params=params, + graph=graph, + ) -class WrapperVertex(Vertex): - def __init__(self, data: Dict, graph): +class WrapperVertex(StatelessVertex): + def __init__(self, data: Dict, graph, params=None): super().__init__(data, graph=graph, base_type="wrappers") + self.steps: List[Callable] = [self._custom_build] - async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any: + async def _custom_build(self, *args, **kwargs): + force = kwargs.get("force", False) + user_id = kwargs.get("user_id", None) if not self._built or force: if "headers" in self.params: self.params["headers"] = ast.literal_eval(self.params["headers"]) await self._build(user_id=user_id) - return self._built_object -class DocumentLoaderVertex(Vertex): +class DocumentLoaderVertex(StatefulVertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="documentloaders", params=params) @@ -111,7 +114,7 @@ class DocumentLoaderVertex(Vertex): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? - if self._built_object and not isinstance(self._built_object, UnbuiltObject): + if not isinstance(self._built_object, UnbuiltObject): avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len( self._built_object ) @@ -121,12 +124,12 @@ class DocumentLoaderVertex(Vertex): return f"{self.vertex_type}()" -class EmbeddingVertex(Vertex): +class EmbeddingVertex(StatefulVertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="embeddings", params=params) -class VectorStoreVertex(Vertex): +class VectorStoreVertex(StatefulVertex): def __init__(self, data: Dict, graph, params=None): super().__init__(data, graph=graph, base_type="vectorstores") @@ -135,6 +138,15 @@ class VectorStoreVertex(Vertex): # VectorStores may contain databse connections # so we need to define the __reduce__ method and the __setstate__ method # to avoid pickling errors + def clean_edges_for_pickling(self): + # for each edge that has self as source + # we need to clear the _built_object of the target + # so that we don't try to pickle a database connection + for edge in self.edges: + if edge.source == self: + edge.target._built_object = None + edge.target._built = False + edge.target.params[edge.target_param] = self def remove_docs_and_texts_from_params(self): # remove documents and texts from params @@ -142,33 +154,34 @@ class VectorStoreVertex(Vertex): self.params.pop("documents", None) self.params.pop("texts", None) - # def __getstate__(self): - # # We want to save the params attribute - # # and if "documents" or "texts" are in the params - # # we want to remove them because they have already - # # been processed. - # params = self.params.copy() - # params.pop("documents", None) - # params.pop("texts", None) + def __getstate__(self): + # We want to save the params attribute + # and if "documents" or "texts" are in the params + # we want to remove them because they have already + # been processed. + params = self.params.copy() + params.pop("documents", None) + params.pop("texts", None) + self.clean_edges_for_pickling() - # return super().__getstate__() + return super().__getstate__() def __setstate__(self, state): super().__setstate__(state) self.remove_docs_and_texts_from_params() -class MemoryVertex(Vertex): +class MemoryVertex(StatefulVertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="memory") -class RetrieverVertex(Vertex): +class RetrieverVertex(StatefulVertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="retrievers") -class TextSplitterVertex(Vertex): +class TextSplitterVertex(StatefulVertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="textsplitters", params=params) @@ -176,7 +189,7 @@ class TextSplitterVertex(Vertex): # This built_object is a list of documents. Maybe we should # show how many documents are in the list? - if self._built_object and not isinstance(self._built_object, UnbuiltObject): + if not isinstance(self._built_object, UnbuiltObject): avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} @@ -184,54 +197,51 @@ class TextSplitterVertex(Vertex): return f"{self.vertex_type}()" -class ChainVertex(Vertex): +class ChainVertex(StatelessVertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="chains") + self.steps = [self._custom_build, self._run] - async def build( - self, - force: bool = False, - user_id=None, - *args, - **kwargs, - ) -> Any: - if not self._built or force: - # Temporarily remove the code from the params - self.params.pop("code", None) - # Check if the chain requires a PromptVertex + async def _custom_build(self, *args, **kwargs): + force = kwargs.get("force", False) + user_id = kwargs.get("user_id", None) + # Remove this once LLMChain is CustomComponent + self.params.pop("code", None) + for key, value in self.params.items(): + if isinstance(value, PromptVertex): + # Build the PromptVertex, passing the tools if available + tools = kwargs.get("tools", None) + self.params[key] = value.build(tools=tools, pinned=force) - # Temporarily remove "code" from the params - self.params.pop("code", None) + await self._build(user_id=user_id) - for key, value in self.params.items(): - if isinstance(value, PromptVertex): - # Build the PromptVertex, passing the tools if available - tools = kwargs.get("tools", None) - self.params[key] = await value.build(tools=tools, force=force) + def set_artifacts(self) -> None: + if isinstance(self._built_object, UnbuiltObject): + return + if self._built_object and hasattr(self._built_object, "input_keys"): + self.artifacts = dict(input_keys=self._built_object.input_keys) - await self._build(user_id=user_id) - - return self._built_object + def _built_object_repr(self): + if isinstance(self._built_object, str): + return self._built_object + return super()._built_object_repr() -class PromptVertex(Vertex): +class PromptVertex(StatelessVertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="prompts") + self.steps: List[Callable] = [self._custom_build] - async def build( - self, - force: bool = False, - user_id=None, - tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None, - *args, - **kwargs, - ) -> Any: + async def _custom_build(self, *args, **kwargs): + force = kwargs.get("force", False) + user_id = kwargs.get("user_id", None) + tools = kwargs.get("tools", []) if not self._built or force: if "input_variables" not in self.params or self.params["input_variables"] is None: self.params["input_variables"] = [] # Check if it is a ZeroShotPrompt and needs a tool if "ShotPrompt" in self.vertex_type: - tools = [await tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] + tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): @@ -253,11 +263,12 @@ class PromptVertex(Vertex): self.params.pop("input_variables", None) await self._build(user_id=user_id) - return self._built_object def _built_object_repr(self): if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"): return super()._built_object_repr() + elif isinstance(self._built_object, UnbuiltObject): + return super()._built_object_repr() # We'll build the prompt with the artifacts # to show the user what the prompt looks like # with the variables filled in @@ -265,15 +276,10 @@ class PromptVertex(Vertex): # Remove the handle_keys from the artifacts # so the prompt format doesn't break artifacts.pop("handle_keys", None) - template = "" try: - if ( - not hasattr(self._built_object, "template") - and hasattr(self._built_object, "prompt") - and not isinstance(self._built_object, UnbuiltObject) - ): + if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"): template = self._built_object.prompt.template - elif not isinstance(self._built_object, UnbuiltObject) and hasattr(self._built_object, "template"): + else: template = self._built_object.template for key, value in artifacts.items(): if value: @@ -284,14 +290,24 @@ class PromptVertex(Vertex): return str(self._built_object) -class OutputParserVertex(Vertex): +class OutputParserVertex(StatelessVertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="output_parsers") -class CustomComponentVertex(Vertex): +class CustomComponentVertex(StatelessVertex): def __init__(self, data: Dict, graph): - super().__init__(data, graph=graph, base_type="custom_components", is_task=False) + super().__init__(data, graph=graph, base_type="custom_components") + + def _built_object_repr(self): + if self.artifacts and "repr" in self.artifacts: + return self.artifacts["repr"] or super()._built_object_repr() + + +class ChatVertex(StatelessVertex): + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="custom_components", is_task=True) + self.steps = [self._build, self._run] def _built_object_repr(self): if self.task_id and self.is_task: @@ -301,3 +317,18 @@ class CustomComponentVertex(Vertex): return f"Task {self.task_id} is not running" if self.artifacts and "repr" in self.artifacts: return self.artifacts["repr"] or super()._built_object_repr() + + def _run(self, *args, **kwargs): + if self.is_power_component: + if self.vertex_type == "ChatOutput": + sender = self.params.get("sender", None) + sender_name = self.params.get("sender_name", None) + self.artifacts = ChatOutputResponse( + message=str(self._built_object), + sender=sender, + sender_name=sender_name, + ).dict() + self._built_result = self._built_object + + else: + super()._run(*args, **kwargs) diff --git a/src/backend/langflow/utils/schemas.py b/src/backend/langflow/utils/schemas.py new file mode 100644 index 000000000..8f0372d8f --- /dev/null +++ b/src/backend/langflow/utils/schemas.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel +from typing import Optional + + +class ChatOutputResponse(BaseModel): + """Chat output response schema.""" + + message: str + sender: Optional[str] = "Machine" + sender_name: Optional[str] = "AI"