From 2bb626fc663929ac9efc0403353187d09efd2fb5 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 30 Aug 2023 14:21:25 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(base.py):=20fix=20validation?= =?UTF-8?q?=20of=20source=20and=20target=20handles=20in=20Edge=20class=20c?= =?UTF-8?q?onstructor=20=E2=9C=A8=20feat(base.py):=20add=20SourceHandle=20?= =?UTF-8?q?and=20TargetHandle=20models=20to=20represent=20source=20and=20t?= =?UTF-8?q?arget=20handles=20in=20Edge=20class?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/edge/base.py | 57 ++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index dc7eab328..27d25f928 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -1,22 +1,67 @@ from langflow.utils.logger import logger from typing import TYPE_CHECKING +from pydantic import BaseModel, Field +from typing import List, Optional if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex +class SourceHandle(BaseModel): + baseClasses: List[str] = Field( + ..., description="List of base classes for the source handle." + ) + dataType: str = Field(..., description="Data type for the source handle.") + id: str = Field(..., description="Unique identifier for the source handle.") + + +class TargetHandle(BaseModel): + fieldName: str = Field(..., description="Field name for the target handle.") + id: str = Field(..., description="Unique identifier for the target handle.") + inputTypes: Optional[List[str]] = Field( + None, description="List of input types for the target handle." + ) + type: str = Field(..., description="Type of the target handle.") + + class Edge: def __init__(self, source: "Vertex", target: "Vertex", edge: dict): self.source: "Vertex" = source self.target: "Vertex" = target - self.source_handle = edge.get("sourceHandle", "") - self.target_handle = edge.get("targetHandle", "") - # 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD' - # target_param is documents - self.target_param = self.target_handle.split("|")[1] - + data = edge.get("data", {}) + if not data: + raise ValueError("Edge data is empty") + self._source_handle = data.get("sourceHandle", {}) + self._target_handle = data.get("targetHandle", {}) + self.source_handle: SourceHandle = SourceHandle(**self._source_handle) + self.target_handle: TargetHandle = TargetHandle(**self._target_handle) + self.target_param = self.target_handle.fieldName + # validate handles + self.validate_handles() + # Validate in __init__ to fail fast self.validate_edge() + def validate_handles(self) -> None: + if self.target_handle.inputTypes is None: + self.valid_handles = ( + self.target_handle.type in self.source_handle.baseClasses + ) + else: + self.valid_handles = ( + any( + baseClass in self.target_handle.inputTypes + for baseClass in self.source_handle.baseClasses + ) + or self.target_handle.type in self.source_handle.baseClasses + ) + if not self.valid_handles: + logger.debug(self.source_handle) + logger.debug(self.target_handle) + raise ValueError( + f"Edge between {self.source.vertex_type} and {self.target.vertex_type} " + f"has invalid handles" + ) + def validate_edge(self) -> None: # Validate that the outputs of the source node are valid inputs # for the target node