Refactor edge validation logic and add conditional path support

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-29 23:08:32 -03:00
commit 1a144a2fc1

View file

@ -13,15 +13,22 @@ if TYPE_CHECKING:
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
conditionalPath: Optional[bool] = Field(
None, description="Conditional path for the source handle."
)
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
type: str = Field(..., description="Type of the target handle.")
@ -50,16 +57,24 @@ class Edge:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has invalid handles"
)
def __setstate__(self, state):
self.source_id = state["source_id"]
@ -76,7 +91,11 @@ class Edge:
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
self.valid = any(
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
# Get what type of input the target node is expecting
self.matched_type = next(
@ -87,7 +106,10 @@ class Edge:
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has no matched type"
)
def __repr__(self) -> str:
return (
@ -99,7 +121,11 @@ class Edge:
return hash(self.__repr__())
def __eq__(self, __value: object) -> bool:
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
return (
self.__repr__() == __value.__repr__()
if isinstance(__value, Edge)
else False
)
class ContractEdge(Edge):
@ -156,7 +182,9 @@ class ContractEdge(Edge):
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
def log_transaction(
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
):
try:
monitor_service = get_monitor_service()
clean_params = build_clean_params(target)