Change honor method to be asynchronous in ContractEdge

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-26 17:06:16 -03:00
commit 20d9e51208
3 changed files with 38 additions and 23 deletions

View file

@ -105,7 +105,7 @@ class ContractEdge(Edge):
self.is_fulfilled = False # Whether the contract has been fulfilled.
self.result: Any = None
def honor(self, source: "Vertex", target: "Vertex") -> None:
async def honor(self, source: "Vertex", target: "Vertex") -> None:
"""
Fulfills the contract by setting the result of the source vertex to the target vertex's parameter.
If the edge is runnable, the source vertex is run with the message text and the target vertex's
@ -117,7 +117,7 @@ class ContractEdge(Edge):
return
if not source._built:
source.build()
await source.build()
if self.matched_type == "Text":
self.result = source._built_result
@ -144,7 +144,7 @@ class ContractEdge(Edge):
async def get_result(self, source: "Vertex", target: "Vertex"):
# Fulfill the contract if it has not been fulfilled.
if not self.is_fulfilled:
self.honor(source, target)
await self.honor(source, target)
log_transaction(self, source, target, "success")
# If the target vertex is a power component we log messages

View file

@ -4,7 +4,7 @@ from typing import Dict, Generator, List, Type, Union
from langchain.chains.base import Chain
from loguru import logger
from langflow.graph.edge.base import ContractEdge, Edge
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.utils import process_flow
from langflow.graph.vertex.base import Vertex
@ -33,6 +33,7 @@ class Graph:
self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self._build_graph()
def __getstate__(self):
@ -111,7 +112,7 @@ class Graph:
"""Returns a vertex by id."""
return self.vertex_map.get(vertex_id)
def get_vertex_edges(self, vertex_id: str) -> List[Union[Edge, ContractEdge]]:
def get_vertex_edges(self, vertex_id: str) -> List[ContractEdge]:
"""Returns a list of edges for a given vertex."""
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
@ -210,18 +211,19 @@ class Graph:
edges.append(ContractEdge(source, target, edge))
return edges
def _get_vertex_class(self, vertex_type: str, vertex_base_type: str) -> Type[Vertex]:
"""Returns the vertex class based on the vertex type."""
if vertex_type in FILE_TOOLS:
return FileToolVertex
if vertex_base_type == "CustomComponent":
return lazy_load_vertex_dict.get_custom_component_vertex_type()
def _get_vertex_class(self, node_type: str, node_lc_type: str, node_id: str) -> Type[Vertex]:
"""Returns the node class based on the node type."""
node_name = node_id.split("-")[0]
if node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_name]
if vertex_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_base_type]
if node_type in FILE_TOOLS:
return FileToolVertex
if node_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type]
return (
lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type]
if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_lc_type]
if node_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
else Vertex
)
@ -233,7 +235,7 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type)
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)

View file

@ -1,6 +1,8 @@
import ast
from typing import Callable, Dict, List, Optional, Union
from langchain_core.messages import AIMessage
from langflow.graph.utils import UnbuiltObject, flatten_list
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
from langflow.interface.utils import extract_input_variables_from_prompt
@ -318,17 +320,28 @@ class ChatVertex(StatelessVertex):
if self.artifacts and "repr" in self.artifacts:
return self.artifacts["repr"] or super()._built_object_repr()
def _run(self, *args, **kwargs):
async def _run(self, *args, **kwargs):
if self.is_power_component:
if self.vertex_type == "ChatOutput":
artifacts = None
sender = self.params.get("sender", None)
sender_name = self.params.get("sender_name", None)
self.artifacts = ChatOutputResponse(
message=str(self._built_object),
sender=sender,
sender_name=sender_name,
).model_dump()
message = ""
if isinstance(self._built_object, AIMessage):
artifacts = ChatOutputResponse.from_message(
self._built_object,
sender=sender,
sender_name=sender_name,
)
elif not isinstance(self._built_object, UnbuiltObject):
artifacts = ChatOutputResponse(
message=message,
sender=sender,
sender_name=sender_name,
)
if artifacts:
self.artifacts = artifacts.model_dump()
self._built_result = self._built_object
else:
super()._run(*args, **kwargs)
await super()._run(*args, **kwargs)