Refactor get_vertex_edges method to return a list of Union[Edge, ContractEdge]
This commit is contained in:
parent
ec7a5725ff
commit
97eb790653
1 changed files with 3 additions and 4 deletions
|
|
@ -2,15 +2,14 @@ from collections import defaultdict, deque
|
|||
from typing import Dict, Generator, List, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.edge.base import ContractEdge, Edge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.utils import process_flow
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex
|
||||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.utils import payload
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Graph:
|
||||
|
|
@ -111,7 +110,7 @@ class Graph:
|
|||
"""Returns a vertex by id."""
|
||||
return self.vertex_map.get(vertex_id)
|
||||
|
||||
def get_vertex_edges(self, vertex_id: str) -> List[Edge]:
|
||||
def get_vertex_edges(self, vertex_id: str) -> List[Union[Edge, ContractEdge]]:
|
||||
"""Returns a list of edges for a given vertex."""
|
||||
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue