Update __eq__ methods
This commit is contained in:
parent
457934b9d9
commit
5d87eb021f
2 changed files with 46 additions and 14 deletions
|
|
@ -1,19 +1,20 @@
|
|||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langflow.graph.edge.utils import build_clean_params
|
||||
from langflow.graph.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.deps import get_monitor_service
|
||||
from langflow.services.monitor.utils import log_message
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
||||
class SourceHandle(BaseModel):
|
||||
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
|
||||
baseClasses: List[str] = Field(
|
||||
..., description="List of base classes for the source handle."
|
||||
)
|
||||
dataType: str = Field(..., description="Data type for the source handle.")
|
||||
id: str = Field(..., description="Unique identifier for the source handle.")
|
||||
|
||||
|
|
@ -21,7 +22,9 @@ class SourceHandle(BaseModel):
|
|||
class TargetHandle(BaseModel):
|
||||
fieldName: str = Field(..., description="Field name for the target handle.")
|
||||
id: str = Field(..., description="Unique identifier for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(
|
||||
None, description="List of input types for the target handle."
|
||||
)
|
||||
type: str = Field(..., description="Type of the target handle.")
|
||||
|
||||
|
||||
|
|
@ -50,16 +53,24 @@ class Edge:
|
|||
|
||||
def validate_handles(self, source, target) -> None:
|
||||
if self.target_handle.inputTypes is None:
|
||||
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
|
||||
self.valid_handles = (
|
||||
self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
else:
|
||||
self.valid_handles = (
|
||||
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
|
||||
any(
|
||||
baseClass in self.target_handle.inputTypes
|
||||
for baseClass in self.source_handle.baseClasses
|
||||
)
|
||||
or self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
if not self.valid_handles:
|
||||
logger.debug(self.source_handle)
|
||||
logger.debug(self.target_handle)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
|
||||
raise ValueError(
|
||||
f"Edge between {source.vertex_type} and {target.vertex_type} "
|
||||
f"has invalid handles"
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.source_id = state["source_id"]
|
||||
|
|
@ -76,7 +87,11 @@ class Edge:
|
|||
# Both lists contain strings and sometimes a string contains the value we are
|
||||
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
|
||||
# so we need to check if any of the strings in source_types is in target_reqs
|
||||
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
|
||||
self.valid = any(
|
||||
output in target_req
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
)
|
||||
# Get what type of input the target node is expecting
|
||||
|
||||
self.matched_type = next(
|
||||
|
|
@ -87,7 +102,10 @@ class Edge:
|
|||
if no_matched_type:
|
||||
logger.debug(self.source_types)
|
||||
logger.debug(self.target_reqs)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
|
||||
raise ValueError(
|
||||
f"Edge between {source.vertex_type} and {target.vertex_type} "
|
||||
f"has no matched type"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
|
@ -98,8 +116,12 @@ class Edge:
|
|||
def __hash__(self) -> int:
|
||||
return hash(self.__repr__())
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
# Create a better way to compare edges
|
||||
return (
|
||||
self._source_handle == __o._source_handle
|
||||
and self._target_handle == __o._target_handle
|
||||
)
|
||||
|
||||
|
||||
class ContractEdge(Edge):
|
||||
|
|
@ -156,7 +178,9 @@ class ContractEdge(Edge):
|
|||
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
|
||||
|
||||
|
||||
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
|
||||
def log_transaction(
|
||||
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
|
||||
):
|
||||
try:
|
||||
monitor_service = get_monitor_service()
|
||||
clean_params = build_clean_params(target)
|
||||
|
|
|
|||
|
|
@ -682,7 +682,15 @@ class Vertex:
|
|||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
try:
|
||||
return self.id == __o.id if isinstance(__o, Vertex) else False
|
||||
if not isinstance(__o, Vertex):
|
||||
return False
|
||||
# We should create a more robust comparison
|
||||
# for the Vertex class
|
||||
ids_are_equal = self.id == __o.id
|
||||
# self._data is a dict and we need to compare them
|
||||
# to check if they are equal
|
||||
data_are_equal = self._data == __o._data
|
||||
return ids_are_equal and data_are_equal
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue