diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index dc7eab328..27d25f928 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -1,22 +1,67 @@ from langflow.utils.logger import logger from typing import TYPE_CHECKING +from pydantic import BaseModel, Field +from typing import List, Optional if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex +class SourceHandle(BaseModel): + baseClasses: List[str] = Field( + ..., description="List of base classes for the source handle." + ) + dataType: str = Field(..., description="Data type for the source handle.") + id: str = Field(..., description="Unique identifier for the source handle.") + + +class TargetHandle(BaseModel): + fieldName: str = Field(..., description="Field name for the target handle.") + id: str = Field(..., description="Unique identifier for the target handle.") + inputTypes: Optional[List[str]] = Field( + None, description="List of input types for the target handle." + ) + type: str = Field(..., description="Type of the target handle.") + + class Edge: def __init__(self, source: "Vertex", target: "Vertex", edge: dict): self.source: "Vertex" = source self.target: "Vertex" = target - self.source_handle = edge.get("sourceHandle", "") - self.target_handle = edge.get("targetHandle", "") - # 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD' - # target_param is documents - self.target_param = self.target_handle.split("|")[1] - + data = edge.get("data", {}) + if not data: + raise ValueError("Edge data is empty") + self._source_handle = data.get("sourceHandle", {}) + self._target_handle = data.get("targetHandle", {}) + self.source_handle: SourceHandle = SourceHandle(**self._source_handle) + self.target_handle: TargetHandle = TargetHandle(**self._target_handle) + self.target_param = self.target_handle.fieldName + # validate handles + self.validate_handles() + # Validate in __init__ to fail fast self.validate_edge() + def validate_handles(self) -> None: + if self.target_handle.inputTypes is None: + self.valid_handles = ( + self.target_handle.type in self.source_handle.baseClasses + ) + else: + self.valid_handles = ( + any( + baseClass in self.target_handle.inputTypes + for baseClass in self.source_handle.baseClasses + ) + or self.target_handle.type in self.source_handle.baseClasses + ) + if not self.valid_handles: + logger.debug(self.source_handle) + logger.debug(self.target_handle) + raise ValueError( + f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " + f"has invalid handles" + ) + def validate_edge(self) -> None: # Validate that the outputs of the source node are valid inputs # for the target node