langflow/src/backend/base/langflow/graph/edge/base.py
2024-03-30 16:49:50 -03:00

182 lines
7.9 KiB
Python

from typing import TYPE_CHECKING, Any, List, Optional
from loguru import logger
from pydantic import BaseModel, Field
from langflow.graph.edge.utils import build_clean_params
from langflow.schema.schema import INPUT_FIELD_NAME
from langflow.services.deps import get_monitor_service
from langflow.services.monitor.utils import log_message
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
type: str = Field(..., description="Type of the target handle.")
class Edge:
def __init__(self, source: "Vertex", target: "Vertex", edge: dict):
self.source_id: str = source.id if source else ""
self.target_id: str = target.id if target else ""
if data := edge.get("data", {}):
self._source_handle = data.get("sourceHandle", {})
self._target_handle = data.get("targetHandle", {})
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
self.target_param = self.target_handle.fieldName
# validate handles
self.validate_handles(source, target)
else:
# Logging here because this is a breaking change
logger.error("Edge data is empty")
self._source_handle = edge.get("sourceHandle", "")
self._target_handle = edge.get("targetHandle", "")
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
# target_param is documents
self.target_param = self._target_handle.split("|")[1]
# Validate in __init__ to fail fast
self.validate_edge(source, target)
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
def __setstate__(self, state):
self.source_id = state["source_id"]
self.target_id = state["target_id"]
self.target_param = state["target_param"]
self.source_handle = state.get("source_handle")
self.target_handle = state.get("target_handle")
def validate_edge(self, source, target) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node
self.source_types = source.output
self.target_reqs = target.required_inputs + target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
# Get what type of input the target node is expecting
self.matched_type = next(
(output for output in self.source_types if output in self.target_reqs),
None,
)
no_matched_type = self.matched_type is None
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
def __repr__(self) -> str:
return (
f"Edge(source={self.source_id}, target={self.target_id}, target_param={self.target_param}"
f", matched_type={self.matched_type})"
)
def __hash__(self) -> int:
return hash(self.__repr__())
def __eq__(self, __o: object) -> bool:
if not isinstance(__o, Edge):
return False
return (
self._source_handle == __o._source_handle
and self._target_handle == __o._target_handle
and self.target_param == __o.target_param
)
class ContractEdge(Edge):
def __init__(self, source: "Vertex", target: "Vertex", raw_edge: dict):
super().__init__(source, target, raw_edge)
self.is_fulfilled = False # Whether the contract has been fulfilled.
self.result: Any = None
async def honor(self, source: "Vertex", target: "Vertex") -> None:
"""
Fulfills the contract by setting the result of the source vertex to the target vertex's parameter.
If the edge is runnable, the source vertex is run with the message text and the target vertex's
root_field param is set to the
result. If the edge is not runnable, the target vertex's parameter is set to the result.
:param message: The message object to be processed if the edge is runnable.
"""
if self.is_fulfilled:
return
if not source._built:
# The system should be read-only, so we should not be building vertices
# that are not already built.
raise ValueError(f"Source vertex {source.id} is not built.")
if self.matched_type == "Text":
self.result = source._built_result
else:
self.result = source._built_object
target.params[self.target_param] = self.result
self.is_fulfilled = True
async def get_result_from_source(self, source: "Vertex", target: "Vertex"):
# Fulfill the contract if it has not been fulfilled.
if not self.is_fulfilled:
await self.honor(source, target)
log_transaction(self, source, target, "success")
# If the target vertex is a power component we log messages
if target.vertex_type == "ChatOutput" and (
isinstance(target.params.get(INPUT_FIELD_NAME), str)
or isinstance(target.params.get(INPUT_FIELD_NAME), dict)
):
if target.params.get("message") == "":
return self.result
await log_message(
sender=target.params.get("sender", ""),
sender_name=target.params.get("sender_name", ""),
message=target.params.get(INPUT_FIELD_NAME, {}),
session_id=target.params.get("session_id", ""),
artifacts=target.artifacts,
)
return self.result
def __repr__(self) -> str:
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
try:
monitor_service = get_monitor_service()
clean_params = build_clean_params(target)
data = {
"source": source.vertex_type,
"target": target.vertex_type,
"target_args": clean_params,
"timestamp": monitor_service.get_timestamp(),
"status": status,
"error": error,
}
monitor_service.add_row(table_name="transactions", data=data)
except Exception as e:
logger.error(f"Error logging transaction: {e}")
logger.error(f"Error logging transaction: {e}")