From ed66136f85594e18874bfa6b28a109c240a86d11 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 23 Feb 2024 19:15:43 -0300 Subject: [PATCH] Refactor ChatVertex to support message streaming --- src/backend/langflow/graph/vertex/types.py | 67 +++++++++++++++++++++- src/backend/langflow/utils/schemas.py | 1 + 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index e80eef594..705b5d114 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,14 +1,16 @@ import ast import json -from typing import Callable, Dict, List, Optional, Union +from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union import yaml from langchain_core.messages import AIMessage +from loguru import logger from langflow.graph.utils import UnbuiltObject, flatten_list from langflow.graph.vertex.base import StatefulVertex, StatelessVertex from langflow.interface.utils import extract_input_variables_from_prompt from langflow.schema import Record +from langflow.services.monitor.utils import log_message from langflow.utils.schemas import ChatOutputResponse @@ -314,6 +316,20 @@ class ChatVertex(StatelessVertex): super().__init__(data, graph=graph, base_type="custom_components", is_task=True) self.steps = [self._build, self._run] + def build_stream_url(self): + return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream" + + async def _build(self, user_id=None): + """ + Initiate the build process. + """ + logger.debug(f"Building {self.vertex_type}") + await self._build_each_node_in_params_dict(user_id) + await self._get_and_instantiate_class(user_id) + self._validate_built_object() + + self._built = True + def _built_object_repr(self): if self.task_id and self.is_task: if task := self.get_task(): @@ -332,6 +348,8 @@ class ChatVertex(StatelessVertex): artifacts = None sender = self.params.get("sender", None) sender_name = self.params.get("sender_name", None) + message = self.params.get("message", None) + stream_url = None if isinstance(self._built_object, AIMessage): artifacts = ChatOutputResponse.from_message( self._built_object, @@ -345,8 +363,13 @@ class ChatVertex(StatelessVertex): message = dict_to_codeblock(self._built_object) elif isinstance(self._built_object, Record): message = self._built_object.text + elif isinstance(message, (AsyncIterator, Iterator)): + stream_url = self.build_stream_url() + message = "" elif not isinstance(self._built_object, str): message = str(self._built_object) + # if the message is a generator or iterator + # it means that it is a stream of messages else: message = self._built_object @@ -354,14 +377,56 @@ class ChatVertex(StatelessVertex): message=message, sender=sender, sender_name=sender_name, + stream_url=stream_url, ) if artifacts: self.artifacts = artifacts.model_dump() + if isinstance(self._built_object, (AsyncIterator, Iterator)): + if self.params["as_record"]: + self._built_object = Record(text=message, data=self.artifacts) + else: + self._built_object = message self._built_result = self._built_object else: await super()._run(*args, **kwargs) + async def stream(self): + iterator = self.params.get("message", None) + if not isinstance(iterator, (AsyncIterator, Iterator)): + raise ValueError("The message must be an iterator or an async iterator.") + is_async = isinstance(iterator, AsyncIterator) + complete_message = "" + if is_async: + async for message in iterator: + message = message.content if hasattr(message, "content") else message + message = message.text if hasattr(message, "text") else message + yield message + complete_message += message + else: + for message in iterator: + message = message.content if hasattr(message, "content") else message + message = message.text if hasattr(message, "text") else message + yield message + complete_message += message + self._built_object = Record(text=complete_message, data=self.artifacts) + self._built_result = complete_message + # Update artifacts with the message + # and remove the stream_url + self.artifacts = ChatOutputResponse( + message=complete_message, + sender=self.params.get("sender", ""), + sender_name=self.params.get("sender_name", ""), + ).model_dump() + + await log_message( + sender=self.params.get("sender", ""), + sender_name=self.params.get("sender_name", ""), + message=complete_message, + session_id=self.params.get("session_id", ""), + artifacts=self.artifacts, + ) + class RoutingVertex(StatelessVertex): def __init__(self, data: Dict, graph): diff --git a/src/backend/langflow/utils/schemas.py b/src/backend/langflow/utils/schemas.py index bcde2eb2f..8d0b2db12 100644 --- a/src/backend/langflow/utils/schemas.py +++ b/src/backend/langflow/utils/schemas.py @@ -10,6 +10,7 @@ class ChatOutputResponse(BaseModel): message: Union[str, List[Union[str, Dict]]] sender: Optional[str] = "Machine" sender_name: Optional[str] = "AI" + stream_url: Optional[str] = None @classmethod def from_message(