diff --git a/build_and_push b/build_and_push deleted file mode 100755 index e9c9edf14..000000000 --- a/build_and_push +++ /dev/null @@ -1,11 +0,0 @@ -#! /bin/bash - -cd src/frontend -docker build -t logspace/frontend_build -f build.Dockerfile . -cd ../backend -docker build -t logspace/backend_build -f build.Dockerfile . - -cd ../../ -VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version) -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION . -docker push ibiscp/langflow:$VERSION diff --git a/src/backend/build.Dockerfile b/src/backend/build.Dockerfile deleted file mode 100644 index 42452c691..000000000 --- a/src/backend/build.Dockerfile +++ /dev/null @@ -1,52 +0,0 @@ -# `python-base` sets up all our shared environment variables -FROM python:3.10-slim - -# python -ENV PYTHONUNBUFFERED=1 \ - # prevents python creating .pyc files - PYTHONDONTWRITEBYTECODE=1 \ - \ - # pip - PIP_NO_CACHE_DIR=off \ - PIP_DISABLE_PIP_VERSION_CHECK=on \ - PIP_DEFAULT_TIMEOUT=100 \ - \ - # poetry - # https://python-poetry.org/docs/configuration/#using-environment-variables - POETRY_VERSION=1.4.0 \ - # make poetry install to this location - POETRY_HOME="/opt/poetry" \ - # make poetry create the virtual environment in the project's root - # it gets named `.venv` - POETRY_VIRTUALENVS_IN_PROJECT=true \ - # do not ask any interactive question - POETRY_NO_INTERACTION=1 \ - \ - # paths - # this is where our requirements + virtual environment will live - PYSETUP_PATH="/opt/pysetup" \ - VENV_PATH="/opt/pysetup/.venv" - -# prepend poetry and venv to path -ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" - -RUN apt-get update \ - && apt-get install --no-install-recommends -y \ - # deps for installing poetry - curl \ - # deps for building python deps - build-essential libpq-dev - -# install poetry - respects $POETRY_VERSION & $POETRY_HOME -RUN curl -sSL https://install.python-poetry.org | python3 - - -# copy project requirement files here to ensure they will be cached. -WORKDIR /app -COPY poetry.lock pyproject.toml ./ -COPY langflow/ ./langflow - -# poetry install -RUN poetry install --without dev - -# build wheel -RUN poetry build -f wheel diff --git a/src/backend/build_and_push b/src/backend/build_and_push deleted file mode 100755 index d70366d72..000000000 --- a/src/backend/build_and_push +++ /dev/null @@ -1,6 +0,0 @@ -#! /bin/bash - -docker build -t logspace/backend_build -f build.Dockerfile . -VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version) -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION . -docker push ibiscp/langflow:$VERSION diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py index 084e04d65..8cddc52e4 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/base.py @@ -3,6 +3,10 @@ from pydantic import BaseModel, validator from langflow.graph.utils import extract_input_variables_from_prompt +class CacheResponse(BaseModel): + data: dict + + class Code(BaseModel): code: str diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/callback.py new file mode 100644 index 000000000..47a8d945c --- /dev/null +++ b/src/backend/langflow/api/callback.py @@ -0,0 +1,18 @@ +from typing import Any +from langchain.callbacks.base import AsyncCallbackHandler + +from langflow.api.schemas import ChatResponse + + +# https://github.com/hwchase17/chat-langchain/blob/master/callback.py +class StreamingLLMCallbackHandler(AsyncCallbackHandler): + """Callback handler for streaming LLM responses.""" + + def __init__(self, websocket): + self.websocket = websocket + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + resp = ChatResponse( + sender="bot", message=token, type="stream", intermediate_steps="" + ) + await self.websocket.send_json(resp.dict()) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py new file mode 100644 index 000000000..b2da73d52 --- /dev/null +++ b/src/backend/langflow/api/chat.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter, WebSocket + +from langflow.api.chat_manager import ChatManager + +router = APIRouter() +chat_manager = ChatManager() + + +@router.websocket("/chat/{client_id}") +async def websocket_endpoint(client_id: str, websocket: WebSocket): + await chat_manager.handle_websocket(client_id, websocket) + + diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py new file mode 100644 index 000000000..5ce7d2452 --- /dev/null +++ b/src/backend/langflow/api/chat_manager.py @@ -0,0 +1,208 @@ +import asyncio +import base64 +from io import BytesIO +from typing import Dict, List +from collections import defaultdict +from fastapi import WebSocket +import json +from langchain.llms import OpenAI, AzureOpenAI +from langchain.chat_models import ChatOpenAI, AzureChatOpenAI +from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse +from langflow.cache.manager import AsyncSubject +from langchain.callbacks.base import AsyncCallbackManager +from langflow.api.callback import StreamingLLMCallbackHandler +from langflow.interface.run import ( + async_get_result_and_steps, + load_or_build_langchain_object, +) +from langflow.utils.logger import logger +from langflow.cache import cache_manager +from PIL.Image import Image + + +class ChatHistory(AsyncSubject): + def __init__(self): + super().__init__() + self.history: Dict[str, List[ChatMessage]] = defaultdict(list) + + async def add_message(self, client_id: str, message: ChatMessage): + self.history[client_id].append(message) + await self.notify() + + def get_history(self, client_id: str) -> List[ChatMessage]: + return self.history[client_id] + + +class ChatManager: + def __init__(self): + self.active_connections: Dict[str, WebSocket] = {} + self.chat_history = ChatHistory() + self.chat_history.attach(self.on_chat_history_update) + self.cache_manager = cache_manager + self.cache_manager.attach(self.update) + + async def on_chat_history_update(self): + """Send the last chat message to the client.""" + client_id = self.cache_manager.current_client_id + if client_id in self.active_connections: + chat_response = self.chat_history.get_history(client_id)[-1] + if chat_response.sender == "bot": + # Process FileResponse + if isinstance(chat_response, FileResponse): + # If data_type is pandas, convert to csv + if chat_response.data_type == "pandas": + chat_response.data = chat_response.data.to_csv() + elif chat_response.data_type == "image": + # Base64 encode the image + chat_response.data = pil_to_base64(chat_response.data) + + await self.send_json(client_id, chat_response) + + def update(self): + if self.cache_manager.current_client_id in self.active_connections: + self.last_cached_object_dict = self.cache_manager.get_last() + # Add a new ChatResponse with the data + chat_response = FileResponse( + sender="bot", + message=None, + type="file", + data=self.last_cached_object_dict["obj"], + data_type=self.last_cached_object_dict["type"], + ) + + asyncio.create_task( + self.chat_history.add_message( + self.cache_manager.current_client_id, chat_response + ) + ) + + async def connect(self, client_id: str, websocket: WebSocket): + await websocket.accept() + self.active_connections[client_id] = websocket + + def disconnect(self, client_id: str): + del self.active_connections[client_id] + + async def send_message(self, client_id: str, message: str): + websocket = self.active_connections[client_id] + await websocket.send_text(message) + + async def send_json(self, client_id: str, message: ChatMessage): + websocket = self.active_connections[client_id] + await websocket.send_json(json.dumps(message.dict())) + + async def process_message(self, client_id: str, payload: Dict): + # Process the graph data and chat message + chat_message = payload.pop("message", "") + chat_message = ChatMessage(sender="you", message=chat_message) + await self.chat_history.add_message(client_id, chat_message) + + graph_data = payload + start_resp = ChatResponse( + sender="bot", message=None, type="start", intermediate_steps="" + ) + await self.chat_history.add_message(client_id, start_resp) + + is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0 + # Generate result and thought + try: + logger.debug("Generating result and thought") + + result, intermediate_steps = await process_graph( + graph_data=graph_data, + is_first_message=is_first_message, + chat_message=chat_message, + websocket=self.active_connections[client_id], + ) + except Exception as e: + # Log stack trace + logger.exception(e) + error_resp = ChatResponse( + sender="bot", message=str(e), type="error", intermediate_steps="" + ) + await self.send_json(client_id, error_resp) + return + # Send a response back to the frontend, if needed + response = ChatResponse( + sender="bot", + message=result or "", + intermediate_steps=intermediate_steps or "", + type="end", + ) + await self.chat_history.add_message(client_id, response) + + async def handle_websocket(self, client_id: str, websocket: WebSocket): + await self.connect(client_id, websocket) + + try: + chat_history = self.chat_history.get_history(client_id) + await websocket.send_json(json.dumps(chat_history)) + + while True: + json_payload = await websocket.receive_json() + payload = json.loads(json_payload) + with self.cache_manager.set_client_id(client_id): + await self.process_message(client_id, payload) + except Exception as e: + # Handle any exceptions that might occur + print(f"Error: {e}") + finally: + self.disconnect(client_id) + + +async def process_graph( + graph_data: Dict, + is_first_message: bool, + chat_message: ChatMessage, + websocket: WebSocket, +): + langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + langchain_object = try_setting_streaming_options(langchain_object, websocket) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + try: + logger.debug("Generating result and thought") + result, intermediate_steps = await async_get_result_and_steps( + langchain_object, chat_message.message or "" + ) + logger.debug("Generated result and intermediate_steps") + return result, intermediate_steps + except Exception as e: + # Log stack trace + logger.exception(e) + raise e + + +def try_setting_streaming_options(langchain_object, websocket): + # If the LLM type is OpenAI or ChatOpenAI, + # set streaming to True + # First we need to find the LLM + llm = None + if hasattr(langchain_object, "llm"): + llm = langchain_object.llm + elif hasattr(langchain_object, "llm_chain") and hasattr( + langchain_object.llm_chain, "llm" + ): + llm = langchain_object.llm_chain.llm + if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)): + llm.streaming = bool(hasattr(llm, "streaming")) + + if hasattr(langchain_object, "callback_manager"): + stream_handler = StreamingLLMCallbackHandler(websocket) + stream_manager = AsyncCallbackManager([stream_handler]) + langchain_object.callback_manager = stream_manager + return langchain_object + + +def pil_to_base64(image: Image) -> str: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()) + return img_str.decode("utf-8") diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py new file mode 100644 index 000000000..1aefe5c8e --- /dev/null +++ b/src/backend/langflow/api/schemas.py @@ -0,0 +1,42 @@ +from typing import Any, Union +from pydantic import BaseModel, validator + + +class ChatMessage(BaseModel): + """Chat message schema.""" + + sender: str + message: Union[str, None] = None + + @validator("sender") + def sender_must_be_bot_or_you(cls, v): + if v not in ["bot", "you"]: + raise ValueError("sender must be bot or you") + return v + + +class ChatResponse(ChatMessage): + """Chat response schema.""" + + intermediate_steps: str + type: str + + @validator("type") + def validate_message_type(cls, v): + if v not in ["start", "stream", "end", "error", "info", "file"]: + raise ValueError("type must be start, stream, end, error, info, or file") + return v + + +class FileResponse(ChatMessage): + """File response schema.""" + + data: Any + data_type: str + type: str = "file" + + @validator("data_type") + def validate_data_type(cls, v): + if v not in ["image", "csv"]: + raise ValueError("data_type must be image or csv") + return v diff --git a/src/backend/langflow/cache/__init__.py b/src/backend/langflow/cache/__init__.py index e69de29bb..583d5ac6d 100644 --- a/src/backend/langflow/cache/__init__.py +++ b/src/backend/langflow/cache/__init__.py @@ -0,0 +1 @@ +from langflow.cache.manager import cache_manager # noqa diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/base.py similarity index 97% rename from src/backend/langflow/cache/utils.py rename to src/backend/langflow/cache/base.py index 310f3be80..ba250da6b 100644 --- a/src/backend/langflow/cache/utils.py +++ b/src/backend/langflow/cache/base.py @@ -2,13 +2,18 @@ import base64 import contextlib import functools import hashlib + import json import os import tempfile from collections import OrderedDict from pathlib import Path +from typing import Any +from PIL import Image +import dill +import pandas as pd # type: ignore -import dill # type: ignore +CACHE = {} def create_cache_folder(func): diff --git a/src/backend/langflow/cache/manager.py b/src/backend/langflow/cache/manager.py new file mode 100644 index 000000000..ba34a3a8d --- /dev/null +++ b/src/backend/langflow/cache/manager.py @@ -0,0 +1,126 @@ +from contextlib import contextmanager +from typing import Any, Awaitable, Callable, List +from PIL import Image +import pandas as pd + + +class Subject: + """Base class for implementing the observer pattern.""" + + def __init__(self): + self.observers: List[Callable[[], None]] = [] + + def attach(self, observer: Callable[[], None]): + """Attach an observer to the subject.""" + self.observers.append(observer) + + def detach(self, observer: Callable[[], None]): + """Detach an observer from the subject.""" + self.observers.remove(observer) + + def notify(self): + """Notify all observers about an event.""" + for observer in self.observers: + if observer is None: + continue + observer() + + +class AsyncSubject: + """Base class for implementing the async observer pattern.""" + + def __init__(self): + self.observers: List[Callable[[], Awaitable]] = [] + + def attach(self, observer: Callable[[], Awaitable]): + """Attach an observer to the subject.""" + self.observers.append(observer) + + def detach(self, observer: Callable[[], Awaitable]): + """Detach an observer from the subject.""" + self.observers.remove(observer) + + async def notify(self): + """Notify all observers about an event.""" + for observer in self.observers: + if observer is None: + continue + await observer() + + +class CacheManager(Subject): + """Manages cache for different clients and notifies observers on changes.""" + + def __init__(self): + super().__init__() + self.CACHE = {} + self.current_client_id = None + + @contextmanager + def set_client_id(self, client_id: str): + """ + Context manager to set the current client_id and associated cache. + + Args: + client_id (str): The client identifier. + """ + previous_client_id = self.current_client_id + self.current_client_id = client_id + self.current_cache = self.CACHE.setdefault(client_id, {}) + try: + yield + finally: + self.current_client_id = previous_client_id + self.current_cache = self.CACHE.get(self.current_client_id, {}) + + def add_pandas(self, name: str, obj: Any): + """ + Add a pandas DataFrame or Series to the current client's cache. + + Args: + name (str): The cache key. + obj (Any): The pandas DataFrame or Series object. + """ + if isinstance(obj, (pd.DataFrame, pd.Series)): + self.current_cache[name] = {"obj": obj, "type": "pandas"} + self.notify() + else: + raise ValueError("Object is not a pandas DataFrame or Series") + + def add_image(self, name: str, obj: Any): + """ + Add a PIL Image to the current client's cache. + + Args: + name (str): The cache key. + obj (Any): The PIL Image object. + """ + if isinstance(obj, Image.Image): + self.current_cache[name] = {"obj": obj, "type": "image"} + self.notify() + else: + raise ValueError("Object is not a PIL Image") + + def get(self, name: str): + """ + Get an object from the current client's cache. + + Args: + name (str): The cache key. + + Returns: + The cached object associated with the given cache key. + """ + return self.current_cache[name] + + def get_last(self): + """ + Get the last added item in the current client's cache. + + Returns: + The last added item in the cache. + """ + return list(self.current_cache.values())[-1] + + +cache_manager = CacheManager() diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index 6d998eed6..7bdbd3b8c 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -9,7 +9,7 @@ import warnings from copy import deepcopy from typing import Any, Dict, List, Optional -from langflow.cache import utils as cache_utils +from langflow.cache import base as cache_utils from langflow.graph.constants import DIRECT_TYPES from langflow.interface import loading from langflow.interface.listing import ALL_TYPES_DICT diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index deba28586..5fb4f0045 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -3,7 +3,7 @@ import io from typing import Any, Dict from chromadb.errors import NotEnoughElementsException -from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict +from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict from langflow.graph.graph import Graph from langflow.interface import loading from langflow.utils.logger import logger @@ -32,7 +32,7 @@ def load_or_build_langchain_object(data_graph, is_first_message=False): return build_langchain_object_with_caching(data_graph) -@memoize_dict(maxsize=1) +@memoize_dict(maxsize=10) def build_langchain_object_with_caching(data_graph): """ Build langchain object from data_graph. @@ -87,7 +87,7 @@ def process_graph(data_graph: Dict[str, Any]): # Generate result and thought logger.debug("Generating result and thought") - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) logger.debug("Generated result and thought") # Save langchain_object to cache @@ -118,7 +118,7 @@ def process_graph_cached(data_graph: Dict[str, Any]): # Generate result and thought logger.debug("Generating result and thought") - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) logger.debug("Generated result and thought") return {"result": str(result), "thought": thought.strip()} @@ -184,7 +184,7 @@ def fix_memory_inputs(langchain_object): update_memory_keys(langchain_object, possible_new_mem_key) -def get_result_and_thought_using_graph(langchain_object, message: str): +def get_result_and_steps(langchain_object, message: str): """Get result and thought from extracted json""" try: if hasattr(langchain_object, "verbose"): @@ -240,6 +240,61 @@ def get_result_and_thought_using_graph(langchain_object, message: str): return result, thought +async def async_get_result_and_steps(langchain_object, message: str): + """Get result and thought from extracted json""" + try: + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + chat_input = None + memory_key = "" + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + memory_key = langchain_object.memory.memory_key + + if hasattr(langchain_object, "input_keys"): + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: + chat_input = {key: message} + else: + chat_input = message # type: ignore + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + + fix_memory_inputs(langchain_object) + + with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): + try: + if hasattr(langchain_object, "acall"): + output = await langchain_object.acall(chat_input) + else: + output = langchain_object(chat_input) + except ValueError as exc: + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + output = langchain_object.run(chat_input) + + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) + + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) + if intermediate_steps: + thought = format_intermediate_steps(intermediate_steps) + else: + thought = output_buffer.getvalue() + + except Exception as exc: + raise ValueError(f"Error: {str(exc)}") from exc + return result, thought + + def get_result_and_thought(extracted_json: Dict[str, Any], message: str): """Get result and thought from extracted json""" try: diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 176e46236..c1f2decd5 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from langflow.api.endpoints import router as endpoints_router from langflow.api.validate import router as validate_router +from langflow.api.chat import router as chat_router def create_app(): @@ -23,6 +24,7 @@ def create_app(): app.include_router(endpoints_router) app.include_router(validate_router) + app.include_router(chat_router) return app diff --git a/src/backend/run b/src/backend/run deleted file mode 100755 index c9ba24768..000000000 --- a/src/backend/run +++ /dev/null @@ -1,8 +0,0 @@ -#! /bin/bash - -poetry remove langchain -docker build -t logspace/backend_build -f build.Dockerfile . -VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version) -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION . -docker run -p 5003:80 -d ibiscp/langflow:$VERSION -poetry add --editable ../../../langchain diff --git a/src/frontend/build.Dockerfile b/src/frontend/build.Dockerfile deleted file mode 100644 index af2335b5c..000000000 --- a/src/frontend/build.Dockerfile +++ /dev/null @@ -1,5 +0,0 @@ -FROM node:14-alpine -WORKDIR /app -COPY . /app -RUN npm install -RUN npm run build \ No newline at end of file diff --git a/src/frontend/build_and_push b/src/frontend/build_and_push deleted file mode 100755 index f223dd343..000000000 --- a/src/frontend/build_and_push +++ /dev/null @@ -1,11 +0,0 @@ -#! /bin/bash - -# Read the contents of the JSON file -json=$(cat package.json) - -# Extract the value of the "version" field using jq -VERSION=$(echo "$json" | jq -r '.version') - -docker build -t logspace/frontend_build -f build.Dockerfile . -docker build --build-arg VERSION=$VERSION -t ibiscp/langflow_frontend:$VERSION . -docker push ibiscp/langflow_frontend:$VERSION diff --git a/tests/conftest.py b/tests/conftest.py index e6eb3562f..15da0d1ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,8 @@ +import json from pathlib import Path +from typing import AsyncGenerator +from httpx import AsyncClient + import pytest from fastapi.testclient import TestClient @@ -21,6 +25,15 @@ def get_text(): """ +@pytest.fixture() +async def async_client() -> AsyncGenerator: + from langflow.main import create_app + + app = create_app() + async with AsyncClient(app=app, base_url="http://testserver") as client: + yield client + + # Create client fixture for FastAPI @pytest.fixture(scope="module") def client(): @@ -30,3 +43,37 @@ def client(): with TestClient(app) as client: yield client + + +def get_graph(_type="basic"): + """Get a graph from a json file""" + from langflow.graph.graph import Graph + + if _type == "basic": + path = pytest.BASIC_EXAMPLE_PATH + elif _type == "complex": + path = pytest.COMPLEX_EXAMPLE_PATH + elif _type == "openapi": + path = pytest.OPENAPI_EXAMPLE_PATH + + with open(path, "r") as f: + flow_graph = json.load(f) + data_graph = flow_graph["data"] + nodes = data_graph["nodes"] + edges = data_graph["edges"] + return Graph(nodes, edges) + + +@pytest.fixture +def basic_graph(): + return get_graph() + + +@pytest.fixture +def complex_graph(): + return get_graph("complex") + + +@pytest.fixture +def openapi_graph(): + return get_graph("openapi") diff --git a/tests/test_cache.py b/tests/test_cache.py index 9c6ad30e3..131e015f3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -3,7 +3,7 @@ import tempfile from pathlib import Path import pytest -from langflow.cache.utils import PREFIX, save_cache +from langflow.cache.base import PREFIX, save_cache from langflow.interface.run import load_langchain_object diff --git a/tests/test_graph.py b/tests/test_graph.py index c007bdd78..76451fe6a 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -15,7 +15,7 @@ from langflow.graph.nodes import ( ToolNode, WrapperNode, ) -from langflow.interface.run import get_result_and_thought_using_graph +from langflow.interface.run import get_result_and_steps from langflow.utils.payload import build_json, get_root_node # Test cases for the graph module @@ -24,38 +24,6 @@ from langflow.utils.payload import build_json, get_root_node # BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH -def get_graph(_type="basic"): - """Get a graph from a json file""" - if _type == "basic": - path = pytest.BASIC_EXAMPLE_PATH - elif _type == "complex": - path = pytest.COMPLEX_EXAMPLE_PATH - elif _type == "openapi": - path = pytest.OPENAPI_EXAMPLE_PATH - - with open(path, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - edges = data_graph["edges"] - return Graph(nodes, edges) - - -@pytest.fixture -def basic_graph(): - return get_graph() - - -@pytest.fixture -def complex_graph(): - return get_graph("complex") - - -@pytest.fixture -def openapi_graph(): - return get_graph("openapi") - - def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]: """Get a node by type""" return next((node for node in graph.nodes if isinstance(node, node_type)), None) @@ -441,7 +409,7 @@ def test_get_result_and_thought(basic_graph): # now build again and check if FakeListLLM was used # Get the result and thought - result, thought = get_result_and_thought_using_graph(langchain_object, message) + result, thought = get_result_and_steps(langchain_object, message) # The result should be a str assert isinstance(result, str) # The thought should be a Thought diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 000000000..74e147075 --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,47 @@ +import json +from unittest.mock import patch +from langflow.api.schemas import ChatMessage +from fastapi.testclient import TestClient + + +def test_websocket_connection(client: TestClient): + with client.websocket_connect("/ws/test_client") as websocket: + assert websocket.scope["client"] == ["testclient", 50000] + assert websocket.scope["path"] == "/ws/test_client" + + +def test_chat_history(client: TestClient): + chat_history = [] + + # Mock the process_graph function to return a specific value + with patch("langflow.api.chat_manager.process_graph") as mock_process_graph: + mock_process_graph.return_value = ("Hello, I'm a mock response!", "") + + with client.websocket_connect("/ws/test_client") as websocket: + # First message should be the history + history = websocket.receive_json() + assert json.loads(history) == [] # Empty history + # Send a message + payload = {"message": "Hello"} + websocket.send_json(json.dumps(payload)) + + # Receive the response from the server + response = websocket.receive_json() + assert json.loads(response) == { + "sender": "bot", + "message": None, + "intermediate_steps": "", + "type": "start", + } + # Send another message + payload = {"message": "How are you?"} + websocket.send_json(json.dumps(payload)) + + # Receive the response from the server + response = websocket.receive_json() + assert json.loads(response) == { + "sender": "bot", + "message": "Hello, I'm a mock response!", + "intermediate_steps": "", + "type": "end", + }