Refactor graph class to handle vertex additions and removals

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-01 18:21:59 -03:00
commit cac6473361

View file

@ -3,8 +3,6 @@ from collections import defaultdict, deque
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
from langchain.chains.base import Chain
from loguru import logger
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.state_manager import GraphStateManager
@ -20,6 +18,7 @@ from langflow.graph.vertex.types import (
)
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.utils import payload
from loguru import logger
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
@ -270,6 +269,16 @@ class Graph:
# Find vertices that are in self but not in other (removed vertices)
removed_vertex_ids = existing_vertex_ids - other_vertex_ids
# Remove vertices that are not in the other graph
for vertex_id in removed_vertex_ids:
self.remove_vertex(vertex_id)
# Add new vertices
for vertex_id in new_vertex_ids:
new_vertex = other.get_vertex(vertex_id)
new_vertex.graph = self
self._add_vertex(new_vertex)
# Update existing vertices that have changed
for vertex_id in existing_vertex_ids.intersection(other_vertex_ids):
self_vertex = self.get_vertex(vertex_id)
@ -291,15 +300,6 @@ class Graph:
self_vertex.set_top_level(self.top_level_vertices)
self.reset_all_edges_of_vertex(self_vertex)
# Remove vertices
for vertex_id in removed_vertex_ids:
self.remove_vertex(vertex_id)
# Add new vertices
for vertex_id in new_vertex_ids:
new_vertex = other.get_vertex(vertex_id)
self._add_vertex(new_vertex)
self.build_graph_maps()
self.increment_update_count()
return self
@ -612,7 +612,7 @@ class Graph:
excluded = set() # To keep track of vertices that should be excluded
stack = [vertex_id] # Use a list as a stack for DFS
def get_successors(vertex):
def get_successors(vertex, recursive=True):
# Recursively get the successors of the current vertex
successors = vertex.successors
if not successors:
@ -620,8 +620,9 @@ class Graph:
successors_result = []
for successor in successors:
# Just return a list of successors
next_successors = get_successors(successor)
successors_result.extend(next_successors)
if recursive:
next_successors = get_successors(successor)
successors_result.extend(next_successors)
successors_result.append(successor)
return successors_result