Add state activation for specific vertices
This commit is contained in:
parent
d846d806cc
commit
22884aeb7e
1 changed files with 20 additions and 10 deletions
|
|
@ -3,6 +3,8 @@ from collections import defaultdict, deque
|
|||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.state_manager import GraphStateManager
|
||||
|
|
@ -20,7 +22,6 @@ from langflow.graph.vertex.types import (
|
|||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.schema import Record
|
||||
from langflow.utils import payload
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.schema import ResultData
|
||||
|
|
@ -43,6 +44,7 @@ class Graph:
|
|||
self.flow_id = flow_id
|
||||
self._is_input_vertices: List[str] = []
|
||||
self._is_output_vertices: List[str] = []
|
||||
self._is_state_vertices: List[str] = []
|
||||
self._has_session_id_vertices: List[str] = []
|
||||
self._sorted_vertices_layers: List[List[str]] = []
|
||||
self.run_id = None
|
||||
|
|
@ -73,16 +75,23 @@ class Graph:
|
|||
# all StateVertex in self.vertices that are not the caller
|
||||
# essentially notifying all the other vertices that the state has changed
|
||||
# This also has to activate their successors
|
||||
caller_vertex = self.get_vertex(caller)
|
||||
for vertex in self.vertices:
|
||||
if vertex.id != caller and isinstance(vertex, StateVertex):
|
||||
successors = self.get_all_successors(vertex)
|
||||
self.activated_vertices.add(vertex.id)
|
||||
for successor in successors:
|
||||
self.activated_vertices.add(successor.id)
|
||||
self.activate_state_vertices(name, caller)
|
||||
|
||||
self.state_manager.update_state(name, record)
|
||||
|
||||
def activate_state_vertices(self, name: str, caller: str):
|
||||
for vertex_id in self._is_state_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if (
|
||||
name in vertex._raw_params["name"]
|
||||
and vertex_id != caller
|
||||
and isinstance(vertex, StateVertex)
|
||||
):
|
||||
successors = self.get_all_successors(vertex)
|
||||
self.activated_vertices.add(vertex_id)
|
||||
for successor in successors:
|
||||
self.activated_vertices.add(successor.id)
|
||||
|
||||
def reset_activated_vertices(self):
|
||||
self.activated_vertices = set()
|
||||
|
||||
|
|
@ -91,7 +100,8 @@ class Graph:
|
|||
) -> None:
|
||||
"""Appends the state of the graph."""
|
||||
if caller:
|
||||
self.state_manager.subscribe(name, caller)
|
||||
|
||||
self.activate_state_vertices(name, caller)
|
||||
|
||||
self.state_manager.append_state(name, record)
|
||||
|
||||
|
|
@ -113,7 +123,7 @@ class Graph:
|
|||
"""
|
||||
Defines the lists of vertices that are inputs, outputs, and have session_id.
|
||||
"""
|
||||
attributes = ["is_input", "is_output", "has_session_id"]
|
||||
attributes = ["is_input", "is_output", "has_session_id", "is_state"]
|
||||
for vertex in self.vertices:
|
||||
for attribute in attributes:
|
||||
if getattr(vertex, attribute):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue