Refactor graph class to handle vertex additions and removals
This commit is contained in:
parent
fb433d90d9
commit
cac6473361
1 changed files with 15 additions and 14 deletions
|
|
@ -3,8 +3,6 @@ from collections import defaultdict, deque
|
|||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.state_manager import GraphStateManager
|
||||
|
|
@ -20,6 +18,7 @@ from langflow.graph.vertex.types import (
|
|||
)
|
||||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.utils import payload
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.schema import ResultData
|
||||
|
|
@ -270,6 +269,16 @@ class Graph:
|
|||
# Find vertices that are in self but not in other (removed vertices)
|
||||
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
|
||||
|
||||
# Remove vertices that are not in the other graph
|
||||
for vertex_id in removed_vertex_ids:
|
||||
self.remove_vertex(vertex_id)
|
||||
|
||||
# Add new vertices
|
||||
for vertex_id in new_vertex_ids:
|
||||
new_vertex = other.get_vertex(vertex_id)
|
||||
new_vertex.graph = self
|
||||
self._add_vertex(new_vertex)
|
||||
|
||||
# Update existing vertices that have changed
|
||||
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
|
||||
self_vertex = self.get_vertex(vertex_id)
|
||||
|
|
@ -291,15 +300,6 @@ class Graph:
|
|||
self_vertex.set_top_level(self.top_level_vertices)
|
||||
self.reset_all_edges_of_vertex(self_vertex)
|
||||
|
||||
# Remove vertices
|
||||
for vertex_id in removed_vertex_ids:
|
||||
self.remove_vertex(vertex_id)
|
||||
|
||||
# Add new vertices
|
||||
for vertex_id in new_vertex_ids:
|
||||
new_vertex = other.get_vertex(vertex_id)
|
||||
self._add_vertex(new_vertex)
|
||||
|
||||
self.build_graph_maps()
|
||||
self.increment_update_count()
|
||||
return self
|
||||
|
|
@ -612,7 +612,7 @@ class Graph:
|
|||
excluded = set() # To keep track of vertices that should be excluded
|
||||
stack = [vertex_id] # Use a list as a stack for DFS
|
||||
|
||||
def get_successors(vertex):
|
||||
def get_successors(vertex, recursive=True):
|
||||
# Recursively get the successors of the current vertex
|
||||
successors = vertex.successors
|
||||
if not successors:
|
||||
|
|
@ -620,8 +620,9 @@ class Graph:
|
|||
successors_result = []
|
||||
for successor in successors:
|
||||
# Just return a list of successors
|
||||
next_successors = get_successors(successor)
|
||||
successors_result.extend(next_successors)
|
||||
if recursive:
|
||||
next_successors = get_successors(successor)
|
||||
successors_result.extend(next_successors)
|
||||
successors_result.append(successor)
|
||||
return successors_result
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue