Merge remote-tracking branch 'origin/chat_and_cache' into websocket
This commit is contained in:
commit
7f2ad60a35
21 changed files with 578 additions and 135 deletions
|
|
@ -1,11 +0,0 @@
|
|||
#! /bin/bash
|
||||
|
||||
cd src/frontend
|
||||
docker build -t logspace/frontend_build -f build.Dockerfile .
|
||||
cd ../backend
|
||||
docker build -t logspace/backend_build -f build.Dockerfile .
|
||||
|
||||
cd ../../
|
||||
VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version)
|
||||
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION .
|
||||
docker push ibiscp/langflow:$VERSION
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
# `python-base` sets up all our shared environment variables
|
||||
FROM python:3.10-slim
|
||||
|
||||
# python
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
# prevents python creating .pyc files
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
\
|
||||
# pip
|
||||
PIP_NO_CACHE_DIR=off \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=on \
|
||||
PIP_DEFAULT_TIMEOUT=100 \
|
||||
\
|
||||
# poetry
|
||||
# https://python-poetry.org/docs/configuration/#using-environment-variables
|
||||
POETRY_VERSION=1.4.0 \
|
||||
# make poetry install to this location
|
||||
POETRY_HOME="/opt/poetry" \
|
||||
# make poetry create the virtual environment in the project's root
|
||||
# it gets named `.venv`
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
# do not ask any interactive question
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
\
|
||||
# paths
|
||||
# this is where our requirements + virtual environment will live
|
||||
PYSETUP_PATH="/opt/pysetup" \
|
||||
VENV_PATH="/opt/pysetup/.venv"
|
||||
|
||||
# prepend poetry and venv to path
|
||||
ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH"
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install --no-install-recommends -y \
|
||||
# deps for installing poetry
|
||||
curl \
|
||||
# deps for building python deps
|
||||
build-essential libpq-dev
|
||||
|
||||
# install poetry - respects $POETRY_VERSION & $POETRY_HOME
|
||||
RUN curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# copy project requirement files here to ensure they will be cached.
|
||||
WORKDIR /app
|
||||
COPY poetry.lock pyproject.toml ./
|
||||
COPY langflow/ ./langflow
|
||||
|
||||
# poetry install
|
||||
RUN poetry install --without dev
|
||||
|
||||
# build wheel
|
||||
RUN poetry build -f wheel
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
#! /bin/bash
|
||||
|
||||
docker build -t logspace/backend_build -f build.Dockerfile .
|
||||
VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version)
|
||||
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION .
|
||||
docker push ibiscp/langflow:$VERSION
|
||||
|
|
@ -3,6 +3,10 @@ from pydantic import BaseModel, validator
|
|||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
|
||||
|
||||
class CacheResponse(BaseModel):
|
||||
data: dict
|
||||
|
||||
|
||||
class Code(BaseModel):
|
||||
code: str
|
||||
|
||||
|
|
|
|||
18
src/backend/langflow/api/callback.py
Normal file
18
src/backend/langflow/api/callback.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from typing import Any
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
from langflow.api.schemas import ChatResponse
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
self.websocket = websocket
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(
|
||||
sender="bot", message=token, type="stream", intermediate_steps=""
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
13
src/backend/langflow/api/chat.py
Normal file
13
src/backend/langflow/api/chat.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from fastapi import APIRouter, WebSocket
|
||||
|
||||
from langflow.api.chat_manager import ChatManager
|
||||
|
||||
router = APIRouter()
|
||||
chat_manager = ChatManager()
|
||||
|
||||
|
||||
@router.websocket("/chat/{client_id}")
|
||||
async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
|
||||
|
||||
208
src/backend/langflow/api/chat_manager.py
Normal file
208
src/backend/langflow/api/chat_manager.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
import asyncio
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict
|
||||
from fastapi import WebSocket
|
||||
import json
|
||||
from langchain.llms import OpenAI, AzureOpenAI
|
||||
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache.manager import AsyncSubject
|
||||
from langchain.callbacks.base import AsyncCallbackManager
|
||||
from langflow.api.callback import StreamingLLMCallbackHandler
|
||||
from langflow.interface.run import (
|
||||
async_get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.cache import cache_manager
|
||||
from PIL.Image import Image
|
||||
|
||||
|
||||
class ChatHistory(AsyncSubject):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
|
||||
|
||||
async def add_message(self, client_id: str, message: ChatMessage):
|
||||
self.history[client_id].append(message)
|
||||
await self.notify()
|
||||
|
||||
def get_history(self, client_id: str) -> List[ChatMessage]:
|
||||
return self.history[client_id]
|
||||
|
||||
|
||||
class ChatManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.chat_history = ChatHistory()
|
||||
self.chat_history.attach(self.on_chat_history_update)
|
||||
self.cache_manager = cache_manager
|
||||
self.cache_manager.attach(self.update)
|
||||
|
||||
async def on_chat_history_update(self):
|
||||
"""Send the last chat message to the client."""
|
||||
client_id = self.cache_manager.current_client_id
|
||||
if client_id in self.active_connections:
|
||||
chat_response = self.chat_history.get_history(client_id)[-1]
|
||||
if chat_response.sender == "bot":
|
||||
# Process FileResponse
|
||||
if isinstance(chat_response, FileResponse):
|
||||
# If data_type is pandas, convert to csv
|
||||
if chat_response.data_type == "pandas":
|
||||
chat_response.data = chat_response.data.to_csv()
|
||||
elif chat_response.data_type == "image":
|
||||
# Base64 encode the image
|
||||
chat_response.data = pil_to_base64(chat_response.data)
|
||||
|
||||
await self.send_json(client_id, chat_response)
|
||||
|
||||
def update(self):
|
||||
if self.cache_manager.current_client_id in self.active_connections:
|
||||
self.last_cached_object_dict = self.cache_manager.get_last()
|
||||
# Add a new ChatResponse with the data
|
||||
chat_response = FileResponse(
|
||||
sender="bot",
|
||||
message=None,
|
||||
type="file",
|
||||
data=self.last_cached_object_dict["obj"],
|
||||
data_type=self.last_cached_object_dict["type"],
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self.chat_history.add_message(
|
||||
self.cache_manager.current_client_id, chat_response
|
||||
)
|
||||
)
|
||||
|
||||
async def connect(self, client_id: str, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections[client_id] = websocket
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
del self.active_connections[client_id]
|
||||
|
||||
async def send_message(self, client_id: str, message: str):
|
||||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_text(message)
|
||||
|
||||
async def send_json(self, client_id: str, message: ChatMessage):
|
||||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_json(json.dumps(message.dict()))
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
# Process the graph data and chat message
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(sender="you", message=chat_message)
|
||||
await self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
start_resp = ChatResponse(
|
||||
sender="bot", message=None, type="start", intermediate_steps=""
|
||||
)
|
||||
await self.chat_history.add_message(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) == 0
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
|
||||
result, intermediate_steps = await process_graph(
|
||||
graph_data=graph_data,
|
||||
is_first_message=is_first_message,
|
||||
chat_message=chat_message,
|
||||
websocket=self.active_connections[client_id],
|
||||
)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
error_resp = ChatResponse(
|
||||
sender="bot", message=str(e), type="error", intermediate_steps=""
|
||||
)
|
||||
await self.send_json(client_id, error_resp)
|
||||
return
|
||||
# Send a response back to the frontend, if needed
|
||||
response = ChatResponse(
|
||||
sender="bot",
|
||||
message=result or "",
|
||||
intermediate_steps=intermediate_steps or "",
|
||||
type="end",
|
||||
)
|
||||
await self.chat_history.add_message(client_id, response)
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
|
||||
try:
|
||||
chat_history = self.chat_history.get_history(client_id)
|
||||
await websocket.send_json(json.dumps(chat_history))
|
||||
|
||||
while True:
|
||||
json_payload = await websocket.receive_json()
|
||||
payload = json.loads(json_payload)
|
||||
with self.cache_manager.set_client_id(client_id):
|
||||
await self.process_message(client_id, payload)
|
||||
except Exception as e:
|
||||
# Handle any exceptions that might occur
|
||||
print(f"Error: {e}")
|
||||
finally:
|
||||
self.disconnect(client_id)
|
||||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict,
|
||||
is_first_message: bool,
|
||||
chat_message: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await async_get_result_and_steps(
|
||||
langchain_object, chat_message.message or ""
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
if isinstance(llm, (OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI)):
|
||||
llm.streaming = bool(hasattr(llm, "streaming"))
|
||||
|
||||
if hasattr(langchain_object, "callback_manager"):
|
||||
stream_handler = StreamingLLMCallbackHandler(websocket)
|
||||
stream_manager = AsyncCallbackManager([stream_handler])
|
||||
langchain_object.callback_manager = stream_manager
|
||||
return langchain_object
|
||||
|
||||
|
||||
def pil_to_base64(image: Image) -> str:
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue())
|
||||
return img_str.decode("utf-8")
|
||||
42
src/backend/langflow/api/schemas.py
Normal file
42
src/backend/langflow/api/schemas.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from typing import Any, Union
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message schema."""
|
||||
|
||||
sender: str
|
||||
message: Union[str, None] = None
|
||||
|
||||
@validator("sender")
|
||||
def sender_must_be_bot_or_you(cls, v):
|
||||
if v not in ["bot", "you"]:
|
||||
raise ValueError("sender must be bot or you")
|
||||
return v
|
||||
|
||||
|
||||
class ChatResponse(ChatMessage):
|
||||
"""Chat response schema."""
|
||||
|
||||
intermediate_steps: str
|
||||
type: str
|
||||
|
||||
@validator("type")
|
||||
def validate_message_type(cls, v):
|
||||
if v not in ["start", "stream", "end", "error", "info", "file"]:
|
||||
raise ValueError("type must be start, stream, end, error, info, or file")
|
||||
return v
|
||||
|
||||
|
||||
class FileResponse(ChatMessage):
|
||||
"""File response schema."""
|
||||
|
||||
data: Any
|
||||
data_type: str
|
||||
type: str = "file"
|
||||
|
||||
@validator("data_type")
|
||||
def validate_data_type(cls, v):
|
||||
if v not in ["image", "csv"]:
|
||||
raise ValueError("data_type must be image or csv")
|
||||
return v
|
||||
1
src/backend/langflow/cache/__init__.py
vendored
1
src/backend/langflow/cache/__init__.py
vendored
|
|
@ -0,0 +1 @@
|
|||
from langflow.cache.manager import cache_manager # noqa
|
||||
|
|
@ -2,13 +2,18 @@ import base64
|
|||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from PIL import Image
|
||||
import dill
|
||||
import pandas as pd # type: ignore
|
||||
|
||||
import dill # type: ignore
|
||||
CACHE = {}
|
||||
|
||||
|
||||
def create_cache_folder(func):
|
||||
126
src/backend/langflow/cache/manager.py
vendored
Normal file
126
src/backend/langflow/cache/manager.py
vendored
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Any, Awaitable, Callable, List
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class Subject:
|
||||
"""Base class for implementing the observer pattern."""
|
||||
|
||||
def __init__(self):
|
||||
self.observers: List[Callable[[], None]] = []
|
||||
|
||||
def attach(self, observer: Callable[[], None]):
|
||||
"""Attach an observer to the subject."""
|
||||
self.observers.append(observer)
|
||||
|
||||
def detach(self, observer: Callable[[], None]):
|
||||
"""Detach an observer from the subject."""
|
||||
self.observers.remove(observer)
|
||||
|
||||
def notify(self):
|
||||
"""Notify all observers about an event."""
|
||||
for observer in self.observers:
|
||||
if observer is None:
|
||||
continue
|
||||
observer()
|
||||
|
||||
|
||||
class AsyncSubject:
|
||||
"""Base class for implementing the async observer pattern."""
|
||||
|
||||
def __init__(self):
|
||||
self.observers: List[Callable[[], Awaitable]] = []
|
||||
|
||||
def attach(self, observer: Callable[[], Awaitable]):
|
||||
"""Attach an observer to the subject."""
|
||||
self.observers.append(observer)
|
||||
|
||||
def detach(self, observer: Callable[[], Awaitable]):
|
||||
"""Detach an observer from the subject."""
|
||||
self.observers.remove(observer)
|
||||
|
||||
async def notify(self):
|
||||
"""Notify all observers about an event."""
|
||||
for observer in self.observers:
|
||||
if observer is None:
|
||||
continue
|
||||
await observer()
|
||||
|
||||
|
||||
class CacheManager(Subject):
|
||||
"""Manages cache for different clients and notifies observers on changes."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.CACHE = {}
|
||||
self.current_client_id = None
|
||||
|
||||
@contextmanager
|
||||
def set_client_id(self, client_id: str):
|
||||
"""
|
||||
Context manager to set the current client_id and associated cache.
|
||||
|
||||
Args:
|
||||
client_id (str): The client identifier.
|
||||
"""
|
||||
previous_client_id = self.current_client_id
|
||||
self.current_client_id = client_id
|
||||
self.current_cache = self.CACHE.setdefault(client_id, {})
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.current_client_id = previous_client_id
|
||||
self.current_cache = self.CACHE.get(self.current_client_id, {})
|
||||
|
||||
def add_pandas(self, name: str, obj: Any):
|
||||
"""
|
||||
Add a pandas DataFrame or Series to the current client's cache.
|
||||
|
||||
Args:
|
||||
name (str): The cache key.
|
||||
obj (Any): The pandas DataFrame or Series object.
|
||||
"""
|
||||
if isinstance(obj, (pd.DataFrame, pd.Series)):
|
||||
self.current_cache[name] = {"obj": obj, "type": "pandas"}
|
||||
self.notify()
|
||||
else:
|
||||
raise ValueError("Object is not a pandas DataFrame or Series")
|
||||
|
||||
def add_image(self, name: str, obj: Any):
|
||||
"""
|
||||
Add a PIL Image to the current client's cache.
|
||||
|
||||
Args:
|
||||
name (str): The cache key.
|
||||
obj (Any): The PIL Image object.
|
||||
"""
|
||||
if isinstance(obj, Image.Image):
|
||||
self.current_cache[name] = {"obj": obj, "type": "image"}
|
||||
self.notify()
|
||||
else:
|
||||
raise ValueError("Object is not a PIL Image")
|
||||
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Get an object from the current client's cache.
|
||||
|
||||
Args:
|
||||
name (str): The cache key.
|
||||
|
||||
Returns:
|
||||
The cached object associated with the given cache key.
|
||||
"""
|
||||
return self.current_cache[name]
|
||||
|
||||
def get_last(self):
|
||||
"""
|
||||
Get the last added item in the current client's cache.
|
||||
|
||||
Returns:
|
||||
The last added item in the cache.
|
||||
"""
|
||||
return list(self.current_cache.values())[-1]
|
||||
|
||||
|
||||
cache_manager = CacheManager()
|
||||
|
|
@ -9,7 +9,7 @@ import warnings
|
|||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langflow.cache import utils as cache_utils
|
||||
from langflow.cache import base as cache_utils
|
||||
from langflow.graph.constants import DIRECT_TYPES
|
||||
from langflow.interface import loading
|
||||
from langflow.interface.listing import ALL_TYPES_DICT
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import io
|
|||
from typing import Any, Dict
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
|
||||
from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.graph.graph import Graph
|
||||
from langflow.interface import loading
|
||||
from langflow.utils.logger import logger
|
||||
|
|
@ -32,7 +32,7 @@ def load_or_build_langchain_object(data_graph, is_first_message=False):
|
|||
return build_langchain_object_with_caching(data_graph)
|
||||
|
||||
|
||||
@memoize_dict(maxsize=1)
|
||||
@memoize_dict(maxsize=10)
|
||||
def build_langchain_object_with_caching(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
|
|
@ -87,7 +87,7 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
|
||||
# Save langchain_object to cache
|
||||
|
|
@ -118,7 +118,7 @@ def process_graph_cached(data_graph: Dict[str, Any]):
|
|||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ def fix_memory_inputs(langchain_object):
|
|||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(langchain_object, message: str):
|
||||
def get_result_and_steps(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
|
|
@ -240,6 +240,61 @@ def get_result_and_thought_using_graph(langchain_object, message: str):
|
|||
return result, thought
|
||||
|
||||
|
||||
async def async_get_result_and_steps(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
if hasattr(langchain_object, "acall"):
|
||||
output = await langchain_object.acall(chat_input)
|
||||
else:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
if intermediate_steps:
|
||||
thought = format_intermediate_steps(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
|
||||
|
||||
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
|
||||
from langflow.api.endpoints import router as endpoints_router
|
||||
from langflow.api.validate import router as validate_router
|
||||
from langflow.api.chat import router as chat_router
|
||||
|
||||
|
||||
def create_app():
|
||||
|
|
@ -23,6 +24,7 @@ def create_app():
|
|||
|
||||
app.include_router(endpoints_router)
|
||||
app.include_router(validate_router)
|
||||
app.include_router(chat_router)
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
#! /bin/bash
|
||||
|
||||
poetry remove langchain
|
||||
docker build -t logspace/backend_build -f build.Dockerfile .
|
||||
VERSION=$(toml get --toml-path pyproject.toml tool.poetry.version)
|
||||
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow:$VERSION .
|
||||
docker run -p 5003:80 -d ibiscp/langflow:$VERSION
|
||||
poetry add --editable ../../../langchain
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
FROM node:14-alpine
|
||||
WORKDIR /app
|
||||
COPY . /app
|
||||
RUN npm install
|
||||
RUN npm run build
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
#! /bin/bash
|
||||
|
||||
# Read the contents of the JSON file
|
||||
json=$(cat package.json)
|
||||
|
||||
# Extract the value of the "version" field using jq
|
||||
VERSION=$(echo "$json" | jq -r '.version')
|
||||
|
||||
docker build -t logspace/frontend_build -f build.Dockerfile .
|
||||
docker build --build-arg VERSION=$VERSION -t ibiscp/langflow_frontend:$VERSION .
|
||||
docker push ibiscp/langflow_frontend:$VERSION
|
||||
|
|
@ -1,4 +1,8 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -21,6 +25,15 @@ def get_text():
|
|||
"""
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
async def async_client() -> AsyncGenerator:
|
||||
from langflow.main import create_app
|
||||
|
||||
app = create_app()
|
||||
async with AsyncClient(app=app, base_url="http://testserver") as client:
|
||||
yield client
|
||||
|
||||
|
||||
# Create client fixture for FastAPI
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
|
|
@ -30,3 +43,37 @@ def client():
|
|||
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
from langflow.graph.graph import Graph
|
||||
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_graph():
|
||||
return get_graph("openapi")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import tempfile
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langflow.cache.utils import PREFIX, save_cache
|
||||
from langflow.cache.base import PREFIX, save_cache
|
||||
from langflow.interface.run import load_langchain_object
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from langflow.graph.nodes import (
|
|||
ToolNode,
|
||||
WrapperNode,
|
||||
)
|
||||
from langflow.interface.run import get_result_and_thought_using_graph
|
||||
from langflow.interface.run import get_result_and_steps
|
||||
from langflow.utils.payload import build_json, get_root_node
|
||||
|
||||
# Test cases for the graph module
|
||||
|
|
@ -24,38 +24,6 @@ from langflow.utils.payload import build_json, get_root_node
|
|||
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
def get_node_by_type(graph, node_type: Type[Node]) -> Union[Node, None]:
|
||||
"""Get a node by type"""
|
||||
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
|
||||
|
|
@ -441,7 +409,7 @@ def test_get_result_and_thought(basic_graph):
|
|||
# now build again and check if FakeListLLM was used
|
||||
|
||||
# Get the result and thought
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
# The result should be a str
|
||||
assert isinstance(result, str)
|
||||
# The thought should be a Thought
|
||||
|
|
|
|||
47
tests/test_websocket.py
Normal file
47
tests/test_websocket.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
from langflow.api.schemas import ChatMessage
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_websocket_connection(client: TestClient):
|
||||
with client.websocket_connect("/ws/test_client") as websocket:
|
||||
assert websocket.scope["client"] == ["testclient", 50000]
|
||||
assert websocket.scope["path"] == "/ws/test_client"
|
||||
|
||||
|
||||
def test_chat_history(client: TestClient):
|
||||
chat_history = []
|
||||
|
||||
# Mock the process_graph function to return a specific value
|
||||
with patch("langflow.api.chat_manager.process_graph") as mock_process_graph:
|
||||
mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
|
||||
|
||||
with client.websocket_connect("/ws/test_client") as websocket:
|
||||
# First message should be the history
|
||||
history = websocket.receive_json()
|
||||
assert json.loads(history) == [] # Empty history
|
||||
# Send a message
|
||||
payload = {"message": "Hello"}
|
||||
websocket.send_json(json.dumps(payload))
|
||||
|
||||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert json.loads(response) == {
|
||||
"sender": "bot",
|
||||
"message": None,
|
||||
"intermediate_steps": "",
|
||||
"type": "start",
|
||||
}
|
||||
# Send another message
|
||||
payload = {"message": "How are you?"}
|
||||
websocket.send_json(json.dumps(payload))
|
||||
|
||||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert json.loads(response) == {
|
||||
"sender": "bot",
|
||||
"message": "Hello, I'm a mock response!",
|
||||
"intermediate_steps": "",
|
||||
"type": "end",
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue