diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 64e005756..3fc085b25 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -5,6 +5,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional from langflow.graph.utils import UnbuiltObject, UnbuiltResult +from langflow.graph.vertex.utils import generate_result from langflow.interface.initialize import loading from langflow.interface.listing import lazy_load_dict from langflow.utils.constants import DIRECT_TYPES @@ -32,6 +33,10 @@ class Vertex: is_task: bool = False, params: Optional[Dict] = None, ) -> None: + # is_external means that the Vertex send or receives data from + # an external source (e.g the chat) + self.has_external_input = False + self.has_external_output = False self.graph = graph self.id: str = data["id"] self._data = data @@ -283,7 +288,7 @@ class Vertex: self._built = True - async def _run(self, user_id: str, inputs: Optional[dict] = None): + async def _run(self, user_id: str, inputs: Optional[dict] = None, session_id: Optional[str] = None): # user_id is just for compatibility with the other build methods inputs = inputs or {} # inputs = {key: value or "" for key, value in inputs.items()} @@ -297,12 +302,9 @@ class Vertex: # inputs = self._built_object.prompt.partial_variables if isinstance(self._built_object, str): self._built_result = self._built_object - elif hasattr(self._built_object, "run") and not isinstance(self._built_object, UnbuiltObject): - try: - result = self._built_object.run(inputs) - self._built_result = result - except Exception as exc: - logger.error(f"Error running {self.vertex_type}: {exc}") + + result = generate_result(self._built_object, inputs, self.has_external_output, session_id) + self._built_result = result async def _build_each_node_in_params_dict(self, user_id=None): """ @@ -505,3 +507,4 @@ class StatefulVertex(Vertex): class StatelessVertex(Vertex): pass + pass diff --git a/src/backend/langflow/graph/vertex/utils.py b/src/backend/langflow/graph/vertex/utils.py index e1d439f4b..a462623cd 100644 --- a/src/backend/langflow/graph/vertex/utils.py +++ b/src/backend/langflow/graph/vertex/utils.py @@ -1,5 +1,39 @@ +from typing import Any, Optional + +from langchain_core.messages import BaseMessage +from langchain_core.runnables import Runnable +from langflow.services.deps import Union, get_socket_service from langflow.utils.constants import PYTHON_BASIC_TYPES def is_basic_type(obj): return type(obj) in PYTHON_BASIC_TYPES + + +async def invoke_lc_runnable( + built_object: Runnable, inputs: dict, has_external_output: bool, session_id: Optional[str] = None +) -> Union[str, BaseMessage]: + if has_external_output: + socketio_service = get_socket_service() + result = "" + stream = built_object.astream(inputs) + async for chunk in stream: + await socketio_service.emit_token(session_id, chunk) + result += chunk + return built_object.invoke(inputs) + + +async def generate_result(built_object: Any, inputs: dict, has_external_output: bool, session_id: Optional[str] = None): + # If the built_object is instance of Runnable + # we can call `invoke` or `stream` on it + # if it has_external_outputl, we need to call `stream` if it has it + # if not, we call `invoke` if it has it + if isinstance(built_object, Runnable): + result = await invoke_lc_runnable( + built_object=built_object, inputs=inputs, has_external_output=has_external_output, session_id=session_id + ) + else: + result = built_object + return result + result = built_object + return result