Add ContractEdge class and related methods

This commit is contained in:
anovazzi1 2024-01-19 21:41:10 -03:00
commit 11eb254622

View file

@ -95,3 +95,115 @@ class Edge:
def __eq__(self, __value: object) -> bool:
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
class ContractEdge(Edge):
def __init__(self, source: "Vertex", target: "Vertex", raw_edge: dict):
super().__init__(source, target, raw_edge)
self.is_fulfilled = False # Whether the contract has been fulfilled.
self.result: Any = None
def honor(self, source, target) -> None:
"""
Fulfills the contract by setting the result of the source vertex to the target vertex's parameter.
If the edge is runnable, the source vertex is run with the message text and the target vertex's
root_field param is set to the
result. If the edge is not runnable, the target vertex's parameter is set to the result.
:param message: The message object to be processed if the edge is runnable.
"""
if self.is_fulfilled:
return
if not source._built:
source.build()
if self.matched_type == "Text":
self.result = source._built_result
else:
self.result = source._built_object
target.params[self.target_param] = self.result
self.is_fulfilled = True
def build_clean_params(self, target: "Vertex") -> dict:
"""
Cleans the parameters of the target vertex.
"""
# Removes all keys that the values aren't python types like str, int, bool, etc.
params = {
key: value
for key, value in target.params.items()
if isinstance(value, (str, int, bool, float, list, dict))
}
# if it is a list we need to check if the contents are python types
for key, value in params.items():
if isinstance(value, list):
params[key] = [item for item in value if isinstance(item, (str, int, bool, float, list, dict))]
return params
async def get_result(self, source, target):
# Fulfill the contract if it has not been fulfilled.
if not self.is_fulfilled:
self.honor(source, target)
log_transaction(self, source, target, "success")
# If the target vertex is a power component we log messages
if (
target.vertex_type == "ChatOutput"
and isinstance(target.params.get("message"), str)
or isinstance(target.params.get("message"), dict)
):
await log_message(
sender_type=target.params.get("sender", ""),
sender_name=target.params.get("sender_name", ""),
message=target.params.get("message", {}),
session_id=target.params.get("session_id", ""),
artifacts=target.artifacts,
)
return self.result
def __repr__(self) -> str:
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
try:
monitor_service = get_monitor_service()
clean_params = edge.build_clean_params(target)
data = {
"source": source.vertex_type,
"target": target.vertex_type,
"target_args": clean_params,
"timestamp": monitor_service.get_timestamp(),
"status": status,
"error": error,
}
monitor_service.add_row(table_name="transactions", data=data)
except Exception as e:
logger.error(f"Error logging transaction: {e}")
async def log_message(
sender_type: str,
sender_name: str,
message: str,
session_id: str,
artifacts: Optional[dict] = None,
):
try:
from langflow.graph.vertex.base import Vertex
if isinstance(session_id, Vertex):
session_id = await session_id.build() # type: ignore
monitor_service = get_monitor_service()
row = {
"sender_type": sender_type,
"sender_name": sender_name,
"message": message,
"artifacts": artifacts or {},
"session_id": session_id,
"timestamp": monitor_service.get_timestamp(),
}
monitor_service.add_row(table_name="messages", data=row)
except Exception as e:
logger.error(f"Error logging message: {e}")