From 5d87eb021fb056c9dae29abe3b7ac38c06b7979f Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 1 Mar 2024 22:51:53 -0300 Subject: [PATCH] Update __eq__ methods --- src/backend/langflow/graph/edge/base.py | 50 +++++++++++++++++------ src/backend/langflow/graph/vertex/base.py | 10 ++++- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 4992d0b47..c49ec714c 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -1,19 +1,20 @@ from typing import TYPE_CHECKING, Any, List, Optional -from loguru import logger -from pydantic import BaseModel, Field - from langflow.graph.edge.utils import build_clean_params from langflow.graph.schema import INPUT_FIELD_NAME from langflow.services.deps import get_monitor_service from langflow.services.monitor.utils import log_message +from loguru import logger +from pydantic import BaseModel, Field if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex class SourceHandle(BaseModel): - baseClasses: List[str] = Field(..., description="List of base classes for the source handle.") + baseClasses: List[str] = Field( + ..., description="List of base classes for the source handle." + ) dataType: str = Field(..., description="Data type for the source handle.") id: str = Field(..., description="Unique identifier for the source handle.") @@ -21,7 +22,9 @@ class SourceHandle(BaseModel): class TargetHandle(BaseModel): fieldName: str = Field(..., description="Field name for the target handle.") id: str = Field(..., description="Unique identifier for the target handle.") - inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.") + inputTypes: Optional[List[str]] = Field( + None, description="List of input types for the target handle." + ) type: str = Field(..., description="Type of the target handle.") @@ -50,16 +53,24 @@ class Edge: def validate_handles(self, source, target) -> None: if self.target_handle.inputTypes is None: - self.valid_handles = self.target_handle.type in self.source_handle.baseClasses + self.valid_handles = ( + self.target_handle.type in self.source_handle.baseClasses + ) else: self.valid_handles = ( - any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses) + any( + baseClass in self.target_handle.inputTypes + for baseClass in self.source_handle.baseClasses + ) or self.target_handle.type in self.source_handle.baseClasses ) if not self.valid_handles: logger.debug(self.source_handle) logger.debug(self.target_handle) - raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles") + raise ValueError( + f"Edge between {source.vertex_type} and {target.vertex_type} " + f"has invalid handles" + ) def __setstate__(self, state): self.source_id = state["source_id"] @@ -76,7 +87,11 @@ class Edge: # Both lists contain strings and sometimes a string contains the value we are # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] # so we need to check if any of the strings in source_types is in target_reqs - self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs) + self.valid = any( + output in target_req + for output in self.source_types + for target_req in self.target_reqs + ) # Get what type of input the target node is expecting self.matched_type = next( @@ -87,7 +102,10 @@ class Edge: if no_matched_type: logger.debug(self.source_types) logger.debug(self.target_reqs) - raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type") + raise ValueError( + f"Edge between {source.vertex_type} and {target.vertex_type} " + f"has no matched type" + ) def __repr__(self) -> str: return ( @@ -98,8 +116,12 @@ class Edge: def __hash__(self) -> int: return hash(self.__repr__()) - def __eq__(self, __value: object) -> bool: - return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False + def __eq__(self, __o: object) -> bool: + # Create a better way to compare edges + return ( + self._source_handle == __o._source_handle + and self._target_handle == __o._target_handle + ) class ContractEdge(Edge): @@ -156,7 +178,9 @@ class ContractEdge(Edge): return f"{self.source_id} -[{self.target_param}]-> {self.target_id}" -def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None): +def log_transaction( + edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None +): try: monitor_service = get_monitor_service() clean_params = build_clean_params(target) diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 4e711329f..818f40d58 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -682,7 +682,15 @@ class Vertex: def __eq__(self, __o: object) -> bool: try: - return self.id == __o.id if isinstance(__o, Vertex) else False + if not isinstance(__o, Vertex): + return False + # We should create a more robust comparison + # for the Vertex class + ids_are_equal = self.id == __o.id + # self._data is a dict and we need to compare them + # to check if they are equal + data_are_equal = self._data == __o._data + return ids_are_equal and data_are_equal except AttributeError: return False