🐛 fix(base.py): fix validation of source and target handles in Edge class constructor

 feat(base.py): add SourceHandle and TargetHandle models to represent source and target handles in Edge class
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-30 14:21:25 -03:00
commit 2bb626fc66

View file

@ -1,22 +1,67 @@
from langflow.utils.logger import logger
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from typing import List, Optional
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
type: str = Field(..., description="Type of the target handle.")
class Edge:
def __init__(self, source: "Vertex", target: "Vertex", edge: dict):
self.source: "Vertex" = source
self.target: "Vertex" = target
self.source_handle = edge.get("sourceHandle", "")
self.target_handle = edge.get("targetHandle", "")
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
# target_param is documents
self.target_param = self.target_handle.split("|")[1]
data = edge.get("data", {})
if not data:
raise ValueError("Edge data is empty")
self._source_handle = data.get("sourceHandle", {})
self._target_handle = data.get("targetHandle", {})
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
self.target_param = self.target_handle.fieldName
# validate handles
self.validate_handles()
# Validate in __init__ to fail fast
self.validate_edge()
def validate_handles(self) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
else:
self.valid_handles = (
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} "
f"has invalid handles"
)
def validate_edge(self) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node