From d978ae543886b8879006e58058f41aae46c04711 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 11:02:32 -0300 Subject: [PATCH 1/9] refactor(cache): move cache-related functions to base.py module feat(cache): add support for pandas and PIL Image objects caching fix(interface): import cache-related functions from base.py module test(cache): update import statements in cache-related test file --- src/backend/langflow/api/base.py | 4 ++ src/backend/langflow/cache/__init__.py | 1 + .../langflow/cache/{utils.py => base.py} | 38 ++++++++++++++++++- src/backend/langflow/graph/base.py | 2 +- src/backend/langflow/interface/run.py | 2 +- tests/test_cache.py | 2 +- 6 files changed, 44 insertions(+), 5 deletions(-) rename src/backend/langflow/cache/{utils.py => base.py} (80%) diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py index 084e04d65..8cddc52e4 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/base.py @@ -3,6 +3,10 @@ from pydantic import BaseModel, validator from langflow.graph.utils import extract_input_variables_from_prompt +class CacheResponse(BaseModel): + data: dict + + class Code(BaseModel): code: str diff --git a/src/backend/langflow/cache/__init__.py b/src/backend/langflow/cache/__init__.py index e69de29bb..f7aac380b 100644 --- a/src/backend/langflow/cache/__init__.py +++ b/src/backend/langflow/cache/__init__.py @@ -0,0 +1 @@ +from langflow.cache.base import add_pandas, add_image, get # noqa diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/base.py similarity index 80% rename from src/backend/langflow/cache/utils.py rename to src/backend/langflow/cache/base.py index 310f3be80..d96614218 100644 --- a/src/backend/langflow/cache/utils.py +++ b/src/backend/langflow/cache/base.py @@ -1,14 +1,18 @@ -import base64 import contextlib import functools import hashlib + import json import os import tempfile from collections import OrderedDict from pathlib import Path +from typing import Any +from PIL import Image +import dill +import pandas as pd # type: ignore -import dill # type: ignore +CACHE = {} def create_cache_folder(func): @@ -147,3 +151,33 @@ def load_cache(hash_val): with cache_path.open("rb") as cache_file: return dill.load(cache_file) return None + + +def add_pandas(name: str, obj: Any): + if isinstance(obj, (pd.DataFrame, pd.Series)): + CACHE[name] = {"obj": obj, "type": "pandas"} + else: + raise ValueError("Object is not a pandas DataFrame or Series") + + +def add_image(name: str, obj: Any): + if isinstance(obj, Image.Image): + CACHE[name] = {"obj": obj, "type": "image"} + else: + raise ValueError("Object is not a PIL Image") + + +def get(name: str): + return CACHE.get(name, {}).get("obj", None) + + +# get last added item +def get_last(): + obj_dict = list(CACHE.values())[-1] + if obj_dict["type"] == "pandas": + # return a csv string + return obj_dict["obj"].to_csv() + elif obj_dict["type"] == "image": + # return a base64 encoded string + return base64.b64encode(obj_dict["obj"].tobytes()).decode("utf-8") + return obj_dict["obj"] diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index ff586c6da..012250739 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -9,7 +9,7 @@ import warnings from copy import deepcopy from typing import Any, Dict, List, Optional -from langflow.cache import utils as cache_utils +from langflow.cache import base as cache_utils from langflow.graph.constants import DIRECT_TYPES from langflow.interface import loading from langflow.interface.listing import ALL_TYPES_DICT diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 300e09e01..89fd7d784 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -2,7 +2,7 @@ import contextlib import io from typing import Any, Dict -from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict +from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict from langflow.graph.graph import Graph from langflow.interface import loading from langflow.utils.logger import logger diff --git a/tests/test_cache.py b/tests/test_cache.py index 9c6ad30e3..131e015f3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -3,7 +3,7 @@ import tempfile from pathlib import Path import pytest -from langflow.cache.utils import PREFIX, save_cache +from langflow.cache.base import PREFIX, save_cache from langflow.interface.run import load_langchain_object From e4d0a39b0b5a54f05203c740781ff2e8895ba3e0 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 11:49:17 -0300 Subject: [PATCH 2/9] refactor: remove unnecessary build and push scripts and Dockerfiles The removed files were unnecessary and were removed to simplify the project. --- build_and_push | 11 -------- src/backend/build.Dockerfile | 52 ----------------------------------- src/backend/build_and_push | 6 ---- src/backend/run | 8 ------ src/frontend/build.Dockerfile | 5 ---- src/frontend/build_and_push | 11 -------- 6 files changed, 93 deletions(-) delete mode 100755 build_and_push delete mode 100644 src/backend/build.Dockerfile delete mode 100755 src/backend/build_and_push delete mode 100755 src/backend/run delete mode 100644 src/frontend/build.Dockerfile delete mode 100755 src/frontend/build_and_push diff --git a/build_and_push b/build_and_push deleted file mode 100755 index e9c9edf14..000000000 --- a/build_and_push +++ /dev/null @@ -1,11 +0,0 @@ -#! /bin/bash - -cd src/frontend -docker build -t logspace/frontend_build -f build.Dockerfile . -cd ../backend -docker build -t logspace/backend_build -f build.Dockerfile . - -cd ../../ -VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version) -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION . -docker push ibiscp/langflow:$VERSION diff --git a/src/backend/build.Dockerfile b/src/backend/build.Dockerfile deleted file mode 100644 index 42452c691..000000000 --- a/src/backend/build.Dockerfile +++ /dev/null @@ -1,52 +0,0 @@ -# `python-base` sets up all our shared environment variables -FROM python:3.10-slim - -# python -ENV PYTHONUNBUFFERED=1 \ - # prevents python creating .pyc files - PYTHONDONTWRITEBYTECODE=1 \ - \ - # pip - PIP_NO_CACHE_DIR=off \ - PIP_DISABLE_PIP_VERSION_CHECK=on \ - PIP_DEFAULT_TIMEOUT=100 \ - \ - # poetry - # https://python-poetry.org/docs/configuration/#using-environment-variables - POETRY_VERSION=1.4.0 \ - # make poetry install to this location - POETRY_HOME="/opt/poetry" \ - # make poetry create the virtual environment in the project's root - # it gets named `.venv` - POETRY_VIRTUALENVS_IN_PROJECT=true \ - # do not ask any interactive question - POETRY_NO_INTERACTION=1 \ - \ - # paths - # this is where our requirements + virtual environment will live - PYSETUP_PATH="/opt/pysetup" \ - VENV_PATH="/opt/pysetup/.venv" - -# prepend poetry and venv to path -ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" - -RUN apt-get update \ - && apt-get install --no-install-recommends -y \ - # deps for installing poetry - curl \ - # deps for building python deps - build-essential libpq-dev - -# install poetry - respects $POETRY_VERSION & $POETRY_HOME -RUN curl -sSL https://install.python-poetry.org | python3 - - -# copy project requirement files here to ensure they will be cached. -WORKDIR /app -COPY poetry.lock pyproject.toml ./ -COPY langflow/ ./langflow - -# poetry install -RUN poetry install --without dev - -# build wheel -RUN poetry build -f wheel diff --git a/src/backend/build_and_push b/src/backend/build_and_push deleted file mode 100755 index d70366d72..000000000 --- a/src/backend/build_and_push +++ /dev/null @@ -1,6 +0,0 @@ -#! /bin/bash - -docker build -t logspace/backend_build -f build.Dockerfile . -VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version) -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION . -docker push ibiscp/langflow:$VERSION diff --git a/src/backend/run b/src/backend/run deleted file mode 100755 index c9ba24768..000000000 --- a/src/backend/run +++ /dev/null @@ -1,8 +0,0 @@ -#! /bin/bash - -poetry remove langchain -docker build -t logspace/backend_build -f build.Dockerfile . -VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version) -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION . -docker run -p 5003:80 -d ibiscp/langflow:$VERSION -poetry add --editable ../../../langchain diff --git a/src/frontend/build.Dockerfile b/src/frontend/build.Dockerfile deleted file mode 100644 index af2335b5c..000000000 --- a/src/frontend/build.Dockerfile +++ /dev/null @@ -1,5 +0,0 @@ -FROM node:14-alpine -WORKDIR /app -COPY . /app -RUN npm install -RUN npm run build \ No newline at end of file diff --git a/src/frontend/build_and_push b/src/frontend/build_and_push deleted file mode 100755 index f223dd343..000000000 --- a/src/frontend/build_and_push +++ /dev/null @@ -1,11 +0,0 @@ -#! /bin/bash - -# Read the contents of the JSON file -json=$(cat package.json) - -# Extract the value of the "version" field using jq -VERSION=$(echo "$json" | jq -r '.version') - -docker build -t logspace/frontend_build -f build.Dockerfile . -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow_frontend:$VERSION . -docker push ibiscp/langflow_frontend:$VERSION From 18b35838504a310c2f06a7c96b401eee854646bc Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 13:13:58 -0300 Subject: [PATCH 3/9] test(websocket.py): add tests for websocket connection, chat history and sending message --- src/backend/langflow/api/chat.py | 13 ++++ src/backend/langflow/api/chat_manager.py | 99 ++++++++++++++++++++++++ src/backend/langflow/api/schemas.py | 29 +++++++ src/backend/langflow/cache/base.py | 1 + src/backend/langflow/interface/run.py | 6 +- src/backend/langflow/main.py | 2 + tests/conftest.py | 47 +++++++++++ tests/test_graph.py | 36 +-------- tests/test_websocket.py | 30 +++++++ 9 files changed, 226 insertions(+), 37 deletions(-) create mode 100644 src/backend/langflow/api/chat.py create mode 100644 src/backend/langflow/api/chat_manager.py create mode 100644 src/backend/langflow/api/schemas.py create mode 100644 tests/test_websocket.py diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py new file mode 100644 index 000000000..e40ac34ca --- /dev/null +++ b/src/backend/langflow/api/chat.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter, WebSocket +from uuid import uuid4 + +from langflow.api.chat_manager import ChatManager + +router = APIRouter() +chat_manager = ChatManager() + + +@router.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + client_id = str(uuid4()) + await chat_manager.handle_websocket(client_id, websocket) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py new file mode 100644 index 000000000..17902247b --- /dev/null +++ b/src/backend/langflow/api/chat_manager.py @@ -0,0 +1,99 @@ +from typing import Dict, List +from collections import defaultdict +from fastapi import WebSocket +import json +from langflow.api.schemas import ChatMessage, ChatResponse + +from langflow.interface.run import ( + get_result_and_steps, + load_or_build_langchain_object, +) +from langflow.utils.logger import logger + + +class ChatHistory: + def __init__(self): + self.history: Dict[str, List[ChatMessage]] = defaultdict(list) + + def add_message(self, client_id: str, message: ChatMessage): + self.history[client_id].append(message) + + def get_history(self, client_id: str) -> List[ChatMessage]: + return self.history[client_id] + + +class ChatManager: + def __init__(self): + self.active_connections: Dict[str, WebSocket] = {} + self.chat_history = ChatHistory() + + async def connect(self, client_id: str, websocket: WebSocket): + await websocket.accept() + self.active_connections[client_id] = websocket + + def disconnect(self, client_id: str): + del self.active_connections[client_id] + + async def send_message(self, client_id: str, message: str): + websocket = self.active_connections[client_id] + await websocket.send_text(message) + + async def send_json(self, client_id: str, message: Dict): + websocket = self.active_connections[client_id] + await websocket.send_json(message) + + async def process_message(self, client_id: str, payload: Dict): + # Process the graph data and chat message + + chat_message = payload.pop("message", "") + chat_message = ChatMessage(sender="user", message=chat_message) + graph_data = payload + start_resp = ChatResponse( + sender="bot", message="", type="start", intermediate_steps="" + ) + await self.send_json(client_id, start_resp.dict()) + + is_first_message = len(graph_data.get("chatHistory", [])) == 0 + langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + logger.debug("Generating result and thought") + result, intermediate_steps = get_result_and_steps( + langchain_object, chat_message.message + ) + + logger.debug("Generated result and intermediate_steps") + # Save the message to chat history + self.chat_history.add_message(client_id, chat_message) + + # Send a response back to the frontend, if needed + response = ChatResponse( + sender="bot", + message=result or "", + intermediate_steps=intermediate_steps or "", + type="end", + ) + await self.send_json(client_id, response.dict()) + + async def handle_websocket(self, client_id: str, websocket: WebSocket): + await self.connect(client_id, websocket) + try: + chat_history = self.chat_history.get_history(client_id) + await websocket.send_text(json.dumps(chat_history)) + + while True: + json_payload = await websocket.receive_text() + payload = json.loads(json_payload) + await self.process_message(client_id, payload) + except Exception as e: + # Handle any exceptions that might occur + print(f"Error: {e}") + finally: + self.disconnect(client_id) diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py new file mode 100644 index 000000000..588c35287 --- /dev/null +++ b/src/backend/langflow/api/schemas.py @@ -0,0 +1,29 @@ +from typing import Any +from pydantic import BaseModel, validator + + +class ChatMessage(BaseModel): + """Chat message schema.""" + + sender: str + message: str + + @validator("sender") + def sender_must_be_bot_or_you(cls, v): + if v not in ["bot", "you"]: + raise ValueError("sender must be bot or you") + return v + + +class ChatResponse(ChatMessage): + """Chat response schema.""" + + intermediate_steps: str + type: str + data: Any = None + + @validator("type") + def validate_message_type(cls, v): + if v not in ["start", "stream", "end", "error", "info"]: + raise ValueError("type must be start, stream, end, error or info") + return v diff --git a/src/backend/langflow/cache/base.py b/src/backend/langflow/cache/base.py index d96614218..9dd5c1780 100644 --- a/src/backend/langflow/cache/base.py +++ b/src/backend/langflow/cache/base.py @@ -1,3 +1,4 @@ +import base64 import contextlib import functools import hashlib diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 89fd7d784..f8920724a 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -86,7 +86,7 @@ def process_graph(data_graph: Dict[str, Any]): # Generate result and thought logger.debug("Generating result and thought") - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) logger.debug("Generated result and thought") # Save langchain_object to cache @@ -117,7 +117,7 @@ def process_graph_cached(data_graph: Dict[str, Any]): # Generate result and thought logger.debug("Generating result and thought") - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) logger.debug("Generated result and thought") return {"result": str(result), "thought": thought.strip()} @@ -183,7 +183,7 @@ def fix_memory_inputs(langchain_object): update_memory_keys(langchain_object, possible_new_mem_key) -def get_result_and_thought_using_graph(langchain_object, message: str): +def get_result_and_steps(langchain_object, message: str): """Get result and thought from extracted json""" try: if hasattr(langchain_object, "verbose"): diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 176e46236..c1f2decd5 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from langflow.api.endpoints import router as endpoints_router from langflow.api.validate import router as validate_router +from langflow.api.chat import router as chat_router def create_app(): @@ -23,6 +24,7 @@ def create_app(): app.include_router(endpoints_router) app.include_router(validate_router) + app.include_router(chat_router) return app diff --git a/tests/conftest.py b/tests/conftest.py index e6eb3562f..15da0d1ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,8 @@ +import json from pathlib import Path +from typing import AsyncGenerator +from httpx import AsyncClient + import pytest from fastapi.testclient import TestClient @@ -21,6 +25,15 @@ def get_text(): """ +@pytest.fixture() +async def async_client() -> AsyncGenerator: + from langflow.main import create_app + + app = create_app() + async with AsyncClient(app=app, base_url="http://testserver") as client: + yield client + + # Create client fixture for FastAPI @pytest.fixture(scope="module") def client(): @@ -30,3 +43,37 @@ def client(): with TestClient(app) as client: yield client + + +def get_graph(_type="basic"): + """Get a graph from a json file""" + from langflow.graph.graph import Graph + + if _type == "basic": + path = pytest.BASIC_EXAMPLE_PATH + elif _type == "complex": + path = pytest.COMPLEX_EXAMPLE_PATH + elif _type == "openapi": + path = pytest.OPENAPI_EXAMPLE_PATH + + with open(path, "r") as f: + flow_graph = json.load(f) + data_graph = flow_graph["data"] + nodes = data_graph["nodes"] + edges = data_graph["edges"] + return Graph(nodes, edges) + + +@pytest.fixture +def basic_graph(): + return get_graph() + + +@pytest.fixture +def complex_graph(): + return get_graph("complex") + + +@pytest.fixture +def openapi_graph(): + return get_graph("openapi") diff --git a/tests/test_graph.py b/tests/test_graph.py index c007bdd78..76451fe6a 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -15,7 +15,7 @@ from langflow.graph.nodes import ( ToolNode, WrapperNode, ) -from langflow.interface.run import get_result_and_thought_using_graph +from langflow.interface.run import get_result_and_steps from langflow.utils.payload import build_json, get_root_node # Test cases for the graph module @@ -24,38 +24,6 @@ from langflow.utils.payload import build_json, get_root_node # BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH -def get_graph(_type="basic"): - """Get a graph from a json file""" - if _type == "basic": - path = pytest.BASIC_EXAMPLE_PATH - elif _type == "complex": - path = pytest.COMPLEX_EXAMPLE_PATH - elif _type == "openapi": - path = pytest.OPENAPI_EXAMPLE_PATH - - with open(path, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - edges = data_graph["edges"] - return Graph(nodes, edges) - - -@pytest.fixture -def basic_graph(): - return get_graph() - - -@pytest.fixture -def complex_graph(): - return get_graph("complex") - - -@pytest.fixture -def openapi_graph(): - return get_graph("openapi") - - def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]: """Get a node by type""" return next((node for node in graph.nodes if isinstance(node, node_type)), None) @@ -441,7 +409,7 @@ def test_get_result_and_thought(basic_graph): # now build again and check if FakeListLLM was used # Get the result and thought - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) # The result should be a str assert isinstance(result, str) # The thought should be a Thought diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 000000000..9ce20bc45 --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,30 @@ +import json + + +def test_websocket_connection(client): + with client.websocket_connect("/ws") as websocket: + assert websocket.client == client + assert websocket.url.path == "/ws" + + +def test_chat_history(client): + chat_history = ["Test message 1", "Test message 2"] + + with client.websocket_connect("/ws") as websocket: + received_history = websocket.receive_text() + received_history = json.loads(received_history) + + assert received_history == chat_history + + +def test_send_message(client, basic_graph): + with client.websocket_connect("/ws") as websocket: + # Send the JSON payload through the WebSocket connection + websocket.send_text(basic_graph) + + # Receive and parse the response from the server + response = websocket.receive_text() + response = json.loads(response) + + # Test that the response is as expected + assert response == "Your response message here" From 9b1f86b68165c689c90924271647ca20f4daa01c Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 14:20:05 -0300 Subject: [PATCH 4/9] refactor(chatComponent): simplify conditional statement in Chat component's error handling logic --- src/frontend/src/components/chatComponent/index.tsx | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/frontend/src/components/chatComponent/index.tsx b/src/frontend/src/components/chatComponent/index.tsx index 11516f004..1febb7539 100644 --- a/src/frontend/src/components/chatComponent/index.tsx +++ b/src/frontend/src/components/chatComponent/index.tsx @@ -93,7 +93,7 @@ export default function Chat({ flow, reactFlowInstance }: ChatType) { (errors: Array, t) => errors.concat( (template[t].required && template[t].show) && - (!template[t].value || template[t].value === "") && + (!template[t].value && template[t].value !== false && template[t].value === "") && !reactFlowInstance .getEdges() .some( @@ -102,12 +102,11 @@ export default function Chat({ flow, reactFlowInstance }: ChatType) { e.targetHandle.split("|")[2] === n.id ) ? [ - `${type} is missing ${ - template.display_name - ? template.display_name - : snakeToNormalCase(template[t].name) - }.`, - ] + `${type} is missing ${template.display_name + ? template.display_name + : snakeToNormalCase(template[t].name) + }.`, + ] : [] ), [] as string[] From 7d183ff57ed1fdc23ee9993a8d563925c6ad52c0 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 21:28:05 -0300 Subject: [PATCH 5/9] refactor(chat.py, chat_manager.py, schemas.py, run.py): add chat history to ChatManager and ChatMessage schema feat(chat.py, chat_manager.py): add error handling for async_get_result_and_steps feat(chat.py): add client_id to websocket endpoint feat(schemas.py): add data_type field to ChatResponse schema refactor(run.py): memoize build_langchain_object_with_caching function with maxsize of 10 --- src/backend/langflow/api/chat.py | 5 +-- src/backend/langflow/api/chat_manager.py | 44 ++++++++++-------- src/backend/langflow/api/schemas.py | 5 ++- src/backend/langflow/interface/run.py | 57 +++++++++++++++++++++++- 4 files changed, 87 insertions(+), 24 deletions(-) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py index e40ac34ca..11b861c77 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/chat.py @@ -7,7 +7,6 @@ router = APIRouter() chat_manager = ChatManager() -@router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): - client_id = str(uuid4()) +@router.websocket("/ws/{client_id}") +async def websocket_endpoint(client_id: str, websocket: WebSocket): await chat_manager.handle_websocket(client_id, websocket) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 17902247b..384de4d12 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -5,7 +5,7 @@ import json from langflow.api.schemas import ChatMessage, ChatResponse from langflow.interface.run import ( - get_result_and_steps, + async_get_result_and_steps, load_or_build_langchain_object, ) from langflow.utils.logger import logger @@ -38,22 +38,25 @@ class ChatManager: websocket = self.active_connections[client_id] await websocket.send_text(message) - async def send_json(self, client_id: str, message: Dict): + async def send_json(self, client_id: str, message: ChatMessage): websocket = self.active_connections[client_id] - await websocket.send_json(message) + self.chat_history.add_message(client_id, message) + await websocket.send_json(message.dict()) async def process_message(self, client_id: str, payload: Dict): # Process the graph data and chat message chat_message = payload.pop("message", "") chat_message = ChatMessage(sender="user", message=chat_message) + self.chat_history.add_message(client_id, chat_message) + graph_data = payload start_resp = ChatResponse( - sender="bot", message="", type="start", intermediate_steps="" + sender="bot", message=None, type="start", intermediate_steps="" ) - await self.send_json(client_id, start_resp.dict()) + await self.send_json(client_id, start_resp) - is_first_message = len(graph_data.get("chatHistory", [])) == 0 + is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 langchain_object = load_or_build_langchain_object(graph_data, is_first_message) logger.debug("Loaded langchain object") @@ -64,15 +67,20 @@ class ChatManager: ) # Generate result and thought - logger.debug("Generating result and thought") - result, intermediate_steps = get_result_and_steps( - langchain_object, chat_message.message - ) - - logger.debug("Generated result and intermediate_steps") - # Save the message to chat history - self.chat_history.add_message(client_id, chat_message) - + try: + logger.debug("Generating result and thought") + result, intermediate_steps = await async_get_result_and_steps( + langchain_object, chat_message.message or "" + ) + logger.debug("Generated result and intermediate_steps") + except Exception as e: + # Log stack trace + logger.exception(e) + error_resp = ChatResponse( + sender="bot", message=str(e), type="error", intermediate_steps="" + ) + await self.send_json(client_id, error_resp) + return # Send a response back to the frontend, if needed response = ChatResponse( sender="bot", @@ -80,16 +88,16 @@ class ChatManager: intermediate_steps=intermediate_steps or "", type="end", ) - await self.send_json(client_id, response.dict()) + await self.send_json(client_id, response) async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) try: chat_history = self.chat_history.get_history(client_id) - await websocket.send_text(json.dumps(chat_history)) + await websocket.send_json(json.dumps(chat_history)) while True: - json_payload = await websocket.receive_text() + json_payload = await websocket.receive_json() payload = json.loads(json_payload) await self.process_message(client_id, payload) except Exception as e: diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index 588c35287..fd9ef0816 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Union from pydantic import BaseModel, validator @@ -6,7 +6,7 @@ class ChatMessage(BaseModel): """Chat message schema.""" sender: str - message: str + message: Union[str, None] = None @validator("sender") def sender_must_be_bot_or_you(cls, v): @@ -21,6 +21,7 @@ class ChatResponse(ChatMessage): intermediate_steps: str type: str data: Any = None + data_type: str = "" @validator("type") def validate_message_type(cls, v): diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index f8920724a..110d3827f 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -31,7 +31,7 @@ def load_or_build_langchain_object(data_graph, is_first_message=False): return build_langchain_object_with_caching(data_graph) -@memoize_dict(maxsize=1) +@memoize_dict(maxsize=10) def build_langchain_object_with_caching(data_graph): """ Build langchain object from data_graph. @@ -235,6 +235,61 @@ def get_result_and_steps(langchain_object, message: str): return result, thought +async def async_get_result_and_steps(langchain_object, message: str): + """Get result and thought from extracted json""" + try: + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + chat_input = None + memory_key = "" + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + memory_key = langchain_object.memory.memory_key + + if hasattr(langchain_object, "input_keys"): + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: + chat_input = {key: message} + else: + chat_input = message # type: ignore + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + + fix_memory_inputs(langchain_object) + + with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): + try: + if hasattr(langchain_object, "acall"): + output = await langchain_object.acall(chat_input) + else: + output = langchain_object(chat_input) + except ValueError as exc: + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + output = langchain_object.run(chat_input) + + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) + + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) + if intermediate_steps: + thought = format_intermediate_steps(intermediate_steps) + else: + thought = output_buffer.getvalue() + + except Exception as exc: + raise ValueError(f"Error: {str(exc)}") from exc + return result, thought + + def get_result_and_thought(extracted_json: Dict[str, Any], message: str): """Get result and thought from extracted json""" try: From 0a630cd70daa593f74ea3dbe700058dc811d906d Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Wed, 19 Apr 2023 22:23:31 -0300 Subject: [PATCH 6/9] refactor(chat_manager.py): move process_graph function outside of ChatManager class test(websocket.py): add tests for websocket connection, chat history, and sending messages --- src/backend/langflow/api/chat_manager.py | 46 +++++++++++------ tests/test_websocket.py | 63 ++++++++++++++++-------- 2 files changed, 74 insertions(+), 35 deletions(-) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 384de4d12..7ca04abf7 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -41,13 +41,13 @@ class ChatManager: async def send_json(self, client_id: str, message: ChatMessage): websocket = self.active_connections[client_id] self.chat_history.add_message(client_id, message) - await websocket.send_json(message.dict()) + await websocket.send_json(json.dumps(message.dict())) async def process_message(self, client_id: str, payload: Dict): # Process the graph data and chat message chat_message = payload.pop("message", "") - chat_message = ChatMessage(sender="user", message=chat_message) + chat_message = ChatMessage(sender="you", message=chat_message) self.chat_history.add_message(client_id, chat_message) graph_data = payload @@ -57,22 +57,14 @@ class ChatManager: await self.send_json(client_id, start_resp) is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 - langchain_object = load_or_build_langchain_object(graph_data, is_first_message) - logger.debug("Loaded langchain object") - - if langchain_object is None: - # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) - # Generate result and thought try: logger.debug("Generating result and thought") - result, intermediate_steps = await async_get_result_and_steps( - langchain_object, chat_message.message or "" + result, intermediate_steps = await process_graph( + graph_data=graph_data, + is_first_message=is_first_message, + chat_message=chat_message, ) - logger.debug("Generated result and intermediate_steps") except Exception as e: # Log stack trace logger.exception(e) @@ -105,3 +97,29 @@ class ChatManager: print(f"Error: {e}") finally: self.disconnect(client_id) + + +async def process_graph( + graph_data: Dict, is_first_message: bool, chat_message: ChatMessage +): + langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + try: + logger.debug("Generating result and thought") + result, intermediate_steps = await async_get_result_and_steps( + langchain_object, chat_message.message or "" + ) + logger.debug("Generated result and intermediate_steps") + return result, intermediate_steps + except Exception as e: + # Log stack trace + logger.exception(e) + raise e diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 9ce20bc45..41405867f 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -1,30 +1,51 @@ import json +from unittest.mock import patch +from langflow.api.schemas import ChatMessage +from fastapi.testclient import TestClient -def test_websocket_connection(client): - with client.websocket_connect("/ws") as websocket: - assert websocket.client == client - assert websocket.url.path == "/ws" +def test_websocket_connection(client: TestClient): + with client.websocket_connect("/ws/test_client") as websocket: + assert websocket.scope["client"] == ["testclient", 50000] + assert websocket.scope["path"] == "/ws/test_client" -def test_chat_history(client): - chat_history = ["Test message 1", "Test message 2"] +def test_chat_history(client: TestClient): + chat_history = [] - with client.websocket_connect("/ws") as websocket: - received_history = websocket.receive_text() - received_history = json.loads(received_history) + # Mock the process_graph function to return a specific value + with patch("langflow.api.chat_manager.process_graph") as mock_process_graph: + mock_process_graph.return_value = ("Hello, I'm a mock response!", "") - assert received_history == chat_history + with client.websocket_connect("/ws/test_client") as websocket: + # First message should be the history + history = websocket.receive_json() + assert json.loads(history) == [] # Empty history + # Send a message + payload = {"message": "Hello"} + websocket.send_json(json.dumps(payload)) + # Receive the response from the server + response = websocket.receive_json() + assert json.loads(response) == { + "sender": "bot", + "message": None, + "intermediate_steps": "", + "type": "start", + "data": None, + "data_type": "", + } + # Send another message + payload = {"message": "How are you?"} + websocket.send_json(json.dumps(payload)) -def test_send_message(client, basic_graph): - with client.websocket_connect("/ws") as websocket: - # Send the JSON payload through the WebSocket connection - websocket.send_text(basic_graph) - - # Receive and parse the response from the server - response = websocket.receive_text() - response = json.loads(response) - - # Test that the response is as expected - assert response == "Your response message here" + # Receive the response from the server + response = websocket.receive_json() + assert json.loads(response) == { + "sender": "bot", + "message": "Hello, I'm a mock response!", + "intermediate_steps": "", + "type": "end", + "data": None, + "data_type": "", + } From 3da30cc5bffd0681d9ebb29e424287e45cad6798 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Thu, 20 Apr 2023 11:09:11 -0300 Subject: [PATCH 7/9] refactor(cache): move cache functionality to a separate class feat(cache): add support for multiple clients and context manager to set client_id feat(cache): add observer pattern to notify on cache changes feat(cache): add async observer pattern to notify on cache changes in async functions feat(cache): add methods to add pandas DataFrame or Series and PIL Image to cache feat(cache): add method to get an object from cache by key feat(cache): add method to get the last added item in cache --- src/backend/langflow/cache/__init__.py | 2 +- src/backend/langflow/cache/base.py | 30 ------ src/backend/langflow/cache/manager.py | 126 +++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 31 deletions(-) create mode 100644 src/backend/langflow/cache/manager.py diff --git a/src/backend/langflow/cache/__init__.py b/src/backend/langflow/cache/__init__.py index f7aac380b..583d5ac6d 100644 --- a/src/backend/langflow/cache/__init__.py +++ b/src/backend/langflow/cache/__init__.py @@ -1 +1 @@ -from langflow.cache.base import add_pandas, add_image, get # noqa +from langflow.cache.manager import cache_manager # noqa diff --git a/src/backend/langflow/cache/base.py b/src/backend/langflow/cache/base.py index 9dd5c1780..ba250da6b 100644 --- a/src/backend/langflow/cache/base.py +++ b/src/backend/langflow/cache/base.py @@ -152,33 +152,3 @@ def load_cache(hash_val): with cache_path.open("rb") as cache_file: return dill.load(cache_file) return None - - -def add_pandas(name: str, obj: Any): - if isinstance(obj, (pd.DataFrame, pd.Series)): - CACHE[name] = {"obj": obj, "type": "pandas"} - else: - raise ValueError("Object is not a pandas DataFrame or Series") - - -def add_image(name: str, obj: Any): - if isinstance(obj, Image.Image): - CACHE[name] = {"obj": obj, "type": "image"} - else: - raise ValueError("Object is not a PIL Image") - - -def get(name: str): - return CACHE.get(name, {}).get("obj", None) - - -# get last added item -def get_last(): - obj_dict = list(CACHE.values())[-1] - if obj_dict["type"] == "pandas": - # return a csv string - return obj_dict["obj"].to_csv() - elif obj_dict["type"] == "image": - # return a base64 encoded string - return base64.b64encode(obj_dict["obj"].tobytes()).decode("utf-8") - return obj_dict["obj"] diff --git a/src/backend/langflow/cache/manager.py b/src/backend/langflow/cache/manager.py new file mode 100644 index 000000000..ba34a3a8d --- /dev/null +++ b/src/backend/langflow/cache/manager.py @@ -0,0 +1,126 @@ +from contextlib import contextmanager +from typing import Any, Awaitable, Callable, List +from PIL import Image +import pandas as pd + + +class Subject: + """Base class for implementing the observer pattern.""" + + def __init__(self): + self.observers: List[Callable[[], None]] = [] + + def attach(self, observer: Callable[[], None]): + """Attach an observer to the subject.""" + self.observers.append(observer) + + def detach(self, observer: Callable[[], None]): + """Detach an observer from the subject.""" + self.observers.remove(observer) + + def notify(self): + """Notify all observers about an event.""" + for observer in self.observers: + if observer is None: + continue + observer() + + +class AsyncSubject: + """Base class for implementing the async observer pattern.""" + + def __init__(self): + self.observers: List[Callable[[], Awaitable]] = [] + + def attach(self, observer: Callable[[], Awaitable]): + """Attach an observer to the subject.""" + self.observers.append(observer) + + def detach(self, observer: Callable[[], Awaitable]): + """Detach an observer from the subject.""" + self.observers.remove(observer) + + async def notify(self): + """Notify all observers about an event.""" + for observer in self.observers: + if observer is None: + continue + await observer() + + +class CacheManager(Subject): + """Manages cache for different clients and notifies observers on changes.""" + + def __init__(self): + super().__init__() + self.CACHE = {} + self.current_client_id = None + + @contextmanager + def set_client_id(self, client_id: str): + """ + Context manager to set the current client_id and associated cache. + + Args: + client_id (str): The client identifier. + """ + previous_client_id = self.current_client_id + self.current_client_id = client_id + self.current_cache = self.CACHE.setdefault(client_id, {}) + try: + yield + finally: + self.current_client_id = previous_client_id + self.current_cache = self.CACHE.get(self.current_client_id, {}) + + def add_pandas(self, name: str, obj: Any): + """ + Add a pandas DataFrame or Series to the current client's cache. + + Args: + name (str): The cache key. + obj (Any): The pandas DataFrame or Series object. + """ + if isinstance(obj, (pd.DataFrame, pd.Series)): + self.current_cache[name] = {"obj": obj, "type": "pandas"} + self.notify() + else: + raise ValueError("Object is not a pandas DataFrame or Series") + + def add_image(self, name: str, obj: Any): + """ + Add a PIL Image to the current client's cache. + + Args: + name (str): The cache key. + obj (Any): The PIL Image object. + """ + if isinstance(obj, Image.Image): + self.current_cache[name] = {"obj": obj, "type": "image"} + self.notify() + else: + raise ValueError("Object is not a PIL Image") + + def get(self, name: str): + """ + Get an object from the current client's cache. + + Args: + name (str): The cache key. + + Returns: + The cached object associated with the given cache key. + """ + return self.current_cache[name] + + def get_last(self): + """ + Get the last added item in the current client's cache. + + Returns: + The last added item in the cache. + """ + return list(self.current_cache.values())[-1] + + +cache_manager = CacheManager() From 5169c0bc27960b678961fa94fed066a648c8efa4 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Thu, 20 Apr 2023 11:09:42 -0300 Subject: [PATCH 8/9] feat(chat_manager.py): add support for sending file responses fix(schemas.py): add validation for file response type and data type test(test_websocket.py): remove data and data_type fields from ChatResponse messages in tests --- src/backend/langflow/api/chat_manager.py | 69 +++++++++++++++++++++--- src/backend/langflow/api/schemas.py | 20 +++++-- tests/test_websocket.py | 4 -- 3 files changed, 77 insertions(+), 16 deletions(-) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 7ca04abf7..8dcaf05ac 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -1,22 +1,30 @@ +import asyncio +import base64 +from io import BytesIO from typing import Dict, List from collections import defaultdict from fastapi import WebSocket import json -from langflow.api.schemas import ChatMessage, ChatResponse +from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse +from langflow.cache.manager import AsyncSubject from langflow.interface.run import ( async_get_result_and_steps, load_or_build_langchain_object, ) from langflow.utils.logger import logger +from langflow.cache import cache_manager +from PIL.Image import Image -class ChatHistory: +class ChatHistory(AsyncSubject): def __init__(self): + super().__init__() self.history: Dict[str, List[ChatMessage]] = defaultdict(list) - def add_message(self, client_id: str, message: ChatMessage): + async def add_message(self, client_id: str, message: ChatMessage): self.history[client_id].append(message) + await self.notify() def get_history(self, client_id: str) -> List[ChatMessage]: return self.history[client_id] @@ -26,6 +34,44 @@ class ChatManager: def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.chat_history = ChatHistory() + self.chat_history.attach(self.on_chat_history_update) + self.cache_manager = cache_manager + self.cache_manager.attach(self.update) + + async def on_chat_history_update(self): + """Send the last chat message to the client.""" + client_id = self.cache_manager.current_client_id + if client_id in self.active_connections: + chat_response = self.chat_history.get_history(client_id)[-1] + if chat_response.sender == "bot": + # Process FileResponse + if isinstance(chat_response, FileResponse): + # If data_type is pandas, convert to csv + if chat_response.data_type == "pandas": + chat_response.data = chat_response.data.to_csv() + elif chat_response.data_type == "image": + # Base64 encode the image + chat_response.data = pil_to_base64(chat_response.data) + + await self.send_json(client_id, chat_response) + + def update(self): + if self.cache_manager.current_client_id in self.active_connections: + self.last_cached_object_dict = self.cache_manager.get_last() + # Add a new ChatResponse with the data + chat_response = FileResponse( + sender="bot", + message=None, + type="file", + data=self.last_cached_object_dict["obj"], + data_type=self.last_cached_object_dict["type"], + ) + + asyncio.create_task( + self.chat_history.add_message( + self.cache_manager.current_client_id, chat_response + ) + ) async def connect(self, client_id: str, websocket: WebSocket): await websocket.accept() @@ -40,7 +86,6 @@ class ChatManager: async def send_json(self, client_id: str, message: ChatMessage): websocket = self.active_connections[client_id] - self.chat_history.add_message(client_id, message) await websocket.send_json(json.dumps(message.dict())) async def process_message(self, client_id: str, payload: Dict): @@ -48,13 +93,13 @@ class ChatManager: chat_message = payload.pop("message", "") chat_message = ChatMessage(sender="you", message=chat_message) - self.chat_history.add_message(client_id, chat_message) + await self.chat_history.add_message(client_id, chat_message) graph_data = payload start_resp = ChatResponse( sender="bot", message=None, type="start", intermediate_steps="" ) - await self.send_json(client_id, start_resp) + await self.chat_history.add_message(client_id, start_resp) is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 # Generate result and thought @@ -80,7 +125,7 @@ class ChatManager: intermediate_steps=intermediate_steps or "", type="end", ) - await self.send_json(client_id, response) + await self.chat_history.add_message(client_id, response) async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) @@ -91,7 +136,8 @@ class ChatManager: while True: json_payload = await websocket.receive_json() payload = json.loads(json_payload) - await self.process_message(client_id, payload) + with self.cache_manager.set_client_id(client_id): + await self.process_message(client_id, payload) except Exception as e: # Handle any exceptions that might occur print(f"Error: {e}") @@ -123,3 +169,10 @@ async def process_graph( # Log stack trace logger.exception(e) raise e + + +def pil_to_base64(image: Image) -> str: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()) + return img_str.decode("utf-8") diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py index fd9ef0816..1aefe5c8e 100644 --- a/src/backend/langflow/api/schemas.py +++ b/src/backend/langflow/api/schemas.py @@ -20,11 +20,23 @@ class ChatResponse(ChatMessage): intermediate_steps: str type: str - data: Any = None - data_type: str = "" @validator("type") def validate_message_type(cls, v): - if v not in ["start", "stream", "end", "error", "info"]: - raise ValueError("type must be start, stream, end, error or info") + if v not in ["start", "stream", "end", "error", "info", "file"]: + raise ValueError("type must be start, stream, end, error, info, or file") + return v + + +class FileResponse(ChatMessage): + """File response schema.""" + + data: Any + data_type: str + type: str = "file" + + @validator("data_type") + def validate_data_type(cls, v): + if v not in ["image", "csv"]: + raise ValueError("data_type must be image or csv") return v diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 41405867f..74e147075 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -32,8 +32,6 @@ def test_chat_history(client: TestClient): "message": None, "intermediate_steps": "", "type": "start", - "data": None, - "data_type": "", } # Send another message payload = {"message": "How are you?"} @@ -46,6 +44,4 @@ def test_chat_history(client: TestClient): "message": "Hello, I'm a mock response!", "intermediate_steps": "", "type": "end", - "data": None, - "data_type": "", } From ebc1f6a0dfef42e389c0db01678cb569d6bad809 Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Sun, 23 Apr 2023 14:31:21 -0300 Subject: [PATCH 9/9] feat(api): add callback handler for streaming LLM responses Add a new file `callback.py` that contains a new class `StreamingLLMCallbackHandler` that inherits from `AsyncCallbackHandler`. This class handles streaming LLM responses. It has a constructor that takes a `websocket` parameter and sets it as an instance variable. It also has an `on_llm_new_token` method that takes a `token` parameter and sends a `ChatResponse` object to the `websocket` instance variable. Update `chat_manager.py` to import the new `StreamingLLMCallbackHandler` class. Add a new function `try_setting_streaming_options` that takes a `langchain_object` and a `websocket` parameter. This function checks if the `llm` attribute of the `langchain_object` is an instance of `OpenAI`, `ChatOpenAI`, `AzureOpenAI`, or `AzureChatOpenAI`. If it is, it sets the --- src/backend/langflow/api/callback.py | 18 ++++++++++++ src/backend/langflow/api/chat.py | 5 ++-- src/backend/langflow/api/chat_manager.py | 36 ++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 src/backend/langflow/api/callback.py diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/callback.py new file mode 100644 index 000000000..47a8d945c --- /dev/null +++ b/src/backend/langflow/api/callback.py @@ -0,0 +1,18 @@ +from typing import Any +from langchain.callbacks.base import AsyncCallbackHandler + +from langflow.api.schemas import ChatResponse + + +# https://github.com/hwchase17/chat-langchain/blob/master/callback.py +class StreamingLLMCallbackHandler(AsyncCallbackHandler): + """Callback handler for streaming LLM responses.""" + + def __init__(self, websocket): + self.websocket = websocket + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + resp = ChatResponse( + sender="bot", message=token, type="stream", intermediate_steps="" + ) + await self.websocket.send_json(resp.dict()) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py index 11b861c77..b2da73d52 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/chat.py @@ -1,5 +1,4 @@ from fastapi import APIRouter, WebSocket -from uuid import uuid4 from langflow.api.chat_manager import ChatManager @@ -7,6 +6,8 @@ router = APIRouter() chat_manager = ChatManager() -@router.websocket("/ws/{client_id}") +@router.websocket("/chat/{client_id}") async def websocket_endpoint(client_id: str, websocket: WebSocket): await chat_manager.handle_websocket(client_id, websocket) + + diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py index 8dcaf05ac..5ce7d2452 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/api/chat_manager.py @@ -5,9 +5,12 @@ from typing import Dict, List from collections import defaultdict from fastapi import WebSocket import json +from langchain.llms import OpenAI, AzureOpenAI +from langchain.chat_models import ChatOpenAI, AzureChatOpenAI from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse from langflow.cache.manager import AsyncSubject - +from langchain.callbacks.base import AsyncCallbackManager +from langflow.api.callback import StreamingLLMCallbackHandler from langflow.interface.run import ( async_get_result_and_steps, load_or_build_langchain_object, @@ -90,7 +93,6 @@ class ChatManager: async def process_message(self, client_id: str, payload: Dict): # Process the graph data and chat message - chat_message = payload.pop("message", "") chat_message = ChatMessage(sender="you", message=chat_message) await self.chat_history.add_message(client_id, chat_message) @@ -105,10 +107,12 @@ class ChatManager: # Generate result and thought try: logger.debug("Generating result and thought") + result, intermediate_steps = await process_graph( graph_data=graph_data, is_first_message=is_first_message, chat_message=chat_message, + websocket=self.active_connections[client_id], ) except Exception as e: # Log stack trace @@ -129,6 +133,7 @@ class ChatManager: async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) + try: chat_history = self.chat_history.get_history(client_id) await websocket.send_json(json.dumps(chat_history)) @@ -146,9 +151,13 @@ class ChatManager: async def process_graph( - graph_data: Dict, is_first_message: bool, chat_message: ChatMessage + graph_data: Dict, + is_first_message: bool, + chat_message: ChatMessage, + websocket: WebSocket, ): langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + langchain_object = try_setting_streaming_options(langchain_object, websocket) logger.debug("Loaded langchain object") if langchain_object is None: @@ -171,6 +180,27 @@ async def process_graph( raise e +def try_setting_streaming_options(langchain_object, websocket): + # If the LLM type is OpenAI or ChatOpenAI, + # set streaming to True + # First we need to find the LLM + llm = None + if hasattr(langchain_object, "llm"): + llm = langchain_object.llm + elif hasattr(langchain_object, "llm_chain") and hasattr( + langchain_object.llm_chain, "llm" + ): + llm = langchain_object.llm_chain.llm + if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)): + llm.streaming = bool(hasattr(llm, "streaming")) + + if hasattr(langchain_object, "callback_manager"): + stream_handler = StreamingLLMCallbackHandler(websocket) + stream_manager = AsyncCallbackManager([stream_handler]) + langchain_object.callback_manager = stream_manager + return langchain_object + + def pil_to_base64(image: Image) -> str: buffered = BytesIO() image.save(buffered, format="PNG")