From 3d4ab248588712240ce6653afbf0607623bc3f6a Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 1 Mar 2024 23:57:01 -0300 Subject: [PATCH] Add activated_vertices to VertexBuildResponse and update state management in Graph --- src/backend/langflow/api/v1/schemas.py | 1 + src/backend/langflow/graph/graph/base.py | 50 ++++++++++++ src/backend/langflow/graph/vertex/base.py | 26 +++--- src/backend/langflow/graph/vertex/types.py | 92 +++++++++++++++------- 4 files changed, 125 insertions(+), 44 deletions(-) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 044d98507..88610d637 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -230,6 +230,7 @@ class ResultDataResponse(BaseModel): class VertexBuildResponse(BaseModel): id: Optional[str] = None inactivated_vertices: Optional[List[str]] = None + activated_vertices: Optional[List[str]] = None valid: bool params: Optional[str] """JSON string of the params.""" diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 0b5af21f6..558fac152 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -16,9 +16,11 @@ from langflow.graph.vertex.types import ( FileToolVertex, LLMVertex, RoutingVertex, + StateVertex, ToolkitVertex, ) from langflow.interface.tools.constants import FILE_TOOLS +from langflow.schema import Record from langflow.utils import payload if TYPE_CHECKING: @@ -55,6 +57,7 @@ class Graph: self._vertices = self._graph_data["nodes"] self._edges = self._graph_data["edges"] self.inactivated_vertices: set = set() + self.activated_vertices: set = set() self.edges: List[ContractEdge] = [] self.vertices: List[Vertex] = [] self._build_graph() @@ -62,6 +65,37 @@ class Graph: self.define_vertices_lists() self.state_manager = GraphStateManager() + def update_state( + self, name: str, record: Union[str, Record], caller: Optional[str] = None + ) -> None: + """Updates the state of the graph.""" + if caller: + # If there is a caller which is a vertex_id, I want to activate + # all StateVertex in self.vertices that are not the caller + # essentially notifying all the other vertices that the state has changed + # This also has to activate their successors + caller_vertex = self.get_vertex(caller) + for vertex in self.vertices: + if vertex.id != caller and isinstance(vertex, StateVertex): + successors = self.get_all_successors(vertex) + self.activated_vertices.add(vertex.id) + for successor in successors: + self.activated_vertices.add(successor.id) + + self.state_manager.update_state(name, record) + + def reset_activated_vertices(self): + self.activated_vertices = set() + + def append_state( + self, name: str, record: Union[str, Record], caller: Optional[str] = None + ) -> None: + """Appends the state of the graph.""" + if caller: + self.state_manager.subscribe(name, caller) + + self.state_manager.append_state(name, record) + def set_run_id(self, run_id: str): for vertex in self.vertices: self.state_manager.subscribe(run_id, vertex.update_graph_state) @@ -500,6 +534,20 @@ class Graph: for source_id in self.predecessor_map.get(vertex.id, []) ] + def get_all_successors(self, vertex, recursive=True): + # Recursively get the successors of the current vertex + successors = vertex.successors + if not successors: + return [] + successors_result = [] + for successor in successors: + # Just return a list of successors + if recursive: + next_successors = self.get_all_successors(successor) + successors_result.extend(next_successors) + successors_result.append(successor) + return successors_result + def get_successors(self, vertex): """Returns the successors of a vertex.""" return [ @@ -561,6 +609,8 @@ class Graph: return ChatVertex elif node_name in ["ShouldRunNext"]: return RoutingVertex + elif node_name in ["SharedState"]: + return StateVertex elif node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type] elif node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP: diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index a550b5b0f..406652388 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -88,12 +88,9 @@ class Vertex: def update_graph_state(self, key, new_state, append: bool): if append: - if key in self.graph_state: - self.graph_state[key].append(new_state) - else: - self.graph_state[key] = [new_state] + self.graph.append_state(key, new_state, caller=self.id) else: - self.graph_state[key] = new_state + self.graph.update_state(key, new_state, caller=self.id) def set_state(self, state: str): self.state = VertexStates[state] @@ -511,7 +508,16 @@ class Vertex: self.params[key] = [] self.params[key].extend(built) else: - self.params[key].append(built) + try: + if self.params[key] == built: + continue + + self.params[key].append(built) + except AttributeError as e: + logger.exception(e) + raise ValueError( + f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {built}" + ) from e def _handle_func(self, key, result): """ @@ -670,11 +676,3 @@ class Vertex: if self._built_object is not None else "Failed to build 😵‍💫" ) - - -class StatefulVertex(Vertex): - pass - - -class StatelessVertex(Vertex): - pass diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index c4f33df40..ba1cd2998 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,6 +1,7 @@ import ast import json -from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union +from typing import (AsyncIterator, Callable, Dict, Iterator, List, Optional, + Union) import yaml from langchain_core.messages import AIMessage @@ -8,14 +9,14 @@ from loguru import logger from langflow.graph.schema import INPUT_FIELD_NAME from langflow.graph.utils import UnbuiltObject, flatten_list, serialize_field -from langflow.graph.vertex.base import StatefulVertex, StatelessVertex +from langflow.graph.vertex.base import Vertex from langflow.interface.utils import extract_input_variables_from_prompt from langflow.schema import Record from langflow.services.monitor.utils import log_vertex_build from langflow.utils.schemas import ChatOutputResponse -class AgentVertex(StatelessVertex): +class AgentVertex(Vertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="agents", params=params) @@ -58,12 +59,12 @@ class AgentVertex(StatelessVertex): await self._build(user_id=user_id) -class ToolVertex(StatelessVertex): +class ToolVertex(Vertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="tools", params=params) -class LLMVertex(StatelessVertex): +class LLMVertex(Vertex): built_node_type = None class_built_object = None @@ -86,7 +87,7 @@ class LLMVertex(StatelessVertex): self.class_built_object = self._built_object -class ToolkitVertex(StatelessVertex): +class ToolkitVertex(Vertex): def __init__(self, data: Dict, graph, params=None): super().__init__(data, graph=graph, base_type="toolkits", params=params) @@ -100,7 +101,7 @@ class FileToolVertex(ToolVertex): ) -class WrapperVertex(StatelessVertex): +class WrapperVertex(Vertex): def __init__(self, data: Dict, graph, params=None): super().__init__(data, graph=graph, base_type="wrappers") self.steps: List[Callable] = [self._custom_build] @@ -114,7 +115,7 @@ class WrapperVertex(StatelessVertex): await self._build(user_id=user_id) -class DocumentLoaderVertex(StatefulVertex): +class DocumentLoaderVertex(Vertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="documentloaders", params=params) @@ -123,21 +124,23 @@ class DocumentLoaderVertex(StatefulVertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len( - self._built_object - ) + avg_length = sum( + len(doc.page_content) + for doc in self._built_object + if hasattr(doc, "page_content") + ) / len(self._built_object) return f"""{self.display_name}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} Documents: {self._built_object[:3]}...""" return f"{self.vertex_type}()" -class EmbeddingVertex(StatefulVertex): +class EmbeddingVertex(Vertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="embeddings", params=params) -class VectorStoreVertex(StatefulVertex): +class VectorStoreVertex(Vertex): def __init__(self, data: Dict, graph, params=None): super().__init__(data, graph=graph, base_type="vectorstores") @@ -179,17 +182,17 @@ class VectorStoreVertex(StatefulVertex): self.remove_docs_and_texts_from_params() -class MemoryVertex(StatefulVertex): +class MemoryVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="memory") -class RetrieverVertex(StatefulVertex): +class RetrieverVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="retrievers") -class TextSplitterVertex(StatefulVertex): +class TextSplitterVertex(Vertex): def __init__(self, data: Dict, graph, params: Optional[Dict] = None): super().__init__(data, graph=graph, base_type="textsplitters", params=params) @@ -198,14 +201,16 @@ class TextSplitterVertex(StatefulVertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object) + avg_length = sum(len(doc.page_content) for doc in self._built_object) / len( + self._built_object + ) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} \nDocuments: {self._built_object[:3]}...""" return f"{self.vertex_type}()" -class ChainVertex(StatelessVertex): +class ChainVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="chains") self.steps = [self._custom_build] @@ -235,7 +240,7 @@ class ChainVertex(StatelessVertex): return super()._built_object_repr() -class PromptVertex(StatelessVertex): +class PromptVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="prompts") self.steps: List[Callable] = [self._custom_build] @@ -245,18 +250,27 @@ class PromptVertex(StatelessVertex): user_id = kwargs.get("user_id", None) tools = kwargs.get("tools", []) if not self._built or force: - if "input_variables" not in self.params or self.params["input_variables"] is None: + if ( + "input_variables" not in self.params + or self.params["input_variables"] is None + ): self.params["input_variables"] = [] # Check if it is a ZeroShotPrompt and needs a tool if "ShotPrompt" in self.vertex_type: - tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] + tools = ( + [tool_node.build(user_id=user_id) for tool_node in tools] + if tools is not None + else [] + ) # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): tools = flatten_list(tools) self.params["tools"] = tools prompt_params = [ - key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions" + key + for key, value in self.params.items() + if isinstance(value, str) and key != "format_instructions" ] else: prompt_params = ["template"] @@ -266,14 +280,20 @@ class PromptVertex(StatelessVertex): prompt_text = self.params[param] variables = extract_input_variables_from_prompt(prompt_text) self.params["input_variables"].extend(variables) - self.params["input_variables"] = list(set(self.params["input_variables"])) + self.params["input_variables"] = list( + set(self.params["input_variables"]) + ) elif isinstance(self.params, dict): self.params.pop("input_variables", None) await self._build(user_id=user_id) def _built_object_repr(self): - if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"): + if ( + not self.artifacts + or self._built_object is None + or not hasattr(self._built_object, "format") + ): return super()._built_object_repr() elif isinstance(self._built_object, UnbuiltObject): return super()._built_object_repr() @@ -285,7 +305,9 @@ class PromptVertex(StatelessVertex): # so the prompt format doesn't break artifacts.pop("handle_keys", None) try: - if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"): + if not hasattr(self._built_object, "template") and hasattr( + self._built_object, "prompt" + ): template = self._built_object.prompt.template else: template = self._built_object.template @@ -293,17 +315,21 @@ class PromptVertex(StatelessVertex): if value: replace_key = "{" + key + "}" template = template.replace(replace_key, value) - return template if isinstance(template, str) else f"{self.vertex_type}({template})" + return ( + template + if isinstance(template, str) + else f"{self.vertex_type}({template})" + ) except KeyError: return str(self._built_object) -class OutputParserVertex(StatelessVertex): +class OutputParserVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="output_parsers") -class CustomComponentVertex(StatelessVertex): +class CustomComponentVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="custom_components") @@ -312,7 +338,7 @@ class CustomComponentVertex(StatelessVertex): return self.artifacts["repr"] or super()._built_object_repr() -class ChatVertex(StatelessVertex): +class ChatVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="custom_components", is_task=True) self.steps = [self._build, self._run] @@ -431,7 +457,7 @@ class ChatVertex(StatelessVertex): pass -class RoutingVertex(StatelessVertex): +class RoutingVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="custom_components") self.use_result = True @@ -457,6 +483,12 @@ class RoutingVertex(StatelessVertex): self._built_result = None +class StateVertex(Vertex): + def __init__(self, data: Dict, graph): + super().__init__(data, graph=graph, base_type="custom_components") + self.steps = [self._build] + + def dict_to_codeblock(d: dict) -> str: serialized = {key: serialize_field(val) for key, val in d.items()} json_str = json.dumps(serialized, indent=4)