Merge remote-tracking branch 'origin/chat_and_cache' into websocket

This commit is contained in:
anovazzi1 2023-04-25 15:42:06 -03:00
commit 7f2ad60a35
21 changed files with 578 additions and 135 deletions

View file

@ -1,11 +0,0 @@
#! /bin/bash
cd src/frontend
docker build -t logspace/frontend_build -f build.Dockerfile .
cd ../backend
docker build -t logspace/backend_build -f build.Dockerfile .
cd ../../
VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version)
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION .
docker push ibiscp/langflow:$VERSION

View file

@ -1,52 +0,0 @@
# `python-base` sets up all our shared environment variables
FROM python:3.10-slim
# python
ENV PYTHONUNBUFFERED=1 \
# prevents python creating .pyc files
PYTHONDONTWRITEBYTECODE=1 \
\
# pip
PIP_NO_CACHE_DIR=off \
PIP_DISABLE_PIP_VERSION_CHECK=on \
PIP_DEFAULT_TIMEOUT=100 \
\
# poetry
# https://python-poetry.org/docs/configuration/#using-environment-variables
POETRY_VERSION=1.4.0 \
# make poetry install to this location
POETRY_HOME="/opt/poetry" \
# make poetry create the virtual environment in the project's root
# it gets named `.venv`
POETRY_VIRTUALENVS_IN_PROJECT=true \
# do not ask any interactive question
POETRY_NO_INTERACTION=1 \
\
# paths
# this is where our requirements + virtual environment will live
PYSETUP_PATH="/opt/pysetup" \
VENV_PATH="/opt/pysetup/.venv"
# prepend poetry and venv to path
ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH"
RUN apt-get update \
&& apt-get install --no-install-recommends -y \
# deps for installing poetry
curl \
# deps for building python deps
build-essential libpq-dev
# install poetry - respects $POETRY_VERSION & $POETRY_HOME
RUN curl -sSL https://install.python-poetry.org | python3 -
# copy project requirement files here to ensure they will be cached.
WORKDIR /app
COPY poetry.lock pyproject.toml ./
COPY langflow/ ./langflow
# poetry install
RUN poetry install --without dev
# build wheel
RUN poetry build -f wheel

View file

@ -1,6 +0,0 @@
#! /bin/bash
docker build -t logspace/backend_build -f build.Dockerfile .
VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version)
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION .
docker push ibiscp/langflow:$VERSION

View file

@ -3,6 +3,10 @@ from pydantic import BaseModel, validator
from langflow.graph.utils import extract_input_variables_from_prompt
class CacheResponse(BaseModel):
data: dict
class Code(BaseModel):
code: str

View file

@ -0,0 +1,18 @@
from typing import Any
from langchain.callbacks.base import AsyncCallbackHandler
from langflow.api.schemas import ChatResponse
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
"""Callback handler for streaming LLM responses."""
def __init__(self, websocket):
self.websocket = websocket
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
resp = ChatResponse(
sender="bot", message=token, type="stream", intermediate_steps=""
)
await self.websocket.send_json(resp.dict())

View file

@ -0,0 +1,13 @@
from fastapi import APIRouter, WebSocket
from langflow.api.chat_manager import ChatManager
router = APIRouter()
chat_manager = ChatManager()
@router.websocket("/chat/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
await chat_manager.handle_websocket(client_id, websocket)

View file

@ -0,0 +1,208 @@
import asyncio
import base64
from io import BytesIO
from typing import Dict, List
from collections import defaultdict
from fastapi import WebSocket
import json
from langchain.llms import OpenAI, AzureOpenAI
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.cache.manager import AsyncSubject
from langchain.callbacks.base import AsyncCallbackManager
from langflow.api.callback import StreamingLLMCallbackHandler
from langflow.interface.run import (
async_get_result_and_steps,
load_or_build_langchain_object,
)
from langflow.utils.logger import logger
from langflow.cache import cache_manager
from PIL.Image import Image
class ChatHistory(AsyncSubject):
def __init__(self):
super().__init__()
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
async def add_message(self, client_id: str, message: ChatMessage):
self.history[client_id].append(message)
await self.notify()
def get_history(self, client_id: str) -> List[ChatMessage]:
return self.history[client_id]
class ChatManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.chat_history = ChatHistory()
self.chat_history.attach(self.on_chat_history_update)
self.cache_manager = cache_manager
self.cache_manager.attach(self.update)
async def on_chat_history_update(self):
"""Send the last chat message to the client."""
client_id = self.cache_manager.current_client_id
if client_id in self.active_connections:
chat_response = self.chat_history.get_history(client_id)[-1]
if chat_response.sender == "bot":
# Process FileResponse
if isinstance(chat_response, FileResponse):
# If data_type is pandas, convert to csv
if chat_response.data_type == "pandas":
chat_response.data = chat_response.data.to_csv()
elif chat_response.data_type == "image":
# Base64 encode the image
chat_response.data = pil_to_base64(chat_response.data)
await self.send_json(client_id, chat_response)
def update(self):
if self.cache_manager.current_client_id in self.active_connections:
self.last_cached_object_dict = self.cache_manager.get_last()
# Add a new ChatResponse with the data
chat_response = FileResponse(
sender="bot",
message=None,
type="file",
data=self.last_cached_object_dict["obj"],
data_type=self.last_cached_object_dict["type"],
)
asyncio.create_task(
self.chat_history.add_message(
self.cache_manager.current_client_id, chat_response
)
)
async def connect(self, client_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
del self.active_connections[client_id]
async def send_message(self, client_id: str, message: str):
websocket = self.active_connections[client_id]
await websocket.send_text(message)
async def send_json(self, client_id: str, message: ChatMessage):
websocket = self.active_connections[client_id]
await websocket.send_json(json.dumps(message.dict()))
async def process_message(self, client_id: str, payload: Dict):
# Process the graph data and chat message
chat_message = payload.pop("message", "")
chat_message = ChatMessage(sender="you", message=chat_message)
await self.chat_history.add_message(client_id, chat_message)
graph_data = payload
start_resp = ChatResponse(
sender="bot", message=None, type="start", intermediate_steps=""
)
await self.chat_history.add_message(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await process_graph(
graph_data=graph_data,
is_first_message=is_first_message,
chat_message=chat_message,
websocket=self.active_connections[client_id],
)
except Exception as e:
# Log stack trace
logger.exception(e)
error_resp = ChatResponse(
sender="bot", message=str(e), type="error", intermediate_steps=""
)
await self.send_json(client_id, error_resp)
return
# Send a response back to the frontend, if needed
response = ChatResponse(
sender="bot",
message=result or "",
intermediate_steps=intermediate_steps or "",
type="end",
)
await self.chat_history.add_message(client_id, response)
async def handle_websocket(self, client_id: str, websocket: WebSocket):
await self.connect(client_id, websocket)
try:
chat_history = self.chat_history.get_history(client_id)
await websocket.send_json(json.dumps(chat_history))
while True:
json_payload = await websocket.receive_json()
payload = json.loads(json_payload)
with self.cache_manager.set_client_id(client_id):
await self.process_message(client_id, payload)
except Exception as e:
# Handle any exceptions that might occur
print(f"Error: {e}")
finally:
self.disconnect(client_id)
async def process_graph(
graph_data: Dict,
is_first_message: bool,
chat_message: ChatMessage,
websocket: WebSocket,
):
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
langchain_object = try_setting_streaming_options(langchain_object, websocket)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await async_get_result_and_steps(
langchain_object, chat_message.message or ""
)
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps
except Exception as e:
# Log stack trace
logger.exception(e)
raise e
def try_setting_streaming_options(langchain_object, websocket):
# If the LLM type is OpenAI or ChatOpenAI,
# set streaming to True
# First we need to find the LLM
llm = None
if hasattr(langchain_object, "llm"):
llm = langchain_object.llm
elif hasattr(langchain_object, "llm_chain") and hasattr(
langchain_object.llm_chain, "llm"
):
llm = langchain_object.llm_chain.llm
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
llm.streaming = bool(hasattr(llm, "streaming"))
if hasattr(langchain_object, "callback_manager"):
stream_handler = StreamingLLMCallbackHandler(websocket)
stream_manager = AsyncCallbackManager([stream_handler])
langchain_object.callback_manager = stream_manager
return langchain_object
def pil_to_base64(image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")

View file

@ -0,0 +1,42 @@
from typing import Any, Union
from pydantic import BaseModel, validator
class ChatMessage(BaseModel):
"""Chat message schema."""
sender: str
message: Union[str, None] = None
@validator("sender")
def sender_must_be_bot_or_you(cls, v):
if v not in ["bot", "you"]:
raise ValueError("sender must be bot or you")
return v
class ChatResponse(ChatMessage):
"""Chat response schema."""
intermediate_steps: str
type: str
@validator("type")
def validate_message_type(cls, v):
if v not in ["start", "stream", "end", "error", "info", "file"]:
raise ValueError("type must be start, stream, end, error, info, or file")
return v
class FileResponse(ChatMessage):
"""File response schema."""
data: Any
data_type: str
type: str = "file"
@validator("data_type")
def validate_data_type(cls, v):
if v not in ["image", "csv"]:
raise ValueError("data_type must be image or csv")
return v

View file

@ -0,0 +1 @@
from langflow.cache.manager import cache_manager # noqa

View file

@ -2,13 +2,18 @@ import base64
import contextlib
import functools
import hashlib
import json
import os
import tempfile
from collections import OrderedDict
from pathlib import Path
from typing import Any
from PIL import Image
import dill
import pandas as pd # type: ignore
import dill # type: ignore
CACHE = {}
def create_cache_folder(func):

126
src/backend/langflow/cache/manager.py vendored Normal file
View file

@ -0,0 +1,126 @@
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, List
from PIL import Image
import pandas as pd
class Subject:
"""Base class for implementing the observer pattern."""
def __init__(self):
self.observers: List[Callable[[], None]] = []
def attach(self, observer: Callable[[], None]):
"""Attach an observer to the subject."""
self.observers.append(observer)
def detach(self, observer: Callable[[], None]):
"""Detach an observer from the subject."""
self.observers.remove(observer)
def notify(self):
"""Notify all observers about an event."""
for observer in self.observers:
if observer is None:
continue
observer()
class AsyncSubject:
"""Base class for implementing the async observer pattern."""
def __init__(self):
self.observers: List[Callable[[], Awaitable]] = []
def attach(self, observer: Callable[[], Awaitable]):
"""Attach an observer to the subject."""
self.observers.append(observer)
def detach(self, observer: Callable[[], Awaitable]):
"""Detach an observer from the subject."""
self.observers.remove(observer)
async def notify(self):
"""Notify all observers about an event."""
for observer in self.observers:
if observer is None:
continue
await observer()
class CacheManager(Subject):
"""Manages cache for different clients and notifies observers on changes."""
def __init__(self):
super().__init__()
self.CACHE = {}
self.current_client_id = None
@contextmanager
def set_client_id(self, client_id: str):
"""
Context manager to set the current client_id and associated cache.
Args:
client_id (str): The client identifier.
"""
previous_client_id = self.current_client_id
self.current_client_id = client_id
self.current_cache = self.CACHE.setdefault(client_id, {})
try:
yield
finally:
self.current_client_id = previous_client_id
self.current_cache = self.CACHE.get(self.current_client_id, {})
def add_pandas(self, name: str, obj: Any):
"""
Add a pandas DataFrame or Series to the current client's cache.
Args:
name (str): The cache key.
obj (Any): The pandas DataFrame or Series object.
"""
if isinstance(obj, (pd.DataFrame, pd.Series)):
self.current_cache[name] = {"obj": obj, "type": "pandas"}
self.notify()
else:
raise ValueError("Object is not a pandas DataFrame or Series")
def add_image(self, name: str, obj: Any):
"""
Add a PIL Image to the current client's cache.
Args:
name (str): The cache key.
obj (Any): The PIL Image object.
"""
if isinstance(obj, Image.Image):
self.current_cache[name] = {"obj": obj, "type": "image"}
self.notify()
else:
raise ValueError("Object is not a PIL Image")
def get(self, name: str):
"""
Get an object from the current client's cache.
Args:
name (str): The cache key.
Returns:
The cached object associated with the given cache key.
"""
return self.current_cache[name]
def get_last(self):
"""
Get the last added item in the current client's cache.
Returns:
The last added item in the cache.
"""
return list(self.current_cache.values())[-1]
cache_manager = CacheManager()

View file

@ -9,7 +9,7 @@ import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional
from langflow.cache import utils as cache_utils
from langflow.cache import base as cache_utils
from langflow.graph.constants import DIRECT_TYPES
from langflow.interface import loading
from langflow.interface.listing import ALL_TYPES_DICT

View file

@ -3,7 +3,7 @@ import io
from typing import Any, Dict
from chromadb.errors import NotEnoughElementsException
from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
from langflow.graph.graph import Graph
from langflow.interface import loading
from langflow.utils.logger import logger
@ -32,7 +32,7 @@ def load_or_build_langchain_object(data_graph, is_first_message=False):
return build_langchain_object_with_caching(data_graph)
@memoize_dict(maxsize=1)
@memoize_dict(maxsize=10)
def build_langchain_object_with_caching(data_graph):
"""
Build langchain object from data_graph.
@ -87,7 +87,7 @@ def process_graph(data_graph: Dict[str, Any]):
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_thought_using_graph(langchain_object, message)
result, thought = get_result_and_steps(langchain_object, message)
logger.debug("Generated result and thought")
# Save langchain_object to cache
@ -118,7 +118,7 @@ def process_graph_cached(data_graph: Dict[str, Any]):
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_thought_using_graph(langchain_object, message)
result, thought = get_result_and_steps(langchain_object, message)
logger.debug("Generated result and thought")
return {"result": str(result), "thought": thought.strip()}
@ -184,7 +184,7 @@ def fix_memory_inputs(langchain_object):
update_memory_keys(langchain_object, possible_new_mem_key)
def get_result_and_thought_using_graph(langchain_object, message: str):
def get_result_and_steps(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
@ -240,6 +240,61 @@ def get_result_and_thought_using_graph(langchain_object, message: str):
return result, thought
async def async_get_result_and_steps(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
chat_input = None
memory_key = ""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
memory_key = langchain_object.memory.memory_key
if hasattr(langchain_object, "input_keys"):
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
else:
chat_input = message # type: ignore
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
try:
if hasattr(langchain_object, "acall"):
output = await langchain_object.acall(chat_input)
else:
output = langchain_object(chat_input)
except ValueError as exc:
# make the error message more informative
logger.debug(f"Error: {str(exc)}")
output = langchain_object.run(chat_input)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
if intermediate_steps:
thought = format_intermediate_steps(intermediate_steps)
else:
thought = output_buffer.getvalue()
except Exception as exc:
raise ValueError(f"Error: {str(exc)}") from exc
return result, thought
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
"""Get result and thought from extracted json"""
try:

View file

@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
from langflow.api.endpoints import router as endpoints_router
from langflow.api.validate import router as validate_router
from langflow.api.chat import router as chat_router
def create_app():
@ -23,6 +24,7 @@ def create_app():
app.include_router(endpoints_router)
app.include_router(validate_router)
app.include_router(chat_router)
return app

View file

@ -1,8 +0,0 @@
#! /bin/bash
poetry remove langchain
docker build -t logspace/backend_build -f build.Dockerfile .
VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version)
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION .
docker run -p 5003:80 -d ibiscp/langflow:$VERSION
poetry add --editable ../../../langchain

View file

@ -1,5 +0,0 @@
FROM node:14-alpine
WORKDIR /app
COPY . /app
RUN npm install
RUN npm run build

View file

@ -1,11 +0,0 @@
#! /bin/bash
# Read the contents of the JSON file
json=$(cat package.json)
# Extract the value of the "version" field using jq
VERSION=$(echo "$json" | jq -r '.version')
docker build -t logspace/frontend_build -f build.Dockerfile .
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow_frontend:$VERSION .
docker push ibiscp/langflow_frontend:$VERSION

View file

@ -1,4 +1,8 @@
import json
from pathlib import Path
from typing import AsyncGenerator
from httpx import AsyncClient
import pytest
from fastapi.testclient import TestClient
@ -21,6 +25,15 @@ def get_text():
"""
@pytest.fixture()
async def async_client() -> AsyncGenerator:
from langflow.main import create_app
app = create_app()
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client
# Create client fixture for FastAPI
@pytest.fixture(scope="module")
def client():
@ -30,3 +43,37 @@ def client():
with TestClient(app) as client:
yield client
def get_graph(_type="basic"):
"""Get a graph from a json file"""
from langflow.graph.graph import Graph
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
@pytest.fixture
def basic_graph():
return get_graph()
@pytest.fixture
def complex_graph():
return get_graph("complex")
@pytest.fixture
def openapi_graph():
return get_graph("openapi")

View file

@ -3,7 +3,7 @@ import tempfile
from pathlib import Path
import pytest
from langflow.cache.utils import PREFIX, save_cache
from langflow.cache.base import PREFIX, save_cache
from langflow.interface.run import load_langchain_object

View file

@ -15,7 +15,7 @@ from langflow.graph.nodes import (
ToolNode,
WrapperNode,
)
from langflow.interface.run import get_result_and_thought_using_graph
from langflow.interface.run import get_result_and_steps
from langflow.utils.payload import build_json, get_root_node
# Test cases for the graph module
@ -24,38 +24,6 @@ from langflow.utils.payload import build_json, get_root_node
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
def get_graph(_type="basic"):
"""Get a graph from a json file"""
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
return Graph(nodes, edges)
@pytest.fixture
def basic_graph():
return get_graph()
@pytest.fixture
def complex_graph():
return get_graph("complex")
@pytest.fixture
def openapi_graph():
return get_graph("openapi")
def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]:
"""Get a node by type"""
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
@ -441,7 +409,7 @@ def test_get_result_and_thought(basic_graph):
# now build again and check if FakeListLLM was used
# Get the result and thought
result, thought = get_result_and_thought_using_graph(langchain_object, message)
result, thought = get_result_and_steps(langchain_object, message)
# The result should be a str
assert isinstance(result, str)
# The thought should be a Thought

47
tests/test_websocket.py Normal file
View file

@ -0,0 +1,47 @@
import json
from unittest.mock import patch
from langflow.api.schemas import ChatMessage
from fastapi.testclient import TestClient
def test_websocket_connection(client: TestClient):
with client.websocket_connect("/ws/test_client") as websocket:
assert websocket.scope["client"] == ["testclient", 50000]
assert websocket.scope["path"] == "/ws/test_client"
def test_chat_history(client: TestClient):
chat_history = []
# Mock the process_graph function to return a specific value
with patch("langflow.api.chat_manager.process_graph") as mock_process_graph:
mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
with client.websocket_connect("/ws/test_client") as websocket:
# First message should be the history
history = websocket.receive_json()
assert json.loads(history) == [] # Empty history
# Send a message
payload = {"message": "Hello"}
websocket.send_json(json.dumps(payload))
# Receive the response from the server
response = websocket.receive_json()
assert json.loads(response) == {
"sender": "bot",
"message": None,
"intermediate_steps": "",
"type": "start",
}
# Send another message
payload = {"message": "How are you?"}
websocket.send_json(json.dumps(payload))
# Receive the response from the server
response = websocket.receive_json()
assert json.loads(response) == {
"sender": "bot",
"message": "Hello, I'm a mock response!",
"intermediate_steps": "",
"type": "end",
}