Refactor edge class and add missing type annotations

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-21 16:01:58 -03:00
commit 05f1535668
2 changed files with 38 additions and 13 deletions

View file

@ -1,18 +1,19 @@
from typing import TYPE_CHECKING, Any, List, Optional
from loguru import logger
from pydantic import BaseModel, Field
from langflow.graph.edge.utils import build_clean_params
from langflow.services.deps import get_monitor_service
from langflow.services.monitor.utils import log_message
from loguru import logger
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
@ -20,7 +21,9 @@ class SourceHandle(BaseModel):
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
type: str = Field(..., description="Type of the target handle.")
@ -49,16 +52,24 @@ class Edge:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has invalid handles"
)
def __setstate__(self, state):
self.source_id = state["source_id"]
@ -75,7 +86,11 @@ class Edge:
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
self.valid = any(
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
# Get what type of input the target node is expecting
self.matched_type = next(
@ -86,7 +101,10 @@ class Edge:
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has no matched type"
)
def __repr__(self) -> str:
return (
@ -98,7 +116,11 @@ class Edge:
return hash(self.__repr__())
def __eq__(self, __value: object) -> bool:
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
return (
self.__repr__() == __value.__repr__()
if isinstance(__value, Edge)
else False
)
class ContractEdge(Edge):
@ -142,7 +164,7 @@ class ContractEdge(Edge):
or isinstance(target.params.get("message"), dict)
):
await log_message(
sender_type=target.params.get("sender", ""),
sender=target.params.get("sender", ""),
sender_name=target.params.get("sender_name", ""),
message=target.params.get("message", {}),
session_id=target.params.get("session_id", ""),
@ -154,7 +176,9 @@ class ContractEdge(Edge):
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
def log_transaction(
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
):
try:
monitor_service = get_monitor_service()
clean_params = build_clean_params(target)

View file

@ -363,6 +363,7 @@ def create_component_template(component):
component_template = build_custom_component_template(component_extractor)
component_template["output_types"] = component_output_types
return component_template