Refactor ChatVertex to support message streaming
This commit is contained in:
parent
a8e27bac6d
commit
ed66136f85
2 changed files with 67 additions and 1 deletions
|
|
@ -1,14 +1,16 @@
|
|||
import ast
|
||||
import json
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import AIMessage
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.utils import UnbuiltObject, flatten_list
|
||||
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langflow.schema import Record
|
||||
from langflow.services.monitor.utils import log_message
|
||||
from langflow.utils.schemas import ChatOutputResponse
|
||||
|
||||
|
||||
|
|
@ -314,6 +316,20 @@ class ChatVertex(StatelessVertex):
|
|||
super().__init__(data, graph=graph, base_type="custom_components", is_task=True)
|
||||
self.steps = [self._build, self._run]
|
||||
|
||||
def build_stream_url(self):
|
||||
return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream"
|
||||
|
||||
async def _build(self, user_id=None):
|
||||
"""
|
||||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.vertex_type}")
|
||||
await self._build_each_node_in_params_dict(user_id)
|
||||
await self._get_and_instantiate_class(user_id)
|
||||
self._validate_built_object()
|
||||
|
||||
self._built = True
|
||||
|
||||
def _built_object_repr(self):
|
||||
if self.task_id and self.is_task:
|
||||
if task := self.get_task():
|
||||
|
|
@ -332,6 +348,8 @@ class ChatVertex(StatelessVertex):
|
|||
artifacts = None
|
||||
sender = self.params.get("sender", None)
|
||||
sender_name = self.params.get("sender_name", None)
|
||||
message = self.params.get("message", None)
|
||||
stream_url = None
|
||||
if isinstance(self._built_object, AIMessage):
|
||||
artifacts = ChatOutputResponse.from_message(
|
||||
self._built_object,
|
||||
|
|
@ -345,8 +363,13 @@ class ChatVertex(StatelessVertex):
|
|||
message = dict_to_codeblock(self._built_object)
|
||||
elif isinstance(self._built_object, Record):
|
||||
message = self._built_object.text
|
||||
elif isinstance(message, (AsyncIterator, Iterator)):
|
||||
stream_url = self.build_stream_url()
|
||||
message = ""
|
||||
elif not isinstance(self._built_object, str):
|
||||
message = str(self._built_object)
|
||||
# if the message is a generator or iterator
|
||||
# it means that it is a stream of messages
|
||||
else:
|
||||
message = self._built_object
|
||||
|
||||
|
|
@ -354,14 +377,56 @@ class ChatVertex(StatelessVertex):
|
|||
message=message,
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
stream_url=stream_url,
|
||||
)
|
||||
if artifacts:
|
||||
self.artifacts = artifacts.model_dump()
|
||||
if isinstance(self._built_object, (AsyncIterator, Iterator)):
|
||||
if self.params["as_record"]:
|
||||
self._built_object = Record(text=message, data=self.artifacts)
|
||||
else:
|
||||
self._built_object = message
|
||||
self._built_result = self._built_object
|
||||
|
||||
else:
|
||||
await super()._run(*args, **kwargs)
|
||||
|
||||
async def stream(self):
|
||||
iterator = self.params.get("message", None)
|
||||
if not isinstance(iterator, (AsyncIterator, Iterator)):
|
||||
raise ValueError("The message must be an iterator or an async iterator.")
|
||||
is_async = isinstance(iterator, AsyncIterator)
|
||||
complete_message = ""
|
||||
if is_async:
|
||||
async for message in iterator:
|
||||
message = message.content if hasattr(message, "content") else message
|
||||
message = message.text if hasattr(message, "text") else message
|
||||
yield message
|
||||
complete_message += message
|
||||
else:
|
||||
for message in iterator:
|
||||
message = message.content if hasattr(message, "content") else message
|
||||
message = message.text if hasattr(message, "text") else message
|
||||
yield message
|
||||
complete_message += message
|
||||
self._built_object = Record(text=complete_message, data=self.artifacts)
|
||||
self._built_result = complete_message
|
||||
# Update artifacts with the message
|
||||
# and remove the stream_url
|
||||
self.artifacts = ChatOutputResponse(
|
||||
message=complete_message,
|
||||
sender=self.params.get("sender", ""),
|
||||
sender_name=self.params.get("sender_name", ""),
|
||||
).model_dump()
|
||||
|
||||
await log_message(
|
||||
sender=self.params.get("sender", ""),
|
||||
sender_name=self.params.get("sender_name", ""),
|
||||
message=complete_message,
|
||||
session_id=self.params.get("session_id", ""),
|
||||
artifacts=self.artifacts,
|
||||
)
|
||||
|
||||
|
||||
class RoutingVertex(StatelessVertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ class ChatOutputResponse(BaseModel):
|
|||
message: Union[str, List[Union[str, Dict]]]
|
||||
sender: Optional[str] = "Machine"
|
||||
sender_name: Optional[str] = "AI"
|
||||
stream_url: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_message(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue