Add state activation for specific vertices

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-02 01:18:51 -03:00
commit 22884aeb7e

View file

@ -3,6 +3,8 @@ from collections import defaultdict, deque
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
from langchain.chains.base import Chain
from loguru import logger
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.state_manager import GraphStateManager
@ -20,7 +22,6 @@ from langflow.graph.vertex.types import (
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.schema import Record
from langflow.utils import payload
from loguru import logger
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
@ -43,6 +44,7 @@ class Graph:
self.flow_id = flow_id
self._is_input_vertices: List[str] = []
self._is_output_vertices: List[str] = []
self._is_state_vertices: List[str] = []
self._has_session_id_vertices: List[str] = []
self._sorted_vertices_layers: List[List[str]] = []
self.run_id = None
@ -73,16 +75,23 @@ class Graph:
# all StateVertex in self.vertices that are not the caller
# essentially notifying all the other vertices that the state has changed
# This also has to activate their successors
caller_vertex = self.get_vertex(caller)
for vertex in self.vertices:
if vertex.id != caller and isinstance(vertex, StateVertex):
successors = self.get_all_successors(vertex)
self.activated_vertices.add(vertex.id)
for successor in successors:
self.activated_vertices.add(successor.id)
self.activate_state_vertices(name, caller)
self.state_manager.update_state(name, record)
def activate_state_vertices(self, name: str, caller: str):
for vertex_id in self._is_state_vertices:
vertex = self.get_vertex(vertex_id)
if (
name in vertex._raw_params["name"]
and vertex_id != caller
and isinstance(vertex, StateVertex)
):
successors = self.get_all_successors(vertex)
self.activated_vertices.add(vertex_id)
for successor in successors:
self.activated_vertices.add(successor.id)
def reset_activated_vertices(self):
self.activated_vertices = set()
@ -91,7 +100,8 @@ class Graph:
) -> None:
"""Appends the state of the graph."""
if caller:
self.state_manager.subscribe(name, caller)
self.activate_state_vertices(name, caller)
self.state_manager.append_state(name, record)
@ -113,7 +123,7 @@ class Graph:
"""
Defines the lists of vertices that are inputs, outputs, and have session_id.
"""
attributes = ["is_input", "is_output", "has_session_id"]
attributes = ["is_input", "is_output", "has_session_id", "is_state"]
for vertex in self.vertices:
for attribute in attributes:
if getattr(vertex, attribute):