Merge remote-tracking branch 'origin/validation_fix' into db
This commit is contained in:
commit
3920eb50d6
26 changed files with 1041 additions and 330 deletions
|
|
@ -1,26 +1,119 @@
|
|||
import json
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
HTTPException,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
WebSocketException,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langflow.api.v1.schemas import BuiltResponse, InitResponse
|
||||
|
||||
from langflow.chat.manager import ChatManager
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
router = APIRouter(tags=["Chat"])
|
||||
chat_manager = ChatManager()
|
||||
flow_data_store = {}
|
||||
|
||||
|
||||
@router.websocket("/chat/{client_id}")
|
||||
async def websocket_endpoint(client_id: str, websocket: WebSocket):
|
||||
async def chat(client_id: str, websocket: WebSocket):
|
||||
"""Websocket endpoint for chat."""
|
||||
try:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
if client_id in chat_manager.in_memory_cache:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
else:
|
||||
message = "Please, build the flow before sending messages"
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
|
||||
except WebSocketException as exc:
|
||||
logger.error(exc)
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
|
||||
except WebSocketDisconnect as exc:
|
||||
|
||||
|
||||
@router.post("/build/init", response_model=InitResponse, status_code=201)
|
||||
async def init_build(graph_data: dict):
|
||||
"""Initialize the build by storing graph data and returning a unique session ID."""
|
||||
|
||||
try:
|
||||
flow_id = graph_data.get("id")
|
||||
|
||||
flow_data_store[flow_id] = graph_data
|
||||
|
||||
return InitResponse(flowId=flow_id)
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc))
|
||||
return HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/build/{flow_id}/status", response_model=BuiltResponse)
|
||||
async def build_status(flow_id: str):
|
||||
"""Check the flow_id is in the flow_data_store."""
|
||||
try:
|
||||
built = flow_id in flow_data_store and not isinstance(
|
||||
flow_data_store[flow_id], dict
|
||||
)
|
||||
|
||||
return BuiltResponse(
|
||||
built=built,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
return HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/build/stream/{flow_id}", response_class=StreamingResponse)
|
||||
async def stream_build(flow_id: str):
|
||||
"""Stream the build process based on stored flow data."""
|
||||
|
||||
async def event_stream(flow_id):
|
||||
final_response = json.dumps({"end_of_stream": True})
|
||||
try:
|
||||
if flow_id not in flow_data_store:
|
||||
error_message = "Invalid session ID"
|
||||
yield f"data: {json.dumps({'error': error_message})}\n\n"
|
||||
return
|
||||
|
||||
graph_data = flow_data_store[flow_id].get("data")
|
||||
|
||||
if not graph_data:
|
||||
error_message = "No data provided"
|
||||
yield f"data: {json.dumps({'error': error_message})}\n\n"
|
||||
return
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
graph = Graph.from_payload(graph_data)
|
||||
for node in graph.generator_build():
|
||||
try:
|
||||
node.build()
|
||||
params = node._built_object_repr()
|
||||
valid = True
|
||||
logger.debug(
|
||||
f"Building node {params[:50]}{'...' if len(params) > 50 else ''}"
|
||||
)
|
||||
except Exception as exc:
|
||||
params = str(exc)
|
||||
valid = False
|
||||
|
||||
response = json.dumps(
|
||||
{
|
||||
"valid": valid,
|
||||
"params": params,
|
||||
"id": node.id,
|
||||
}
|
||||
)
|
||||
yield f"data: {response}\n\n"
|
||||
|
||||
chat_manager.set_cache(flow_id, graph.build())
|
||||
except Exception:
|
||||
logger.error("Error while building the flow")
|
||||
finally:
|
||||
yield f"data: {final_response}\n\n"
|
||||
|
||||
try:
|
||||
return StreamingResponse(event_stream(flow_id), media_type="text/event-stream")
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
|
|
|||
|
|
@ -93,3 +93,11 @@ class FlowListCreate(BaseModel):
|
|||
|
||||
class FlowListRead(BaseModel):
|
||||
flows: List[FlowRead]
|
||||
|
||||
|
||||
class InitResponse(BaseModel):
|
||||
flowId: str
|
||||
|
||||
|
||||
class BuiltResponse(BaseModel):
|
||||
built: bool
|
||||
|
|
|
|||
8
src/backend/langflow/cache/__init__.py
vendored
8
src/backend/langflow/cache/__init__.py
vendored
|
|
@ -1 +1,7 @@
|
|||
from langflow.cache.manager import cache_manager # noqa
|
||||
from langflow.cache.manager import cache_manager
|
||||
from langflow.cache.flow import InMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"cache_manager",
|
||||
"InMemoryCache",
|
||||
]
|
||||
|
|
|
|||
223
src/backend/langflow/cache/base.py
vendored
223
src/backend/langflow/cache/base.py
vendored
|
|
@ -1,154 +1,95 @@
|
|||
import base64
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import dill # type: ignore
|
||||
|
||||
CACHE: Dict[str, Any] = {}
|
||||
import abc
|
||||
|
||||
|
||||
def create_cache_folder(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the destination folder
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX
|
||||
|
||||
# Create the destination folder if it doesn't exist
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def memoize_dict(maxsize=128):
|
||||
cache = OrderedDict()
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
hashed = compute_dict_hash(args[0])
|
||||
key = (func.__name__, hashed, frozenset(kwargs.items()))
|
||||
if key not in cache:
|
||||
result = func(*args, **kwargs)
|
||||
cache[key] = result
|
||||
if len(cache) > maxsize:
|
||||
cache.popitem(last=False)
|
||||
else:
|
||||
result = cache[key]
|
||||
return result
|
||||
|
||||
def clear_cache():
|
||||
cache.clear()
|
||||
|
||||
wrapper.clear_cache = clear_cache # type: ignore
|
||||
wrapper.cache = cache # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
PREFIX = "langflow_cache"
|
||||
|
||||
|
||||
@create_cache_folder
|
||||
def clear_old_cache_files(max_cache_size: int = 3):
|
||||
cache_dir = Path(tempfile.gettempdir()) / PREFIX
|
||||
cache_files = list(cache_dir.glob("*.dill"))
|
||||
|
||||
if len(cache_files) > max_cache_size:
|
||||
cache_files_sorted_by_mtime = sorted(
|
||||
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
|
||||
)
|
||||
|
||||
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(cache_file)
|
||||
|
||||
|
||||
def compute_dict_hash(graph_data):
|
||||
graph_data = filter_json(graph_data)
|
||||
|
||||
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
|
||||
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def filter_json(json_data):
|
||||
filtered_data = json_data.copy()
|
||||
|
||||
# Remove 'viewport' and 'chatHistory' keys
|
||||
if "viewport" in filtered_data:
|
||||
del filtered_data["viewport"]
|
||||
if "chatHistory" in filtered_data:
|
||||
del filtered_data["chatHistory"]
|
||||
|
||||
# Filter nodes
|
||||
if "nodes" in filtered_data:
|
||||
for node in filtered_data["nodes"]:
|
||||
if "position" in node:
|
||||
del node["position"]
|
||||
if "positionAbsolute" in node:
|
||||
del node["positionAbsolute"]
|
||||
if "selected" in node:
|
||||
del node["selected"]
|
||||
if "dragging" in node:
|
||||
del node["dragging"]
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
@create_cache_folder
|
||||
def save_binary_file(content: str, file_name: str, accepted_types: list[str]) -> str:
|
||||
class BaseCache(abc.ABC):
|
||||
"""
|
||||
Save a binary file to the specified folder.
|
||||
|
||||
Args:
|
||||
content: The content of the file as a bytes object.
|
||||
file_name: The name of the file, including its extension.
|
||||
|
||||
Returns:
|
||||
The path to the saved file.
|
||||
Abstract base class for a cache.
|
||||
"""
|
||||
if not any(file_name.endswith(suffix) for suffix in accepted_types):
|
||||
raise ValueError(f"File {file_name} is not accepted")
|
||||
|
||||
# Get the destination folder
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX
|
||||
if not content:
|
||||
raise ValueError("Please, reload the file in the loader.")
|
||||
data = content.split(",")[1]
|
||||
decoded_bytes = base64.b64decode(data)
|
||||
@abc.abstractmethod
|
||||
def get(self, key):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
# Create the full file path
|
||||
file_path = os.path.join(cache_path, file_name)
|
||||
Args:
|
||||
key: The key of the item to retrieve.
|
||||
|
||||
# Save the binary content to the file
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(decoded_bytes)
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
return file_path
|
||||
@abc.abstractmethod
|
||||
def set(self, key, value):
|
||||
"""
|
||||
Add an item to the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@create_cache_folder
|
||||
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
|
||||
with cache_path.open("wb") as cache_file:
|
||||
dill.dump(chat_data, cache_file)
|
||||
@abc.abstractmethod
|
||||
def delete(self, key):
|
||||
"""
|
||||
Remove an item from the cache.
|
||||
|
||||
if clean_old_cache_files:
|
||||
clear_old_cache_files()
|
||||
Args:
|
||||
key: The key of the item to remove.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self):
|
||||
"""
|
||||
Clear all items from the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@create_cache_folder
|
||||
def load_cache(hash_val):
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
|
||||
if cache_path.exists():
|
||||
with cache_path.open("rb") as cache_file:
|
||||
return dill.load(cache_file)
|
||||
return None
|
||||
@abc.abstractmethod
|
||||
def __contains__(self, key):
|
||||
"""
|
||||
Check if the key is in the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to check.
|
||||
|
||||
Returns:
|
||||
True if the key is in the cache, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, key):
|
||||
"""
|
||||
Retrieve an item from the cache using the square bracket notation.
|
||||
|
||||
Args:
|
||||
key: The key of the item to retrieve.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __setitem__(self, key, value):
|
||||
"""
|
||||
Add an item to the cache using the square bracket notation.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __delitem__(self, key):
|
||||
"""
|
||||
Remove an item from the cache using the square bracket notation.
|
||||
|
||||
Args:
|
||||
key: The key of the item to remove.
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
146
src/backend/langflow/cache/flow.py
vendored
Normal file
146
src/backend/langflow/cache/flow.py
vendored
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
from langflow.cache.base import BaseCache
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
"""
|
||||
A simple in-memory cache using an OrderedDict.
|
||||
|
||||
This cache supports setting a maximum size and expiration time for cached items.
|
||||
When the cache is full, it uses a Least Recently Used (LRU) eviction policy.
|
||||
Thread-safe using a threading Lock.
|
||||
|
||||
Attributes:
|
||||
max_size (int, optional): Maximum number of items to store in the cache.
|
||||
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
|
||||
|
||||
Example:
|
||||
|
||||
cache = InMemoryCache(max_size=3, expiration_time=5)
|
||||
|
||||
# setting cache values
|
||||
cache.set("a", 1)
|
||||
cache.set("b", 2)
|
||||
cache["c"] = 3
|
||||
|
||||
# getting cache values
|
||||
a = cache.get("a")
|
||||
b = cache["b"]
|
||||
"""
|
||||
|
||||
def __init__(self, max_size=None, expiration_time=60 * 60):
|
||||
"""
|
||||
Initialize a new InMemoryCache instance.
|
||||
|
||||
Args:
|
||||
max_size (int, optional): Maximum number of items to store in the cache.
|
||||
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
|
||||
"""
|
||||
self._cache = OrderedDict()
|
||||
self._lock = threading.Lock()
|
||||
self.max_size = max_size
|
||||
self.expiration_time = expiration_time
|
||||
|
||||
def get(self, key):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to retrieve.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found or the item has expired.
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
item = self._cache.pop(key)
|
||||
if (
|
||||
self.expiration_time is None
|
||||
or time.time() - item["time"] < self.expiration_time
|
||||
):
|
||||
# Move the key to the end to make it recently used
|
||||
self._cache[key] = item
|
||||
return item["value"]
|
||||
else:
|
||||
self.delete(key)
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
"""
|
||||
Add an item to the cache.
|
||||
|
||||
If the cache is full, the least recently used item is evicted.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache.
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
# Remove existing key before re-inserting to update order
|
||||
self.delete(key)
|
||||
elif self.max_size and len(self._cache) >= self.max_size:
|
||||
# Remove least recently used item
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = {"value": value, "time": time.time()}
|
||||
|
||||
def get_or_set(self, key, value):
|
||||
"""
|
||||
Retrieve an item from the cache. If the item does not exist, set it with the provided value.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache if the item doesn't exist.
|
||||
|
||||
Returns:
|
||||
The cached value associated with the key.
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
return self.get(key)
|
||||
self.set(key, value)
|
||||
return value
|
||||
|
||||
def delete(self, key):
|
||||
"""
|
||||
Remove an item from the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to remove.
|
||||
"""
|
||||
# with self._lock:
|
||||
self._cache.pop(key, None)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clear all items from the cache.
|
||||
"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Check if the key is in the cache."""
|
||||
return key in self._cache
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Retrieve an item from the cache using the square bracket notation."""
|
||||
return self.get(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Add an item to the cache using the square bracket notation."""
|
||||
self.set(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Remove an item from the cache using the square bracket notation."""
|
||||
self.delete(key)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of items in the cache."""
|
||||
return len(self._cache)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the InMemoryCache instance."""
|
||||
return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})"
|
||||
6
src/backend/langflow/cache/manager.py
vendored
6
src/backend/langflow/cache/manager.py
vendored
|
|
@ -54,7 +54,7 @@ class CacheManager(Subject):
|
|||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.CACHE = {}
|
||||
self._cache = {}
|
||||
self.current_client_id = None
|
||||
self.current_cache = {}
|
||||
|
||||
|
|
@ -68,12 +68,12 @@ class CacheManager(Subject):
|
|||
"""
|
||||
previous_client_id = self.current_client_id
|
||||
self.current_client_id = client_id
|
||||
self.current_cache = self.CACHE.setdefault(client_id, {})
|
||||
self.current_cache = self._cache.setdefault(client_id, {})
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.current_client_id = previous_client_id
|
||||
self.current_cache = self.CACHE.get(self.current_client_id, {})
|
||||
self.current_cache = self._cache.get(self.current_client_id, {})
|
||||
|
||||
def add(self, name: str, obj: Any, obj_type: str, extension: Optional[str] = None):
|
||||
"""
|
||||
|
|
|
|||
154
src/backend/langflow/cache/utils.py
vendored
Normal file
154
src/backend/langflow/cache/utils.py
vendored
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
import base64
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import dill # type: ignore
|
||||
|
||||
CACHE: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def create_cache_folder(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the destination folder
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX
|
||||
|
||||
# Create the destination folder if it doesn't exist
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def memoize_dict(maxsize=128):
|
||||
cache = OrderedDict()
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
hashed = compute_dict_hash(args[0])
|
||||
key = (func.__name__, hashed, frozenset(kwargs.items()))
|
||||
if key not in cache:
|
||||
result = func(*args, **kwargs)
|
||||
cache[key] = result
|
||||
if len(cache) > maxsize:
|
||||
cache.popitem(last=False)
|
||||
else:
|
||||
result = cache[key]
|
||||
return result
|
||||
|
||||
def clear_cache():
|
||||
cache.clear()
|
||||
|
||||
wrapper.clear_cache = clear_cache # type: ignore
|
||||
wrapper.cache = cache # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
PREFIX = "langflow_cache"
|
||||
|
||||
|
||||
@create_cache_folder
|
||||
def clear_old_cache_files(max_cache_size: int = 3):
|
||||
cache_dir = Path(tempfile.gettempdir()) / PREFIX
|
||||
cache_files = list(cache_dir.glob("*.dill"))
|
||||
|
||||
if len(cache_files) > max_cache_size:
|
||||
cache_files_sorted_by_mtime = sorted(
|
||||
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
|
||||
)
|
||||
|
||||
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(cache_file)
|
||||
|
||||
|
||||
def compute_dict_hash(graph_data):
|
||||
graph_data = filter_json(graph_data)
|
||||
|
||||
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
|
||||
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def filter_json(json_data):
|
||||
filtered_data = json_data.copy()
|
||||
|
||||
# Remove 'viewport' and 'chatHistory' keys
|
||||
if "viewport" in filtered_data:
|
||||
del filtered_data["viewport"]
|
||||
if "chatHistory" in filtered_data:
|
||||
del filtered_data["chatHistory"]
|
||||
|
||||
# Filter nodes
|
||||
if "nodes" in filtered_data:
|
||||
for node in filtered_data["nodes"]:
|
||||
if "position" in node:
|
||||
del node["position"]
|
||||
if "positionAbsolute" in node:
|
||||
del node["positionAbsolute"]
|
||||
if "selected" in node:
|
||||
del node["selected"]
|
||||
if "dragging" in node:
|
||||
del node["dragging"]
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
@create_cache_folder
|
||||
def save_binary_file(content: str, file_name: str, accepted_types: list[str]) -> str:
|
||||
"""
|
||||
Save a binary file to the specified folder.
|
||||
|
||||
Args:
|
||||
content: The content of the file as a bytes object.
|
||||
file_name: The name of the file, including its extension.
|
||||
|
||||
Returns:
|
||||
The path to the saved file.
|
||||
"""
|
||||
if not any(file_name.endswith(suffix) for suffix in accepted_types):
|
||||
raise ValueError(f"File {file_name} is not accepted")
|
||||
|
||||
# Get the destination folder
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX
|
||||
if not content:
|
||||
raise ValueError("Please, reload the file in the loader.")
|
||||
data = content.split(",")[1]
|
||||
decoded_bytes = base64.b64decode(data)
|
||||
|
||||
# Create the full file path
|
||||
file_path = os.path.join(cache_path, file_name)
|
||||
|
||||
# Save the binary content to the file
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(decoded_bytes)
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@create_cache_folder
|
||||
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
|
||||
with cache_path.open("wb") as cache_file:
|
||||
dill.dump(chat_data, cache_file)
|
||||
|
||||
if clean_old_cache_files:
|
||||
clear_old_cache_files()
|
||||
|
||||
|
||||
@create_cache_folder
|
||||
def load_cache(hash_val):
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
|
||||
if cache_path.exists():
|
||||
with cache_path.open("rb") as cache_file:
|
||||
return dill.load(cache_file)
|
||||
return None
|
||||
|
|
@ -10,7 +10,9 @@ from langflow.utils.logger import logger
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langflow.cache.flow import InMemoryCache
|
||||
|
||||
|
||||
class ChatHistory(Subject):
|
||||
|
|
@ -46,6 +48,7 @@ class ChatManager:
|
|||
self.chat_history = ChatHistory()
|
||||
self.cache_manager = cache_manager
|
||||
self.cache_manager.attach(self.update)
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
|
||||
def on_chat_history_update(self):
|
||||
"""Send the last chat message to the client."""
|
||||
|
|
@ -99,24 +102,30 @@ class ChatManager:
|
|||
websocket = self.active_connections[client_id]
|
||||
await websocket.send_json(message.dict())
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict):
|
||||
async def close_connection(self, client_id: str, code: int, reason: str):
|
||||
if websocket := self.active_connections[client_id]:
|
||||
await websocket.close(code=code, reason=reason)
|
||||
self.disconnect(client_id)
|
||||
|
||||
async def process_message(
|
||||
self, client_id: str, payload: Dict, langchain_object: Any
|
||||
):
|
||||
# Process the graph data and chat message
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(message=chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
|
||||
graph_data = payload
|
||||
# graph_data = payload
|
||||
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
|
||||
await self.send_json(client_id, start_resp)
|
||||
|
||||
is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1
|
||||
# is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
|
||||
result, intermediate_steps = await process_graph(
|
||||
graph_data=graph_data,
|
||||
is_first_message=is_first_message,
|
||||
langchain_object=langchain_object,
|
||||
chat_message=chat_message,
|
||||
websocket=self.active_connections[client_id],
|
||||
)
|
||||
|
|
@ -149,6 +158,14 @@ class ChatManager:
|
|||
await self.send_json(client_id, response)
|
||||
self.chat_history.add_message(client_id, response)
|
||||
|
||||
def set_cache(self, client_id: str, langchain_object: Any) -> bool:
|
||||
"""
|
||||
Set the cache for a client.
|
||||
"""
|
||||
|
||||
self.in_memory_cache.set(client_id, langchain_object)
|
||||
return client_id in self.in_memory_cache
|
||||
|
||||
async def handle_websocket(self, client_id: str, websocket: WebSocket):
|
||||
await self.connect(client_id, websocket)
|
||||
|
||||
|
|
@ -169,22 +186,24 @@ class ChatManager:
|
|||
continue
|
||||
|
||||
with self.cache_manager.set_client_id(client_id):
|
||||
await self.process_message(client_id, payload)
|
||||
langchain_object = self.in_memory_cache.get(client_id)
|
||||
await self.process_message(client_id, payload, langchain_object)
|
||||
|
||||
except Exception as e:
|
||||
# Handle any exceptions that might occur
|
||||
logger.exception(e)
|
||||
# send a message to the client
|
||||
await self.active_connections[client_id].close(
|
||||
code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120]
|
||||
logger.error(e)
|
||||
await self.close_connection(
|
||||
client_id=client_id,
|
||||
code=status.WS_1011_INTERNAL_ERROR,
|
||||
reason=str(e)[:120],
|
||||
)
|
||||
self.disconnect(client_id)
|
||||
finally:
|
||||
try:
|
||||
connection = self.active_connections.get(client_id)
|
||||
if connection:
|
||||
await connection.close(code=1000, reason="Client disconnected")
|
||||
self.disconnect(client_id)
|
||||
await self.close_connection(
|
||||
client_id=client_id,
|
||||
code=status.WS_1000_NORMAL_CLOSURE,
|
||||
reason="Client disconnected",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.error(e)
|
||||
self.disconnect(client_id)
|
||||
|
|
|
|||
|
|
@ -1,23 +1,15 @@
|
|||
from fastapi import WebSocket
|
||||
from langflow.api.v1.schemas import ChatMessage
|
||||
from langflow.processing.process import (
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.processing.base import get_result_and_steps
|
||||
from langflow.interface.utils import try_setting_streaming_options
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict,
|
||||
is_first_message: bool,
|
||||
langchain_object,
|
||||
chat_message: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Type, Union
|
||||
from typing import Dict, Generator, List, Type, Union
|
||||
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.graph.constants import VERTEX_TYPE_MAP
|
||||
|
|
@ -106,6 +106,47 @@ class Graph:
|
|||
raise ValueError("No root node found")
|
||||
return root_node.build()
|
||||
|
||||
def topological_sort(self) -> List[Vertex]:
|
||||
"""
|
||||
Performs a topological sort of the vertices in the graph.
|
||||
|
||||
Returns:
|
||||
List[Vertex]: A list of vertices in topological order.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph contains a cycle.
|
||||
"""
|
||||
# States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||
state = {node: 0 for node in self.nodes}
|
||||
sorted_vertices = []
|
||||
|
||||
def dfs(node):
|
||||
if state[node] == 1:
|
||||
# We have a cycle
|
||||
raise ValueError(
|
||||
"Graph contains a cycle, cannot perform topological sort"
|
||||
)
|
||||
if state[node] == 0:
|
||||
state[node] = 1
|
||||
for edge in node.edges:
|
||||
if edge.source == node:
|
||||
dfs(edge.target)
|
||||
state[node] = 2
|
||||
sorted_vertices.append(node)
|
||||
|
||||
# Visit each node
|
||||
for node in self.nodes:
|
||||
if state[node] == 0:
|
||||
dfs(node)
|
||||
|
||||
return list(reversed(sorted_vertices))
|
||||
|
||||
def generator_build(self) -> Generator:
|
||||
"""Builds each vertex in the graph and yields it."""
|
||||
sorted_vertices = self.topological_sort()
|
||||
logger.info("Sorted vertices: %s", sorted_vertices)
|
||||
yield from sorted_vertices
|
||||
|
||||
def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]:
|
||||
"""Returns the neighbors of a node."""
|
||||
neighbors: Dict[Vertex, int] = {}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from langflow.cache import base as cache_utils
|
||||
from langflow.cache import utils as cache_utils
|
||||
from langflow.graph.vertex.constants import DIRECT_TYPES
|
||||
from langflow.interface import loading
|
||||
from langflow.interface.listing import ALL_TYPES_DICT
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
|
|
|||
|
|
@ -1,26 +1,18 @@
|
|||
import {
|
||||
classNames,
|
||||
nodeColors,
|
||||
nodeIcons,
|
||||
toNormalCase,
|
||||
toTitleCase,
|
||||
} from "../../utils";
|
||||
import { classNames, nodeColors, nodeIcons, toTitleCase } from "../../utils";
|
||||
import ParameterComponent from "./components/parameterComponent";
|
||||
import { typesContext } from "../../contexts/typesContext";
|
||||
import { useContext, useState, useEffect, useRef, Fragment } from "react";
|
||||
import { useContext, useState, useEffect, useRef } from "react";
|
||||
import { NodeDataType } from "../../types/flow";
|
||||
import { alertContext } from "../../contexts/alertContext";
|
||||
import { PopUpContext } from "../../contexts/popUpContext";
|
||||
import NodeModal from "../../modals/NodeModal";
|
||||
import { useCallback } from "react";
|
||||
import { TabsContext } from "../../contexts/tabsContext";
|
||||
import { debounce } from "../../utils";
|
||||
import Tooltip from "../../components/TooltipComponent";
|
||||
import { NodeToolbar } from "reactflow";
|
||||
import NodeToolbarComponent from "../../pages/FlowPage/components/nodeToolbarComponent";
|
||||
|
||||
import ShadTooltip from "../../components/ShadTooltipComponent";
|
||||
import { postValidateNode } from "../../controllers/API";
|
||||
import { useSSE } from "../../contexts/SSEContext";
|
||||
|
||||
export default function GenericNode({
|
||||
data,
|
||||
selected,
|
||||
|
|
@ -31,46 +23,30 @@ export default function GenericNode({
|
|||
const { setErrorData } = useContext(alertContext);
|
||||
const showError = useRef(true);
|
||||
const { types, deleteNode } = useContext(typesContext);
|
||||
const { openPopUp } = useContext(PopUpContext);
|
||||
const { closePopUp } = useContext(PopUpContext);
|
||||
|
||||
const { closePopUp, openPopUp } = useContext(PopUpContext);
|
||||
|
||||
const Icon = nodeIcons[data.type] || nodeIcons[types[data.type]];
|
||||
const [validationStatus, setValidationStatus] = useState(null);
|
||||
// State for outline color
|
||||
const [isValid, setIsValid] = useState(false);
|
||||
const { reactFlowInstance } = useContext(typesContext);
|
||||
const [params, setParams] = useState([]);
|
||||
const { sseData } = useSSE();
|
||||
|
||||
// useEffect(() => {
|
||||
// if (reactFlowInstance) {
|
||||
// setParams(Object.values(reactFlowInstance.toObject()));
|
||||
// }
|
||||
// }, [save]);
|
||||
|
||||
// New useEffect to watch for changes in sseData and update validation status
|
||||
useEffect(() => {
|
||||
if (reactFlowInstance) {
|
||||
setParams(Object.values(reactFlowInstance.toObject()));
|
||||
const relevantData = sseData[data.id];
|
||||
if (relevantData) {
|
||||
// Extract validation information from relevantData and update the validationStatus state
|
||||
setValidationStatus(relevantData);
|
||||
} else {
|
||||
setValidationStatus(null);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const validateNode = useCallback(
|
||||
debounce(async () => {
|
||||
try {
|
||||
const response = await postValidateNode(
|
||||
data.id,
|
||||
reactFlowInstance.toObject()
|
||||
);
|
||||
|
||||
if (response.status === 200) {
|
||||
let jsonResponseParsed = await JSON.parse(response.data);
|
||||
setValidationStatus(jsonResponseParsed);
|
||||
}
|
||||
} catch (error) {
|
||||
// console.error("Error validating node:", error);
|
||||
setValidationStatus("error");
|
||||
}
|
||||
}, 1000), // Adjust the debounce delay (500ms) as needed
|
||||
[reactFlowInstance, data.id]
|
||||
);
|
||||
useEffect(() => {
|
||||
if (params.length > 0) {
|
||||
validateNode();
|
||||
}
|
||||
}, [params, validateNode]);
|
||||
}, [sseData, data.id]);
|
||||
|
||||
if (!Icon) {
|
||||
if (showError.current) {
|
||||
|
|
|
|||
155
src/frontend/src/components/chatComponent/buildTrigger/index.tsx
Normal file
155
src/frontend/src/components/chatComponent/buildTrigger/index.tsx
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
import { useState, useContext } from "react";
|
||||
import { Transition } from "@headlessui/react";
|
||||
import { Zap } from "lucide-react";
|
||||
import { validateNodes } from "../../../utils";
|
||||
import { FlowType } from "../../../types/flow";
|
||||
import Loading from "../../../components/ui/loading";
|
||||
import { useSSE } from "../../../contexts/SSEContext";
|
||||
import { typesContext } from "../../../contexts/typesContext";
|
||||
import { alertContext } from "../../../contexts/alertContext";
|
||||
import { postBuildInit } from "../../../controllers/API";
|
||||
|
||||
export default function BuildTrigger({
|
||||
open,
|
||||
flow,
|
||||
setIsBuilt,
|
||||
isBuilt,
|
||||
}: {
|
||||
open: boolean;
|
||||
flow: FlowType;
|
||||
setIsBuilt: any;
|
||||
isBuilt: boolean;
|
||||
}) {
|
||||
const [isBuilding, setIsBuilding] = useState(false);
|
||||
|
||||
const { updateSSEData } = useSSE();
|
||||
const { reactFlowInstance } = useContext(typesContext);
|
||||
const { setErrorData } = useContext(alertContext);
|
||||
|
||||
async function handleBuild(flow: FlowType) {
|
||||
const errors = validateNodes(reactFlowInstance);
|
||||
if (errors.length > 0) {
|
||||
setErrorData({
|
||||
title: "Oops! Looks like you missed something",
|
||||
list: errors,
|
||||
});
|
||||
return;
|
||||
}
|
||||
const minimumLoadingTime = 200; // in milliseconds
|
||||
const startTime = Date.now();
|
||||
setIsBuilding(true);
|
||||
|
||||
try {
|
||||
const allNodesValid = await streamNodeData(flow);
|
||||
await enforceMinimumLoadingTime(startTime, minimumLoadingTime);
|
||||
setIsBuilt(allNodesValid);
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
} finally {
|
||||
setIsBuilding(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function streamNodeData(flow: FlowType) {
|
||||
// Step 1: Make a POST request to send the flow data and receive a unique session ID
|
||||
const response = await postBuildInit(flow);
|
||||
const { flowId } = response.data;
|
||||
|
||||
// Step 2: Use the session ID to establish an SSE connection using EventSource
|
||||
let validationResults = [];
|
||||
let finished = false;
|
||||
const apiUrl = `/api/v1/build/stream/${flowId}`;
|
||||
const eventSource = new EventSource(apiUrl);
|
||||
try{
|
||||
eventSource.onmessage = (event) => {
|
||||
// If the event is parseable, return
|
||||
if (!event.data) {
|
||||
return;
|
||||
}
|
||||
const parsedData = JSON.parse(event.data);
|
||||
// if the event is the end of the stream, close the connection
|
||||
if (parsedData.end_of_stream) {
|
||||
eventSource.close();
|
||||
|
||||
return;
|
||||
}
|
||||
// Otherwise, process the data
|
||||
const isValid = processStreamResult(parsedData);
|
||||
validationResults.push(isValid);
|
||||
};
|
||||
|
||||
eventSource.onerror = (error) => {
|
||||
console.error("EventSource failed:", error);
|
||||
eventSource.close();
|
||||
};
|
||||
// Step 3: Wait for the stream to finish
|
||||
while (!finished) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
finished = validationResults.length === flow.data.nodes.length;
|
||||
}
|
||||
// Step 4: Return true if all nodes are valid, false otherwise
|
||||
return validationResults.every((result) => result);
|
||||
}
|
||||
catch(e){
|
||||
console.log(e)
|
||||
eventSource.close();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function processStreamResult(parsedData) {
|
||||
// Process each chunk of data here
|
||||
// Parse the chunk and update the context
|
||||
try {
|
||||
updateSSEData({ [parsedData.id]: parsedData });
|
||||
} catch (err) {
|
||||
console.log("Error parsing stream data: ", err);
|
||||
}
|
||||
return parsedData.valid;
|
||||
}
|
||||
|
||||
async function enforceMinimumLoadingTime(
|
||||
startTime: number,
|
||||
minimumLoadingTime: number
|
||||
) {
|
||||
const elapsedTime = Date.now() - startTime;
|
||||
const remainingTime = minimumLoadingTime - elapsedTime;
|
||||
|
||||
if (remainingTime > 0) {
|
||||
return new Promise((resolve) => setTimeout(resolve, remainingTime));
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Transition
|
||||
show={!open}
|
||||
appear={true}
|
||||
enter="transition ease-out duration-300"
|
||||
enterFrom="translate-y-96"
|
||||
enterTo="translate-y-0"
|
||||
leave="transition ease-in duration-300"
|
||||
leaveFrom="translate-y-0"
|
||||
leaveTo="translate-y-96"
|
||||
>
|
||||
<div className={`fixed right-4` + (isBuilt ? " bottom-20" : " bottom-4")}>
|
||||
<div
|
||||
className="border flex justify-center align-center py-1 px-3 w-12 h-12 rounded-full bg-gradient-to-r from-blue-700 via-blue-600 to-blue-500 dark:border-gray-600 cursor-pointer"
|
||||
onClick={() => {
|
||||
if (!isBuilding) handleBuild(flow);
|
||||
}}
|
||||
>
|
||||
<button>
|
||||
<div className="flex gap-3 items-center">
|
||||
{isBuilding ? (
|
||||
// Render your loading animation here when isBuilding is true
|
||||
<Loading style={{ color: "white" }} />
|
||||
) : (
|
||||
<Zap className="h-6 w-6" style={{ color: "white" }} />
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</Transition>
|
||||
);
|
||||
}
|
||||
|
|
@ -3,13 +3,26 @@ import {
|
|||
Bars3CenterLeftIcon,
|
||||
ChatBubbleBottomCenterTextIcon,
|
||||
} from "@heroicons/react/24/outline";
|
||||
import { MessagesSquare } from "lucide-react";
|
||||
import { nodeColors } from "../../../utils";
|
||||
import { PopUpContext } from "../../../contexts/popUpContext";
|
||||
import { alertContext } from "../../../contexts/alertContext";
|
||||
import { useContext } from "react";
|
||||
import ChatModal from "../../../modals/chatModal";
|
||||
|
||||
export default function ChatTrigger({ open, setOpen }) {
|
||||
const { openPopUp } = useContext(PopUpContext);
|
||||
export default function ChatTrigger({ open, setOpen, isBuilt }) {
|
||||
const { setErrorData } = useContext(alertContext);
|
||||
|
||||
function handleClick() {
|
||||
if (isBuilt) {
|
||||
setOpen(true);
|
||||
} else {
|
||||
setErrorData({
|
||||
title: "Flow not built",
|
||||
list: ["Please build the flow before chatting"],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Transition
|
||||
show={!open}
|
||||
|
|
@ -24,13 +37,11 @@ export default function ChatTrigger({ open, setOpen }) {
|
|||
<div className="absolute bottom-4 right-3">
|
||||
<div
|
||||
className="border flex justify-center align-center py-1 px-3 w-12 h-12 rounded-full bg-gradient-to-r from-blue-500 via-blue-600 to-blue-700 dark:border-gray-600 cursor-pointer"
|
||||
onClick={() => {
|
||||
setOpen(true);
|
||||
}}
|
||||
onClick={handleClick}
|
||||
>
|
||||
<button>
|
||||
<div className="flex gap-3 items-center">
|
||||
<ChatBubbleBottomCenterTextIcon
|
||||
<MessagesSquare
|
||||
className="h-6 w-6 mt-1"
|
||||
style={{ color: "white" }}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
import { useEffect, useRef, useState } from "react";
|
||||
|
||||
import { Context, useEffect, useRef, useState, useContext } from "react";
|
||||
import ReactFlow, { useNodes } from "reactflow";
|
||||
import { ChatMessageType, ChatType } from "../../types/chat";
|
||||
import ChatTrigger from "./chatTrigger";
|
||||
import BuildTrigger from "./buildTrigger";
|
||||
import ChatModal from "../../modals/chatModal";
|
||||
|
||||
import _ from "lodash";
|
||||
import _, { set } from "lodash";
|
||||
import { getBuildStatus } from "../../controllers/API";
|
||||
import { NodeType } from "../../types/flow";
|
||||
|
||||
export default function Chat({ flow }: ChatType) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [isBuilt, setIsBuilt] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (
|
||||
|
|
@ -23,10 +28,58 @@ export default function Chat({ flow }: ChatType) {
|
|||
document.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
// Define an async function within the useEffect hook
|
||||
const fetchBuildStatus = async () => {
|
||||
const response = await getBuildStatus(flow.id);
|
||||
setIsBuilt(response.built);
|
||||
};
|
||||
|
||||
// Call the async function
|
||||
fetchBuildStatus();
|
||||
}, [flow]);
|
||||
|
||||
const prevNodesRef = useRef<any[] | undefined>();
|
||||
const nodes = useNodes();
|
||||
useEffect(() => {
|
||||
const prevNodes = prevNodesRef.current;
|
||||
const currentNodes = nodes.map(
|
||||
(node: NodeType) => node.data.node.template.value
|
||||
);
|
||||
|
||||
if (
|
||||
prevNodes &&
|
||||
JSON.stringify(prevNodes) !== JSON.stringify(currentNodes)
|
||||
) {
|
||||
setIsBuilt(false);
|
||||
console.log("Nodes changed");
|
||||
}
|
||||
|
||||
prevNodesRef.current = currentNodes;
|
||||
}, [nodes]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ChatModal key={flow.id} flow={flow} open={open} setOpen={setOpen} />
|
||||
<ChatTrigger open={open} setOpen={setOpen} />
|
||||
{isBuilt ? (
|
||||
<div>
|
||||
<BuildTrigger
|
||||
open={open}
|
||||
flow={flow}
|
||||
setIsBuilt={setIsBuilt}
|
||||
isBuilt={isBuilt}
|
||||
/>
|
||||
<ChatModal key={flow.id} flow={flow} open={open} setOpen={setOpen} />
|
||||
<ChatTrigger open={open} setOpen={setOpen} isBuilt={isBuilt} />
|
||||
</div>
|
||||
) : (
|
||||
<BuildTrigger
|
||||
open={open}
|
||||
flow={flow}
|
||||
setIsBuilt={setIsBuilt}
|
||||
isBuilt={isBuilt}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
39
src/frontend/src/components/ui/loading.tsx
Normal file
39
src/frontend/src/components/ui/loading.tsx
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import { SVGProps } from "react";
|
||||
|
||||
// https://github.com/feathericons/feather/issues/695#issuecomment-1503699643
|
||||
export const Loading = (props: SVGProps<SVGSVGElement>) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width={24}
|
||||
height={24}
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth={2}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
className="feather feather-circle"
|
||||
{...props}
|
||||
>
|
||||
<circle cx={12} cy={12} r={10} strokeDasharray={63} strokeDashoffset={21}>
|
||||
<animateTransform
|
||||
attributeName="transform"
|
||||
type="rotate"
|
||||
from="0 12 12"
|
||||
to="360 12 12"
|
||||
dur="2s"
|
||||
repeatCount="indefinite"
|
||||
/>
|
||||
<animate
|
||||
attributeName="stroke-dashoffset"
|
||||
dur="8s"
|
||||
repeatCount="indefinite"
|
||||
keyTimes="0; 0.5; 1"
|
||||
values="-16; -47; -16"
|
||||
calcMode="spline"
|
||||
keySplines="0.4 0 0.2 1; 0.4 0 0.2 1"
|
||||
/>
|
||||
</circle>
|
||||
</svg>
|
||||
);
|
||||
export default Loading;
|
||||
35
src/frontend/src/contexts/SSEContext.tsx
Normal file
35
src/frontend/src/contexts/SSEContext.tsx
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback,
|
||||
} from "react";
|
||||
|
||||
const initialValue = {
|
||||
updateSSEData: ({}) => {},
|
||||
sseData: {},
|
||||
};
|
||||
|
||||
const SSEContext = createContext(initialValue);
|
||||
|
||||
export function useSSE() {
|
||||
return useContext(SSEContext);
|
||||
}
|
||||
|
||||
export function SSEProvider({ children }) {
|
||||
const [sseData, setSSEData] = useState({});
|
||||
|
||||
const updateSSEData = useCallback((newData: any) => {
|
||||
setSSEData((prevData) => ({
|
||||
...prevData,
|
||||
...newData,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<SSEContext.Provider value={{ sseData, updateSSEData }}>
|
||||
{children}
|
||||
</SSEContext.Provider>
|
||||
);
|
||||
}
|
||||
|
|
@ -1,4 +1,9 @@
|
|||
import { PromptTypeAPI, errorsTypeAPI } from "./../../types/api/index";
|
||||
import {
|
||||
BuildStatusTypeAPI,
|
||||
PromptTypeAPI,
|
||||
errorsTypeAPI,
|
||||
InitTypeAPI,
|
||||
} from "./../../types/api/index";
|
||||
import { APIObjectType, sendAllProps } from "../../types/api/index";
|
||||
import axios, { AxiosResponse } from "axios";
|
||||
import { FlowStyleType, FlowType } from "../../types/flow";
|
||||
|
|
@ -272,4 +277,14 @@ export async function getVersion() {
|
|||
*/
|
||||
export async function getHealth() {
|
||||
return await axios.get("/health"); // Health is the only endpoint that doesn't require /api/v1
|
||||
export async function getBuildStatus(
|
||||
flowId: string
|
||||
): Promise<BuildStatusTypeAPI> {
|
||||
return await axios.get(`/api/v1/build/${flowId}/status`);
|
||||
}
|
||||
|
||||
export async function postBuildInit(
|
||||
flow: FlowType
|
||||
): Promise<AxiosResponse<InitTypeAPI>> {
|
||||
return await axios.post(`/api/v1/build/init`, flow);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import IntComponent from "../../components/intComponent";
|
|||
import InputFileComponent from "../../components/inputFileComponent";
|
||||
import PromptAreaComponent from "../../components/promptComponent";
|
||||
import CodeAreaComponent from "../../components/codeAreaComponent";
|
||||
import { TabsContext } from "../../contexts/tabsContext";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
|
|
@ -33,7 +32,6 @@ import {
|
|||
DialogTrigger,
|
||||
} from "../../components/ui/dialog";
|
||||
import { Button } from "../../components/ui/button";
|
||||
import { Edit } from "lucide-react";
|
||||
import { Badge } from "../../components/ui/badge";
|
||||
|
||||
export default function EditNodeModal({ data }: { data: NodeDataType }) {
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import { ChatBubbleOvalLeftEllipsisIcon } from "@heroicons/react/24/outline";
|
|||
import { Fragment, useContext, useEffect, useRef, useState } from "react";
|
||||
import { FlowType, NodeType } from "../../types/flow";
|
||||
import { alertContext } from "../../contexts/alertContext";
|
||||
import { toNormalCase } from "../../utils";
|
||||
import { toNormalCase, validateNodes } from "../../utils";
|
||||
import { typesContext } from "../../contexts/typesContext";
|
||||
import ChatMessage from "./chatMessage";
|
||||
import { FaEraser } from "react-icons/fa";
|
||||
|
|
@ -185,6 +185,17 @@ export default function ChatModal({
|
|||
}://${host}${chatEndpoint}`;
|
||||
}
|
||||
|
||||
function getWebSocketUrl(chatId, isDevelopment = false) {
|
||||
const isSecureProtocol = window.location.protocol === "https:";
|
||||
const webSocketProtocol = isSecureProtocol ? "wss" : "ws";
|
||||
const host = isDevelopment ? "localhost:7860" : window.location.host;
|
||||
const chatEndpoint = `/api/v1/chat/${chatId}`;
|
||||
|
||||
return `${
|
||||
isDevelopment ? "ws" : webSocketProtocol
|
||||
}://${host}${chatEndpoint}`;
|
||||
}
|
||||
|
||||
function connectWS() {
|
||||
try {
|
||||
const urlWs = getWebSocketUrl(
|
||||
|
|
@ -269,53 +280,6 @@ export default function ChatModal({
|
|||
if (ref.current) ref.current.scrollIntoView({ behavior: "smooth" });
|
||||
}, [chatHistory]);
|
||||
|
||||
function validateNode(n: NodeType): Array<string> {
|
||||
if (!n.data?.node?.template || !Object.keys(n.data.node.template)) {
|
||||
setNoticeData({
|
||||
title:
|
||||
"We've noticed a potential issue with a node in the flow. Please review it and, if necessary, submit a bug report with your exported flow file. Thank you for your help!",
|
||||
});
|
||||
return [];
|
||||
}
|
||||
|
||||
const {
|
||||
type,
|
||||
node: { template },
|
||||
} = n.data;
|
||||
|
||||
return Object.keys(template).reduce(
|
||||
(errors: Array<string>, t) =>
|
||||
errors.concat(
|
||||
template[t].required &&
|
||||
template[t].show &&
|
||||
(template[t].value === undefined ||
|
||||
template[t].value === null ||
|
||||
template[t].value === "") &&
|
||||
!reactFlowInstance
|
||||
.getEdges()
|
||||
.some(
|
||||
(e) =>
|
||||
e.targetHandle.split("|")[1] === t &&
|
||||
e.targetHandle.split("|")[2] === n.id
|
||||
)
|
||||
? [
|
||||
`${type} is missing ${
|
||||
template.display_name
|
||||
? template.display_name
|
||||
: toNormalCase(template[t].name)
|
||||
}.`,
|
||||
]
|
||||
: []
|
||||
),
|
||||
[] as string[]
|
||||
);
|
||||
}
|
||||
|
||||
function validateNodes() {
|
||||
return reactFlowInstance
|
||||
.getNodes()
|
||||
.flatMap((n: NodeType) => validateNode(n));
|
||||
}
|
||||
|
||||
const ref = useRef(null);
|
||||
|
||||
|
|
@ -327,7 +291,7 @@ export default function ChatModal({
|
|||
|
||||
function sendMessage() {
|
||||
if (chatValue !== "") {
|
||||
let nodeValidationErrors = validateNodes();
|
||||
let nodeValidationErrors = validateNodes(reactFlowInstance);
|
||||
if (nodeValidationErrors.length === 0) {
|
||||
setLockChat(true);
|
||||
let message = chatValue;
|
||||
|
|
|
|||
|
|
@ -38,3 +38,11 @@ export type errorsTypeAPI = {
|
|||
imports: { errors: Array<string> };
|
||||
};
|
||||
export type PromptTypeAPI = { input_variables: Array<string> };
|
||||
|
||||
export type BuildStatusTypeAPI = {
|
||||
built: boolean;
|
||||
};
|
||||
|
||||
export type InitTypeAPI = {
|
||||
flowId: string;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import {
|
|||
Bars3CenterLeftIcon,
|
||||
} from "@heroicons/react/24/outline";
|
||||
import { Connection, Edge, Node, ReactFlowInstance } from "reactflow";
|
||||
import { FlowType, NodeDataType, NodeType } from "./types/flow";
|
||||
import { FlowType, NodeType } from "./types/flow";
|
||||
import { APITemplateType } from "./types/api";
|
||||
import _ from "lodash";
|
||||
import { ChromaIcon } from "./icons/ChromaIcon";
|
||||
|
|
@ -737,3 +737,56 @@ export function buildTweaks(flow) {
|
|||
return acc;
|
||||
}, {});
|
||||
}
|
||||
export function validateNode(
|
||||
n: NodeType,
|
||||
reactFlowInstance: ReactFlowInstance
|
||||
): Array<string> {
|
||||
if (!n.data?.node?.template || !Object.keys(n.data.node.template)) {
|
||||
return [
|
||||
"We've noticed a potential issue with a node in the flow. Please review it and, if necessary, submit a bug report with your exported flow file. Thank you for your help!",
|
||||
];
|
||||
}
|
||||
|
||||
const {
|
||||
type,
|
||||
node: { template },
|
||||
} = n.data;
|
||||
|
||||
return Object.keys(template).reduce(
|
||||
(errors: Array<string>, t) =>
|
||||
errors.concat(
|
||||
template[t].required &&
|
||||
template[t].show &&
|
||||
(template[t].value === undefined ||
|
||||
template[t].value === null ||
|
||||
template[t].value === "") &&
|
||||
!reactFlowInstance
|
||||
.getEdges()
|
||||
.some(
|
||||
(e) =>
|
||||
e.targetHandle.split("|")[1] === t &&
|
||||
e.targetHandle.split("|")[2] === n.id
|
||||
)
|
||||
? [
|
||||
`${type} is missing ${
|
||||
template.display_name
|
||||
? template.display_name
|
||||
: toNormalCase(template[t].name)
|
||||
}.`,
|
||||
]
|
||||
: []
|
||||
),
|
||||
[] as string[]
|
||||
);
|
||||
}
|
||||
|
||||
export function validateNodes(reactFlowInstance: ReactFlowInstance) {
|
||||
if (reactFlowInstance.getNodes().length === 0) {
|
||||
return [
|
||||
"No nodes found in the flow. Please add at least one node to the flow.",
|
||||
];
|
||||
}
|
||||
return reactFlowInstance
|
||||
.getNodes()
|
||||
.flatMap((n: NodeType) => validateNode(n, reactFlowInstance));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,11 +12,9 @@ const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
|
|||
changeOrigin: true,
|
||||
secure: false,
|
||||
ws: true,
|
||||
// rewrite: (path) => `/api/v1${path}`,
|
||||
};
|
||||
return proxyObj;
|
||||
}, {});
|
||||
|
||||
export default defineConfig(() => {
|
||||
return {
|
||||
build: {
|
||||
|
|
|
|||
|
|
@ -66,6 +66,12 @@ def get_graph(_type="basic"):
|
|||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph_data():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
|
|
|||
|
|
@ -1,47 +1,47 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
from fastapi import WebSocketDisconnect
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
# from langflow.chat.manager import ChatManager
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_websocket_connection(client: TestClient):
|
||||
with client.websocket_connect("api/v1/chat/test_client") as websocket:
|
||||
assert websocket.scope["client"] == ["testclient", 50000]
|
||||
assert websocket.scope["path"] == "/api/v1/chat/test_client"
|
||||
def test_init_build(client):
|
||||
response = client.post(
|
||||
"api/v1/build/init", json={"id": "test", "data": {"key": "value"}}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"flowId": "test"}
|
||||
|
||||
|
||||
def test_chat_history(client: TestClient):
|
||||
# Mock the process_graph function to return a specific value
|
||||
with patch("langflow.chat.manager.process_graph") as mock_process_graph:
|
||||
mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
|
||||
def test_stream_build(client):
|
||||
client.post("/build/init", json={"id": "stream_test", "data": {"key": "value"}})
|
||||
|
||||
with client.websocket_connect("api/v1/chat/test_client") as websocket:
|
||||
# First message should be the history
|
||||
history = websocket.receive_json()
|
||||
assert history == [] # Empty history
|
||||
# Send a message
|
||||
payload = {"message": "Hello"}
|
||||
websocket.send_json(json.dumps(payload))
|
||||
# Test the stream
|
||||
response = client.get("api/v1/build/stream/stream_test")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert response == {
|
||||
"is_bot": True,
|
||||
"message": None,
|
||||
"type": "start",
|
||||
"intermediate_steps": "",
|
||||
"files": [],
|
||||
}
|
||||
# Send another message
|
||||
payload = {"message": "How are you?"}
|
||||
websocket.send_json(json.dumps(payload))
|
||||
|
||||
# Receive the response from the server
|
||||
response = websocket.receive_json()
|
||||
assert response == {
|
||||
"is_bot": True,
|
||||
"message": "Hello, I'm a mock response!",
|
||||
"type": "end",
|
||||
"intermediate_steps": "",
|
||||
"files": [],
|
||||
}
|
||||
def test_websocket_endpoint(client):
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(
|
||||
"api/v1/chat/non_existing_client_id"
|
||||
) as websocket:
|
||||
websocket.send_json({"type": "test"})
|
||||
data = websocket.receive_json()
|
||||
assert "Please, build the flow before sending messages" in data["message"]
|
||||
|
||||
|
||||
def test_websocket_endpoint_after_build(client, basic_graph_data):
|
||||
# Assuming your websocket_endpoint uses chat_manager which caches data from stream_build
|
||||
client.post("/build/init", json=basic_graph_data)
|
||||
client.get("/build/stream/websocket_test")
|
||||
|
||||
# There should be more to test here, but it depends on the inner workings of your websocket handler
|
||||
# and how your chat_manager and other classes behave. The following is just an example structure.
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect("api/v1/chat/websocket_test") as websocket:
|
||||
websocket.send_json({"type": "test"})
|
||||
# Perform assertions here, based on what you expect the websocket to return
|
||||
# data = websocket.receive_json()
|
||||
# assert ...
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue