From 60e09a3628fc624705bb8092fb576a85c94ff82e Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 25 Apr 2023 19:11:04 -0300 Subject: [PATCH 1/3] feat(langflow): add support for async functions in Node's func parameter fix(langflow): fix Node's func parameter to be a coroutine function if it is a sync function --- src/backend/langflow/__init__.py | 5 ++++- src/backend/langflow/graph/base.py | 23 ++++++++++++++++------- src/backend/langflow/utils/util.py | 16 ++++++++++++++++ src/backend/langflow/utils/validate.py | 2 +- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/backend/langflow/__init__.py b/src/backend/langflow/__init__.py index 1be5464c2..fb06fe1a7 100644 --- a/src/backend/langflow/__init__.py +++ b/src/backend/langflow/__init__.py @@ -1 +1,4 @@ -from langflow.interface.loading import load_flow_from_json # noqa +from langflow.interface.loading import load_flow_from_json +from langflow.cache import cache_manager + +__all__ = ["load_flow_from_json", "cache_manager"] diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index 012250739..a4e6725da 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -4,6 +4,7 @@ # - Build each inner agent first, then build the outer agent import contextlib +import inspect import types import warnings from copy import deepcopy @@ -14,6 +15,7 @@ from langflow.graph.constants import DIRECT_TYPES from langflow.interface import loading from langflow.interface.listing import ALL_TYPES_DICT from langflow.utils.logger import logger +from langflow.utils.util import sync_to_async class Node: @@ -158,13 +160,20 @@ class Node: continue result = value.build() # If the key is "func", then we need to use the run method - if key == "func" and not isinstance(result, types.FunctionType): - # func can be PythonFunction(code='\ndef upper_case(text: str) -> str:\n return text.upper()\n') - # so we need to check if there is an attribute called run - if hasattr(result, "run"): - result = result.run # type: ignore - elif hasattr(result, "get_function"): - result = result.get_function() # type: ignore + if key == "func": + if not isinstance(result, types.FunctionType): + # func can be PythonFunction(code='\ndef upper_case(text: str) -> str:\n return text.upper()\n') + # so we need to check if there is an attribute called run + if hasattr(result, "run"): + result = result.run # type: ignore + elif hasattr(result, "get_function"): + result = result.get_function() # type: ignore + elif inspect.iscoroutinefunction(result): + self.params["coroutine"] = result + else: + # turn result which is a function into a coroutine + # so that it can be awaited + self.params["coroutine"] = sync_to_async(result) self.params[key] = result elif isinstance(value, list) and all( diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index b31a3bed1..eddd59ce1 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -1,3 +1,5 @@ +import asyncio +from functools import partial, wraps import importlib import inspect import re @@ -301,3 +303,17 @@ def update_verbose(d: dict, new_value: bool) -> dict: elif k == "verbose": d[k] = new_value return d + + +def sync_to_async(func): + """ + Decorator to convert a sync function to an async function. + """ + + @wraps(func) + async def async_wrapper(*args, **kwargs): + loop = asyncio.get_event_loop() + func_call = partial(func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + return async_wrapper diff --git a/src/backend/langflow/utils/validate.py b/src/backend/langflow/utils/validate.py index d1353bd77..59d22a143 100644 --- a/src/backend/langflow/utils/validate.py +++ b/src/backend/langflow/utils/validate.py @@ -155,7 +155,7 @@ def create_function(code, function_name): exec_globals[function_name] = locals()[function_name] # Return a function that imports necessary modules and calls the target function - def wrapped_function(*args, **kwargs): + async def wrapped_function(*args, **kwargs): for module_name, module in exec_globals.items(): if isinstance(module, type(importlib)): globals()[module_name] = module From 57826f12482c8cbd980e66469393d63e45584498 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 25 Apr 2023 20:26:17 -0300 Subject: [PATCH 2/3] fix(api/chat.py): catch and log exceptions in websocket endpoint fix(api/chat_manager.py): remove async from ChatHistory.add_message and on_chat_history_update fix(interface/run.py): remove async from async_get_result_and_steps refactor(utils/util.py): remove unused code and simplify sync_to_async decorator --- src/backend/langflow/api/chat.py | 8 ++++- src/backend/langflow/api/chat_manager.py | 38 +++++++++++++----------- src/backend/langflow/interface/run.py | 10 +++---- src/backend/langflow/utils/util.py | 4 +-- 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py index d5c2dc879..e25d0d2f1 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/chat.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, WebSocket from langflow.api.chat_manager import ChatManager +from langflow.utils.logger import logger router = APIRouter() chat_manager = ChatManager() @@ -9,4 +10,9 @@ chat_manager = ChatManager() @router.websocket("/chat/{client_id}") async def websocket_endpoint(client_id: str, websocket: WebSocket): """Websocket endpoint for chat.""" - await chat_manager.handle_websocket(client_id, websocket) + try: + await chat_manager.handle_websocket(client_id, websocket) + except Exception as e: + # Log stack trace + logger.exception(e) + raise e diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 8f407a791..5b6f25eff 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -8,11 +8,12 @@ import json from langchain.llms import OpenAI, AzureOpenAI from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse -from langflow.cache.manager import AsyncSubject +from langflow.cache.manager import AsyncSubject, Subject from langchain.callbacks.base import AsyncCallbackManager from langflow.api.callback import StreamingLLMCallbackHandler from langflow.interface.run import ( async_get_result_and_steps, + get_result_and_steps, load_or_build_langchain_object, ) from langflow.utils.logger import logger @@ -20,21 +21,23 @@ from langflow.cache import cache_manager from PIL.Image import Image -class ChatHistory(AsyncSubject): +class ChatHistory(Subject): def __init__(self): super().__init__() self.history: Dict[str, List[ChatMessage]] = defaultdict(list) - async def add_message(self, client_id: str, message: ChatMessage): + def add_message(self, client_id: str, message: ChatMessage): """Add a message to the chat history.""" self.history[client_id].append(message) - await self.notify() + self.notify() - def get_history(self, client_id: str) -> List[ChatMessage]: + def get_history(self, client_id: str, filter=True) -> List[ChatMessage]: """Get the chat history for a client.""" if history := self.history.get(client_id, []): - return [msg for msg in history if msg.type not in ["start", "stream"]] + if filter: + return [msg for msg in history if msg.type not in ["start", "stream"]] + return history else: return [] @@ -47,11 +50,11 @@ class ChatManager: self.cache_manager = cache_manager self.cache_manager.attach(self.update) - async def on_chat_history_update(self): + def on_chat_history_update(self): """Send the last chat message to the client.""" client_id = self.cache_manager.current_client_id if client_id in self.active_connections: - chat_response = self.chat_history.get_history(client_id)[-1] + chat_response = self.chat_history.get_history(client_id, filter=False)[-1] if chat_response.is_bot: # Process FileResponse if isinstance(chat_response, FileResponse): @@ -61,8 +64,11 @@ class ChatManager: elif chat_response.data_type == "image": # Base64 encode the image chat_response.data = pil_to_base64(chat_response.data) + # get event loop + loop = asyncio.get_event_loop() - await self.send_json(client_id, chat_response) + coroutine = self.send_json(client_id, chat_response) + asyncio.run_coroutine_threadsafe(coroutine, loop) def update(self): if self.cache_manager.current_client_id in self.active_connections: @@ -75,10 +81,8 @@ class ChatManager: data_type=self.last_cached_object_dict["type"], ) - asyncio.create_task( - self.chat_history.add_message( - self.cache_manager.current_client_id, chat_response - ) + self.chat_history.add_message( + self.cache_manager.current_client_id, chat_response ) async def connect(self, client_id: str, websocket: WebSocket): @@ -100,11 +104,11 @@ class ChatManager: # Process the graph data and chat message chat_message = payload.pop("message", "") chat_message = ChatMessage(message=chat_message) - await self.chat_history.add_message(client_id, chat_message) + self.chat_history.add_message(client_id, chat_message) graph_data = payload start_resp = ChatResponse(message=None, type="start", intermediate_steps="") - await self.chat_history.add_message(client_id, start_resp) + self.chat_history.add_message(client_id, start_resp) is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 # Generate result and thought @@ -127,7 +131,7 @@ class ChatManager: intermediate_steps=intermediate_steps or "", type="end", ) - await self.chat_history.add_message(client_id, response) + self.chat_history.add_message(client_id, response) async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) @@ -173,7 +177,7 @@ async def process_graph( # Generate result and thought try: logger.debug("Generating result and thought") - result, intermediate_steps = await async_get_result_and_steps( + result, intermediate_steps = get_result_and_steps( langchain_object, chat_message.message or "" ) logger.debug("Generated result and intermediate_steps") diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 5fb4f0045..c823ba531 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -240,7 +240,7 @@ def get_result_and_steps(langchain_object, message: str): return result, thought -async def async_get_result_and_steps(langchain_object, message: str): +def async_get_result_and_steps(langchain_object, message: str): """Get result and thought from extracted json""" try: if hasattr(langchain_object, "verbose"): @@ -267,10 +267,10 @@ async def async_get_result_and_steps(langchain_object, message: str): with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): try: - if hasattr(langchain_object, "acall"): - output = await langchain_object.acall(chat_input) - else: - output = langchain_object(chat_input) + # if hasattr(langchain_object, "acall"): + # output = await langchain_object.acall(chat_input) + # else: + output = langchain_object(chat_input) except ValueError as exc: # make the error message more informative logger.debug(f"Error: {str(exc)}") diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index eddd59ce1..080137c26 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -312,8 +312,6 @@ def sync_to_async(func): @wraps(func) async def async_wrapper(*args, **kwargs): - loop = asyncio.get_event_loop() - func_call = partial(func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) + return func(*args, **kwargs) return async_wrapper From 42a17e3aaf94eeafff2e8f54f3c728a3dec473d1 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 25 Apr 2023 20:50:49 -0300 Subject: [PATCH 3/3] refactor(chat_manager.py): remove redundant or condition in line 129 feat(chat_manager.py): add strip() method to intermediate_steps to remove leading/trailing whitespaces --- src/backend/langflow/api/chat_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 5b6f25eff..b71f652d1 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -126,9 +126,10 @@ class ChatManager: logger.exception(e) raise e # Send a response back to the frontend, if needed + intermediate_steps = intermediate_steps or "" response = ChatResponse( message=result or "", - intermediate_steps=intermediate_steps or "", + intermediate_steps=intermediate_steps.strip(), type="end", ) self.chat_history.add_message(client_id, response)