Fixes State management and first implementation of vertex inactivation (#1805)

* Refactor RoutingVertex class in types.py to include a new step in the steps list

* Refactor Graph class to update run_manager.run_predecessors when activating vertices

* Refactor ShouldRunNextComponent class in ShouldRunNext.py to improve error handling and readability

* Refactor error handling in get_lifespan function in main.py and fix mismatch between models and database in DatabaseService class in service.py

* Fix vertex inactivation

* Fix inactivation of vertices and update buildUtils

* Refactor ShouldRunNextComponent class to improve error handling and readability

* Refactor langflow.graph.vertex.types imports in base.py

* Fix nullable constraint in langflow migration script

* Fix condition check in ShouldRunNextComponent class

* Refactor build_graph_maps function in base.py to accept optional parameters for edges and vertices

* Apply grayscale effect to ring-muted-foreground in applies.css

* Apply grayscale effect to ring-muted-foreground in applies.css
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-04-30 16:33:53 -03:00 committed by GitHub
commit 95813de618
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 106 additions and 68 deletions

View file

@ -33,7 +33,7 @@ def upgrade() -> None:
"created_at",
existing_type=sa.DATETIME(),
nullable=False,
existing_server_default=sa.text("(CURRENT_TIMESTAMP)"),
existing_server_default=sa.text("(CURRENT_TIMESTAMP)"), # type: ignore
)
# ### end Alembic commands ###
@ -53,7 +53,7 @@ def downgrade() -> None:
"created_at",
existing_type=sa.DATETIME(),
nullable=True,
existing_server_default=sa.text("(CURRENT_TIMESTAMP)"),
existing_server_default=sa.text("(CURRENT_TIMESTAMP)"), # type: ignore
)
# ### end Alembic commands ###

View file

@ -0,0 +1,30 @@
from langchain_core.messages import BaseMessage
from langchain_core.prompts import PromptTemplate
from langflow.custom import CustomComponent
from langflow.field_typing import BaseLanguageModel, Text
class ShouldRunNextComponent(CustomComponent):
display_name = "Should Run Next"
description = "Determines if a vertex is runnable."
def build(self, llm: BaseLanguageModel, question: str, context: str, retries: int = 3) -> Text:
template = "Given the following question and the context below, answer with a yes or no.\n\n{error_message}\n\nQuestion: {question}\n\nContext: {context}\n\nAnswer:"
prompt = PromptTemplate.from_template(template)
chain = prompt | llm
error_message = ""
for i in range(retries):
result = chain.invoke(dict(question=question, context=context, error_message=error_message))
if isinstance(result, BaseMessage):
content = result.content
elif isinstance(result, str):
content = result
if isinstance(content, str) and content.lower().strip() in ["yes", "no"]:
break
condition = str(content).lower().strip() == "yes"
self.status = f"Should Run Next: {condition}"
if condition is False:
self.stop()
return context

View file

@ -3,7 +3,7 @@ import uuid
from collections import defaultdict, deque
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Type, Union
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union
from loguru import logger
@ -14,7 +14,7 @@ from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import process_flow
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, RoutingVertex, StateVertex, ToolkitVertex
from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, StateVertex, ToolkitVertex
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.schema import Record
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
@ -75,7 +75,7 @@ class Graph:
self.vertices: List[Vertex] = []
self.run_manager = RunnableVerticesManager()
self._build_graph()
self.build_graph_maps()
self.build_graph_maps(self.edges)
self.define_vertices_lists()
self.state_manager = GraphStateManager()
@ -130,6 +130,18 @@ class Graph:
):
vertices_ids.append(vertex_id)
successors = self.get_all_successors(vertex, flat=True)
# Update run_manager.run_predecessors because we are activating vertices
# The run_prdecessors is the predecessor map of the vertices
# we remove the vertex_id from the predecessor map whenever we run a vertex
# So we need to get all edges of the vertex and successors
# and run self.build_adjacency_maps(edges) to get the new predecessor map
# that is not complete but we can use to update the run_predecessors
edges_set = set()
for vertex in [vertex] + successors:
edges_set.update(vertex.edges)
edges = list(edges_set)
new_predecessor_map, _ = self.build_adjacency_maps(edges)
self.run_manager.run_predecessors.update(new_predecessor_map)
self.vertices_to_run.update(list(map(lambda x: x.id, successors)))
self.activated_vertices = vertices_ids
self.vertices_to_run.update(vertices_ids)
@ -401,14 +413,20 @@ class Graph:
"inactivated_vertices": self.inactivated_vertices,
}
def build_graph_maps(self):
def build_graph_maps(self, edges: Optional[List[ContractEdge]] = None, vertices: Optional[List[Vertex]] = None):
"""
Builds the adjacency maps for the graph.
"""
self.predecessor_map, self.successor_map = self.build_adjacency_maps()
if edges is None:
edges = self.edges
self.in_degree_map = self.build_in_degree()
self.parent_child_map = self.build_parent_child_map()
if vertices is None:
vertices = self.vertices
self.predecessor_map, self.successor_map = self.build_adjacency_maps(edges)
self.in_degree_map = self.build_in_degree(edges)
self.parent_child_map = self.build_parent_child_map(vertices)
def reset_inactivated_vertices(self):
"""
@ -433,9 +451,9 @@ class Graph:
for child_id in self.parent_child_map[vertex_id]:
self.mark_branch(child_id, state)
def build_parent_child_map(self):
def build_parent_child_map(self, vertices: List[Vertex]):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
for vertex in vertices:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
return parent_child_map
@ -559,6 +577,7 @@ class Graph:
self.update_vertex_from_another(self_vertex, other_vertex)
self.build_graph_maps()
self.define_vertices_lists()
self.increment_update_count()
return self
@ -944,8 +963,6 @@ class Graph:
node_name = node_id.split("-")[0]
if node_name in ["ChatOutput", "ChatInput"]:
return ChatVertex
elif node_name in ["ShouldRunNext"]:
return RoutingVertex
elif node_name in ["SharedState", "Notify", "Listen"]:
return StateVertex
elif node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
@ -1277,17 +1294,17 @@ class Graph:
def remove_from_predecessors(self, vertex_id: str):
self.run_manager.remove_from_predecessors(vertex_id)
def build_in_degree(self):
in_degree = defaultdict(int)
for edge in self.edges:
def build_in_degree(self, edges: List[ContractEdge]) -> Dict[str, int]:
in_degree: Dict[str, int] = defaultdict(int)
for edge in edges:
in_degree[edge.target_id] += 1
return in_degree
def build_adjacency_maps(self):
def build_adjacency_maps(self, edges: List[ContractEdge]) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
"""Returns the adjacency maps for the graph."""
predecessor_map = defaultdict(list)
successor_map = defaultdict(list)
for edge in self.edges:
for edge in edges:
predecessor_map[edge.target_id].append(edge.source_id)
successor_map[edge.source_id].append(edge.target_id)
return predecessor_map, successor_map

View file

@ -15,7 +15,6 @@ from langflow.interface.wrappers.base import wrapper_creator
from langflow.utils.lazy_load import LazyLoadDictBase
CHAT_COMPONENTS = ["ChatInput", "ChatOutput", "TextInput", "SessionID"]
ROUTING_COMPONENTS = ["ShouldRunNext"]
class VertexTypesDict(LazyLoadDictBase):
@ -51,7 +50,6 @@ class VertexTypesDict(LazyLoadDictBase):
**{t: types.CustomComponentVertex for t in custom_component_creator.to_list()},
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
**{t: types.ChatVertex for t in CHAT_COMPONENTS},
**{t: types.RoutingVertex for t in ROUTING_COMPONENTS},
}
def get_custom_component_vertex_type(self):

View file

@ -15,6 +15,7 @@ class RunnableVerticesManager:
def is_vertex_runnable(self, vertex_id: str) -> bool:
"""Determines if a vertex is runnable."""
return vertex_id in self.vertices_to_run and not self.run_predecessors.get(vertex_id)
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:

View file

@ -72,7 +72,6 @@ class Vertex:
self.load_from_db_fields: List[str] = []
self.parent_is_top_level = False
self.layer = None
self.should_run = True
self.result: Optional[ResultData] = None
try:
self.is_interface_component = self.vertex_type in InterfaceComponentTypes

View file

@ -1,6 +1,7 @@
import ast
import json
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
import yaml
from langchain_core.messages import AIMessage
from loguru import logger
@ -438,41 +439,6 @@ class ChatVertex(Vertex):
return self.vertex_type == InterfaceComponentTypes.ChatInput and self.is_input
class RoutingVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components")
self.use_result = True
self.steps = [self._build]
def _built_object_repr(self):
if self.artifacts and "repr" in self.artifacts:
return self.artifacts["repr"] or super()._built_object_repr()
return super()._built_object_repr()
@property
def successors_ids(self):
if isinstance(self._built_object, bool):
ids = super().successors_ids
if self._built_object:
return ids
return []
raise ValueError("RoutingVertex should return a boolean value.")
def _run(self, *args, **kwargs):
if self._built_object:
condition = self._built_object.get("condition")
result = self._built_object.get("result")
if condition is None:
raise ValueError("Condition is required for the routing vertex.")
if result is None:
raise ValueError("Result is required for the routing vertex.")
if condition is True:
self._built_result = result
else:
self.graph.mark_branch(self.id, "INACTIVE")
self._built_result = None
class StateVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components")

View file

@ -87,6 +87,14 @@ class CustomComponent(Component):
except Exception as e:
raise ValueError(f"Error updating state: {e}")
def stop(self):
if not self.vertex:
raise ValueError("Vertex is not set")
try:
self.graph.mark_branch(self.vertex.id, "INACTIVE")
except Exception as e:
raise ValueError(f"Error stopping {self.display_name}: {e}")
def append_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")

View file

@ -53,6 +53,7 @@ def get_lifespan(fix_migration=False, socketio_server=None):
except Exception as exc:
if "langflow migration --fix" not in str(exc):
logger.error(exc)
raise
# Shutdown message
rprint("[bold red]Shutting down Langflow...[/bold red]")
teardown_services()

View file

@ -133,7 +133,7 @@ class DatabaseService(Service):
alembic_cfg = Config(stdout=buffer)
# alembic_cfg.attributes["connection"] = session
alembic_cfg.set_main_option("script_location", str(self.script_location))
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace('%', '%%'))
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%"))
should_initialize_alembic = False
with Session(self.engine) as session:
@ -170,9 +170,7 @@ class DatabaseService(Service):
except util.exc.AutogenerateDiffsDetected as exc:
logger.error(f"AutogenerateDiffsDetected: {exc}")
if not fix:
raise RuntimeError(
"Something went wrong running migrations. Please, run `langflow migration --fix`"
) from exc
raise RuntimeError(f"There's a mismatch between the models and the database.\n{exc}")
if fix:
self.try_downgrade_upgrade_until_success(alembic_cfg)

View file

@ -478,8 +478,13 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
// const nextVertices will be the zip of vertexBuildData.next_vertices_ids and
// vertexBuildData.top_level_vertices
// the VertexLayerElementType as {id: next_vertices_id, layer: top_level_vertex}
// next_vertices_ids should be next_vertices_ids without the inactivated vertices
const next_vertices_ids = vertexBuildData.next_vertices_ids.filter(
(id) => !vertexBuildData.inactivated_vertices?.includes(id)
);
const nextVertices: VertexLayerElementType[] = zip(
vertexBuildData.next_vertices_ids,
next_vertices_ids,
vertexBuildData.top_level_vertices
).map(([id, reference]) => ({ id: id!, reference }));
@ -489,7 +494,7 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
];
const newIds = [
...get().verticesBuild!.verticesIds,
...vertexBuildData.next_vertices_ids,
...next_vertices_ids,
];
get().updateVerticesBuild({
verticesIds: newIds,
@ -598,7 +603,10 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
set({
verticesBuild: {
...verticesBuild,
// remove the vertices from the list of vertices ids
// that are going to be built
verticesIds: get().verticesBuild!.verticesIds.filter(
// keep the vertices that are not in the list of vertices to remove
(vertex) => !vertices.includes(vertex)
),
},

View file

@ -323,7 +323,7 @@
muted-foreground is too strong, maybe use a lighter shade of it?
*/
@apply border-none ring ring-muted-foreground;
@apply border-none ring grayscale;
}
.built-invalid-status {
@apply border-none ring ring-[#FF9090];

View file

@ -166,14 +166,26 @@ export async function buildVertices({
!useFlowStore
.getState()
.verticesBuild?.verticesIds.includes(element.id) &&
!useFlowStore
.getState()
.verticesBuild?.verticesIds.includes(element.reference ?? "") &&
onBuildUpdate
) {
// If it is, skip building and set the state to inactive
onBuildUpdate(
getInactiveVertexData(element.id),
BuildStatus.INACTIVE,
runId
);
if (element.id) {
onBuildUpdate(
getInactiveVertexData(element.id),
BuildStatus.INACTIVE,
runId
);
}
if (element.reference) {
onBuildUpdate(
getInactiveVertexData(element.reference),
BuildStatus.INACTIVE,
runId
);
}
buildResults.push(false);
return;
}