diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index ddde55728..2cca07100 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -1,18 +1,19 @@ from typing import TYPE_CHECKING, Any, List, Optional -from loguru import logger -from pydantic import BaseModel, Field - from langflow.graph.edge.utils import build_clean_params from langflow.services.deps import get_monitor_service from langflow.services.monitor.utils import log_message +from loguru import logger +from pydantic import BaseModel, Field if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex class SourceHandle(BaseModel): - baseClasses: List[str] = Field(..., description="List of base classes for the source handle.") + baseClasses: List[str] = Field( + ..., description="List of base classes for the source handle." + ) dataType: str = Field(..., description="Data type for the source handle.") id: str = Field(..., description="Unique identifier for the source handle.") @@ -20,7 +21,9 @@ class SourceHandle(BaseModel): class TargetHandle(BaseModel): fieldName: str = Field(..., description="Field name for the target handle.") id: str = Field(..., description="Unique identifier for the target handle.") - inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.") + inputTypes: Optional[List[str]] = Field( + None, description="List of input types for the target handle." + ) type: str = Field(..., description="Type of the target handle.") @@ -49,16 +52,24 @@ class Edge: def validate_handles(self, source, target) -> None: if self.target_handle.inputTypes is None: - self.valid_handles = self.target_handle.type in self.source_handle.baseClasses + self.valid_handles = ( + self.target_handle.type in self.source_handle.baseClasses + ) else: self.valid_handles = ( - any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses) + any( + baseClass in self.target_handle.inputTypes + for baseClass in self.source_handle.baseClasses + ) or self.target_handle.type in self.source_handle.baseClasses ) if not self.valid_handles: logger.debug(self.source_handle) logger.debug(self.target_handle) - raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles") + raise ValueError( + f"Edge between {source.vertex_type} and {target.vertex_type} " + f"has invalid handles" + ) def __setstate__(self, state): self.source_id = state["source_id"] @@ -75,7 +86,11 @@ class Edge: # Both lists contain strings and sometimes a string contains the value we are # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] # so we need to check if any of the strings in source_types is in target_reqs - self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs) + self.valid = any( + output in target_req + for output in self.source_types + for target_req in self.target_reqs + ) # Get what type of input the target node is expecting self.matched_type = next( @@ -86,7 +101,10 @@ class Edge: if no_matched_type: logger.debug(self.source_types) logger.debug(self.target_reqs) - raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type") + raise ValueError( + f"Edge between {source.vertex_type} and {target.vertex_type} " + f"has no matched type" + ) def __repr__(self) -> str: return ( @@ -98,7 +116,11 @@ class Edge: return hash(self.__repr__()) def __eq__(self, __value: object) -> bool: - return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False + return ( + self.__repr__() == __value.__repr__() + if isinstance(__value, Edge) + else False + ) class ContractEdge(Edge): @@ -142,7 +164,7 @@ class ContractEdge(Edge): or isinstance(target.params.get("message"), dict) ): await log_message( - sender_type=target.params.get("sender", ""), + sender=target.params.get("sender", ""), sender_name=target.params.get("sender_name", ""), message=target.params.get("message", {}), session_id=target.params.get("session_id", ""), @@ -154,7 +176,9 @@ class ContractEdge(Edge): return f"{self.source_id} -[{self.target_param}]-> {self.target_id}" -def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None): +def log_transaction( + edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None +): try: monitor_service = get_monitor_service() clean_params = build_clean_params(target) diff --git a/src/backend/langflow/interface/custom/utils.py b/src/backend/langflow/interface/custom/utils.py index ce15768e1..cc0f2bd70 100644 --- a/src/backend/langflow/interface/custom/utils.py +++ b/src/backend/langflow/interface/custom/utils.py @@ -363,6 +363,7 @@ def create_component_template(component): component_template = build_custom_component_template(component_extractor) component_template["output_types"] = component_output_types + return component_template