From 95813de61862d23badd642152497946bef8f20f6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 30 Apr 2024 16:33:53 -0300 Subject: [PATCH] Fixes State management and first implementation of vertex inactivation (#1805) * Refactor RoutingVertex class in types.py to include a new step in the steps list * Refactor Graph class to update run_manager.run_predecessors when activating vertices * Refactor ShouldRunNextComponent class in ShouldRunNext.py to improve error handling and readability * Refactor error handling in get_lifespan function in main.py and fix mismatch between models and database in DatabaseService class in service.py * Fix vertex inactivation * Fix inactivation of vertices and update buildUtils * Refactor ShouldRunNextComponent class to improve error handling and readability * Refactor langflow.graph.vertex.types imports in base.py * Fix nullable constraint in langflow migration script * Fix condition check in ShouldRunNextComponent class * Refactor build_graph_maps function in base.py to accept optional parameters for edges and vertices * Apply grayscale effect to ring-muted-foreground in applies.css * Apply grayscale effect to ring-muted-foreground in applies.css --- .../versions/6e7b581b5648_fix_nullable.py | 4 +- .../components/helpers/ShouldRunNext.py | 30 ++++++++++++ src/backend/base/langflow/graph/graph/base.py | 49 +++++++++++++------ .../base/langflow/graph/graph/constants.py | 2 - .../graph/graph/runnable_vertices_manager.py | 1 + .../base/langflow/graph/vertex/base.py | 1 - .../base/langflow/graph/vertex/types.py | 36 +------------- .../custom_component/custom_component.py | 8 +++ src/backend/base/langflow/main.py | 1 + .../langflow/services/database/service.py | 6 +-- src/frontend/src/stores/flowStore.ts | 12 ++++- src/frontend/src/style/applies.css | 2 +- src/frontend/src/utils/buildUtils.ts | 22 +++++++-- 13 files changed, 106 insertions(+), 68 deletions(-) create mode 100644 src/backend/base/langflow/components/helpers/ShouldRunNext.py diff --git a/src/backend/base/langflow/alembic/versions/6e7b581b5648_fix_nullable.py b/src/backend/base/langflow/alembic/versions/6e7b581b5648_fix_nullable.py index 4ecd15714..71e81128c 100644 --- a/src/backend/base/langflow/alembic/versions/6e7b581b5648_fix_nullable.py +++ b/src/backend/base/langflow/alembic/versions/6e7b581b5648_fix_nullable.py @@ -33,7 +33,7 @@ def upgrade() -> None: "created_at", existing_type=sa.DATETIME(), nullable=False, - existing_server_default=sa.text("(CURRENT_TIMESTAMP)"), + existing_server_default=sa.text("(CURRENT_TIMESTAMP)"), # type: ignore ) # ### end Alembic commands ### @@ -53,7 +53,7 @@ def downgrade() -> None: "created_at", existing_type=sa.DATETIME(), nullable=True, - existing_server_default=sa.text("(CURRENT_TIMESTAMP)"), + existing_server_default=sa.text("(CURRENT_TIMESTAMP)"), # type: ignore ) # ### end Alembic commands ### diff --git a/src/backend/base/langflow/components/helpers/ShouldRunNext.py b/src/backend/base/langflow/components/helpers/ShouldRunNext.py new file mode 100644 index 000000000..0d20706ea --- /dev/null +++ b/src/backend/base/langflow/components/helpers/ShouldRunNext.py @@ -0,0 +1,30 @@ +from langchain_core.messages import BaseMessage +from langchain_core.prompts import PromptTemplate + +from langflow.custom import CustomComponent +from langflow.field_typing import BaseLanguageModel, Text + + +class ShouldRunNextComponent(CustomComponent): + display_name = "Should Run Next" + description = "Determines if a vertex is runnable." + + def build(self, llm: BaseLanguageModel, question: str, context: str, retries: int = 3) -> Text: + template = "Given the following question and the context below, answer with a yes or no.\n\n{error_message}\n\nQuestion: {question}\n\nContext: {context}\n\nAnswer:" + + prompt = PromptTemplate.from_template(template) + chain = prompt | llm + error_message = "" + for i in range(retries): + result = chain.invoke(dict(question=question, context=context, error_message=error_message)) + if isinstance(result, BaseMessage): + content = result.content + elif isinstance(result, str): + content = result + if isinstance(content, str) and content.lower().strip() in ["yes", "no"]: + break + condition = str(content).lower().strip() == "yes" + self.status = f"Should Run Next: {condition}" + if condition is False: + self.stop() + return context diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index a0661f4f2..bddfd6795 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -3,7 +3,7 @@ import uuid from collections import defaultdict, deque from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Type, Union +from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union from loguru import logger @@ -14,7 +14,7 @@ from langflow.graph.graph.state_manager import GraphStateManager from langflow.graph.graph.utils import process_flow from langflow.graph.schema import InterfaceComponentTypes, RunOutputs from langflow.graph.vertex.base import Vertex -from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, RoutingVertex, StateVertex, ToolkitVertex +from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, StateVertex, ToolkitVertex from langflow.interface.tools.constants import FILE_TOOLS from langflow.schema import Record from langflow.schema.schema import INPUT_FIELD_NAME, InputType @@ -75,7 +75,7 @@ class Graph: self.vertices: List[Vertex] = [] self.run_manager = RunnableVerticesManager() self._build_graph() - self.build_graph_maps() + self.build_graph_maps(self.edges) self.define_vertices_lists() self.state_manager = GraphStateManager() @@ -130,6 +130,18 @@ class Graph: ): vertices_ids.append(vertex_id) successors = self.get_all_successors(vertex, flat=True) + # Update run_manager.run_predecessors because we are activating vertices + # The run_prdecessors is the predecessor map of the vertices + # we remove the vertex_id from the predecessor map whenever we run a vertex + # So we need to get all edges of the vertex and successors + # and run self.build_adjacency_maps(edges) to get the new predecessor map + # that is not complete but we can use to update the run_predecessors + edges_set = set() + for vertex in [vertex] + successors: + edges_set.update(vertex.edges) + edges = list(edges_set) + new_predecessor_map, _ = self.build_adjacency_maps(edges) + self.run_manager.run_predecessors.update(new_predecessor_map) self.vertices_to_run.update(list(map(lambda x: x.id, successors))) self.activated_vertices = vertices_ids self.vertices_to_run.update(vertices_ids) @@ -401,14 +413,20 @@ class Graph: "inactivated_vertices": self.inactivated_vertices, } - def build_graph_maps(self): + def build_graph_maps(self, edges: Optional[List[ContractEdge]] = None, vertices: Optional[List[Vertex]] = None): """ Builds the adjacency maps for the graph. """ - self.predecessor_map, self.successor_map = self.build_adjacency_maps() + if edges is None: + edges = self.edges - self.in_degree_map = self.build_in_degree() - self.parent_child_map = self.build_parent_child_map() + if vertices is None: + vertices = self.vertices + + self.predecessor_map, self.successor_map = self.build_adjacency_maps(edges) + + self.in_degree_map = self.build_in_degree(edges) + self.parent_child_map = self.build_parent_child_map(vertices) def reset_inactivated_vertices(self): """ @@ -433,9 +451,9 @@ class Graph: for child_id in self.parent_child_map[vertex_id]: self.mark_branch(child_id, state) - def build_parent_child_map(self): + def build_parent_child_map(self, vertices: List[Vertex]): parent_child_map = defaultdict(list) - for vertex in self.vertices: + for vertex in vertices: parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] return parent_child_map @@ -559,6 +577,7 @@ class Graph: self.update_vertex_from_another(self_vertex, other_vertex) self.build_graph_maps() + self.define_vertices_lists() self.increment_update_count() return self @@ -944,8 +963,6 @@ class Graph: node_name = node_id.split("-")[0] if node_name in ["ChatOutput", "ChatInput"]: return ChatVertex - elif node_name in ["ShouldRunNext"]: - return RoutingVertex elif node_name in ["SharedState", "Notify", "Listen"]: return StateVertex elif node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP: @@ -1277,17 +1294,17 @@ class Graph: def remove_from_predecessors(self, vertex_id: str): self.run_manager.remove_from_predecessors(vertex_id) - def build_in_degree(self): - in_degree = defaultdict(int) - for edge in self.edges: + def build_in_degree(self, edges: List[ContractEdge]) -> Dict[str, int]: + in_degree: Dict[str, int] = defaultdict(int) + for edge in edges: in_degree[edge.target_id] += 1 return in_degree - def build_adjacency_maps(self): + def build_adjacency_maps(self, edges: List[ContractEdge]) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: """Returns the adjacency maps for the graph.""" predecessor_map = defaultdict(list) successor_map = defaultdict(list) - for edge in self.edges: + for edge in edges: predecessor_map[edge.target_id].append(edge.source_id) successor_map[edge.source_id].append(edge.target_id) return predecessor_map, successor_map diff --git a/src/backend/base/langflow/graph/graph/constants.py b/src/backend/base/langflow/graph/graph/constants.py index 658c3c68a..f948a0753 100644 --- a/src/backend/base/langflow/graph/graph/constants.py +++ b/src/backend/base/langflow/graph/graph/constants.py @@ -15,7 +15,6 @@ from langflow.interface.wrappers.base import wrapper_creator from langflow.utils.lazy_load import LazyLoadDictBase CHAT_COMPONENTS = ["ChatInput", "ChatOutput", "TextInput", "SessionID"] -ROUTING_COMPONENTS = ["ShouldRunNext"] class VertexTypesDict(LazyLoadDictBase): @@ -51,7 +50,6 @@ class VertexTypesDict(LazyLoadDictBase): **{t: types.CustomComponentVertex for t in custom_component_creator.to_list()}, **{t: types.RetrieverVertex for t in retriever_creator.to_list()}, **{t: types.ChatVertex for t in CHAT_COMPONENTS}, - **{t: types.RoutingVertex for t in ROUTING_COMPONENTS}, } def get_custom_component_vertex_type(self): diff --git a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py index 875f43472..713aead65 100644 --- a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py +++ b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py @@ -15,6 +15,7 @@ class RunnableVerticesManager: def is_vertex_runnable(self, vertex_id: str) -> bool: """Determines if a vertex is runnable.""" + return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id) def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]: diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index f300b2dda..e250e9419 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -72,7 +72,6 @@ class Vertex: self.load_from_db_fields: List[str] = [] self.parent_is_top_level = False self.layer = None - self.should_run = True self.result: Optional[ResultData] = None try: self.is_interface_component = self.vertex_type in InterfaceComponentTypes diff --git a/src/backend/base/langflow/graph/vertex/types.py b/src/backend/base/langflow/graph/vertex/types.py index d7a98df8a..87f4856b2 100644 --- a/src/backend/base/langflow/graph/vertex/types.py +++ b/src/backend/base/langflow/graph/vertex/types.py @@ -1,6 +1,7 @@ import ast import json from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union + import yaml from langchain_core.messages import AIMessage from loguru import logger @@ -438,41 +439,6 @@ class ChatVertex(Vertex): return self.vertex_type == InterfaceComponentTypes.ChatInput and self.is_input -class RoutingVertex(Vertex): - def __init__(self, data: Dict, graph): - super().__init__(data, graph=graph, base_type="custom_components") - self.use_result = True - self.steps = [self._build] - - def _built_object_repr(self): - if self.artifacts and "repr" in self.artifacts: - return self.artifacts["repr"] or super()._built_object_repr() - return super()._built_object_repr() - - @property - def successors_ids(self): - if isinstance(self._built_object, bool): - ids = super().successors_ids - if self._built_object: - return ids - return [] - raise ValueError("RoutingVertex should return a boolean value.") - - def _run(self, *args, **kwargs): - if self._built_object: - condition = self._built_object.get("condition") - result = self._built_object.get("result") - if condition is None: - raise ValueError("Condition is required for the routing vertex.") - if result is None: - raise ValueError("Result is required for the routing vertex.") - if condition is True: - self._built_result = result - else: - self.graph.mark_branch(self.id, "INACTIVE") - self._built_result = None - - class StateVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="custom_components") diff --git a/src/backend/base/langflow/interface/custom/custom_component/custom_component.py b/src/backend/base/langflow/interface/custom/custom_component/custom_component.py index e6f0c4652..1638ebb2c 100644 --- a/src/backend/base/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/interface/custom/custom_component/custom_component.py @@ -87,6 +87,14 @@ class CustomComponent(Component): except Exception as e: raise ValueError(f"Error updating state: {e}") + def stop(self): + if not self.vertex: + raise ValueError("Vertex is not set") + try: + self.graph.mark_branch(self.vertex.id, "INACTIVE") + except Exception as e: + raise ValueError(f"Error stopping {self.display_name}: {e}") + def append_state(self, name: str, value: Any): if not self.vertex: raise ValueError("Vertex is not set") diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index 81797dc80..697cfa226 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -53,6 +53,7 @@ def get_lifespan(fix_migration=False, socketio_server=None): except Exception as exc: if "langflow migration --fix" not in str(exc): logger.error(exc) + raise # Shutdown message rprint("[bold red]Shutting down Langflow...[/bold red]") teardown_services() diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index 5ca20c642..14c79f85b 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -133,7 +133,7 @@ class DatabaseService(Service): alembic_cfg = Config(stdout=buffer) # alembic_cfg.attributes["connection"] = session alembic_cfg.set_main_option("script_location", str(self.script_location)) - alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace('%', '%%')) + alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%")) should_initialize_alembic = False with Session(self.engine) as session: @@ -170,9 +170,7 @@ class DatabaseService(Service): except util.exc.AutogenerateDiffsDetected as exc: logger.error(f"AutogenerateDiffsDetected: {exc}") if not fix: - raise RuntimeError( - "Something went wrong running migrations. Please, run `langflow migration --fix`" - ) from exc + raise RuntimeError(f"There's a mismatch between the models and the database.\n{exc}") if fix: self.try_downgrade_upgrade_until_success(alembic_cfg) diff --git a/src/frontend/src/stores/flowStore.ts b/src/frontend/src/stores/flowStore.ts index bcc8f6309..f6077c5e9 100644 --- a/src/frontend/src/stores/flowStore.ts +++ b/src/frontend/src/stores/flowStore.ts @@ -478,8 +478,13 @@ const useFlowStore = create((set, get) => ({ // const nextVertices will be the zip of vertexBuildData.next_vertices_ids and // vertexBuildData.top_level_vertices // the VertexLayerElementType as {id: next_vertices_id, layer: top_level_vertex} + + // next_vertices_ids should be next_vertices_ids without the inactivated vertices + const next_vertices_ids = vertexBuildData.next_vertices_ids.filter( + (id) => !vertexBuildData.inactivated_vertices?.includes(id) + ); const nextVertices: VertexLayerElementType[] = zip( - vertexBuildData.next_vertices_ids, + next_vertices_ids, vertexBuildData.top_level_vertices ).map(([id, reference]) => ({ id: id!, reference })); @@ -489,7 +494,7 @@ const useFlowStore = create((set, get) => ({ ]; const newIds = [ ...get().verticesBuild!.verticesIds, - ...vertexBuildData.next_vertices_ids, + ...next_vertices_ids, ]; get().updateVerticesBuild({ verticesIds: newIds, @@ -598,7 +603,10 @@ const useFlowStore = create((set, get) => ({ set({ verticesBuild: { ...verticesBuild, + // remove the vertices from the list of vertices ids + // that are going to be built verticesIds: get().verticesBuild!.verticesIds.filter( + // keep the vertices that are not in the list of vertices to remove (vertex) => !vertices.includes(vertex) ), }, diff --git a/src/frontend/src/style/applies.css b/src/frontend/src/style/applies.css index 268f838eb..4b016eb73 100644 --- a/src/frontend/src/style/applies.css +++ b/src/frontend/src/style/applies.css @@ -323,7 +323,7 @@ muted-foreground is too strong, maybe use a lighter shade of it? */ - @apply border-none ring ring-muted-foreground; + @apply border-none ring grayscale; } .built-invalid-status { @apply border-none ring ring-[#FF9090]; diff --git a/src/frontend/src/utils/buildUtils.ts b/src/frontend/src/utils/buildUtils.ts index bc2cc62e7..6b462fe04 100644 --- a/src/frontend/src/utils/buildUtils.ts +++ b/src/frontend/src/utils/buildUtils.ts @@ -166,14 +166,26 @@ export async function buildVertices({ !useFlowStore .getState() .verticesBuild?.verticesIds.includes(element.id) && + !useFlowStore + .getState() + .verticesBuild?.verticesIds.includes(element.reference ?? "") && onBuildUpdate ) { // If it is, skip building and set the state to inactive - onBuildUpdate( - getInactiveVertexData(element.id), - BuildStatus.INACTIVE, - runId - ); + if (element.id) { + onBuildUpdate( + getInactiveVertexData(element.id), + BuildStatus.INACTIVE, + runId + ); + } + if (element.reference) { + onBuildUpdate( + getInactiveVertexData(element.reference), + BuildStatus.INACTIVE, + runId + ); + } buildResults.push(false); return; }