Merge remote-tracking branch 'origin/validation_fix' into db

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-15 07:37:26 -03:00
commit 3920eb50d6
26 changed files with 1041 additions and 330 deletions

View file

@ -1,26 +1,119 @@
import json
from fastapi import (
APIRouter,
HTTPException,
WebSocket,
WebSocketDisconnect,
WebSocketException,
status,
)
from fastapi.responses import StreamingResponse
from langflow.api.v1.schemas import BuiltResponse, InitResponse
from langflow.chat.manager import ChatManager
from langflow.graph.graph.base import Graph
from langflow.utils.logger import logger
router = APIRouter(tags=["Chat"])
chat_manager = ChatManager()
flow_data_store = {}
@router.websocket("/chat/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
async def chat(client_id: str, websocket: WebSocket):
"""Websocket endpoint for chat."""
try:
await chat_manager.handle_websocket(client_id, websocket)
if client_id in chat_manager.in_memory_cache:
await chat_manager.handle_websocket(client_id, websocket)
else:
message = "Please, build the flow before sending messages"
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except WebSocketDisconnect as exc:
@router.post("/build/init", response_model=InitResponse, status_code=201)
async def init_build(graph_data: dict):
"""Initialize the build by storing graph data and returning a unique session ID."""
try:
flow_id = graph_data.get("id")
flow_data_store[flow_id] = graph_data
return InitResponse(flowId=flow_id)
except Exception as exc:
logger.error(exc)
await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc))
return HTTPException(status_code=500, detail=str(exc))
@router.get("/build/{flow_id}/status", response_model=BuiltResponse)
async def build_status(flow_id: str):
"""Check the flow_id is in the flow_data_store."""
try:
built = flow_id in flow_data_store and not isinstance(
flow_data_store[flow_id], dict
)
return BuiltResponse(
built=built,
)
except Exception as exc:
logger.error(exc)
return HTTPException(status_code=500, detail=str(exc))
@router.get("/build/stream/{flow_id}", response_class=StreamingResponse)
async def stream_build(flow_id: str):
"""Stream the build process based on stored flow data."""
async def event_stream(flow_id):
final_response = json.dumps({"end_of_stream": True})
try:
if flow_id not in flow_data_store:
error_message = "Invalid session ID"
yield f"data: {json.dumps({'error': error_message})}\n\n"
return
graph_data = flow_data_store[flow_id].get("data")
if not graph_data:
error_message = "No data provided"
yield f"data: {json.dumps({'error': error_message})}\n\n"
return
logger.debug("Building langchain object")
graph = Graph.from_payload(graph_data)
for node in graph.generator_build():
try:
node.build()
params = node._built_object_repr()
valid = True
logger.debug(
f"Building node {params[:50]}{'...' if len(params) > 50 else ''}"
)
except Exception as exc:
params = str(exc)
valid = False
response = json.dumps(
{
"valid": valid,
"params": params,
"id": node.id,
}
)
yield f"data: {response}\n\n"
chat_manager.set_cache(flow_id, graph.build())
except Exception:
logger.error("Error while building the flow")
finally:
yield f"data: {final_response}\n\n"
try:
return StreamingResponse(event_stream(flow_id), media_type="text/event-stream")
except Exception as exc:
logger.error(exc)
raise HTTPException(status_code=500, detail=str(exc))

View file

@ -93,3 +93,11 @@ class FlowListCreate(BaseModel):
class FlowListRead(BaseModel):
flows: List[FlowRead]
class InitResponse(BaseModel):
flowId: str
class BuiltResponse(BaseModel):
built: bool

View file

@ -1 +1,7 @@
from langflow.cache.manager import cache_manager # noqa
from langflow.cache.manager import cache_manager
from langflow.cache.flow import InMemoryCache
__all__ = [
"cache_manager",
"InMemoryCache",
]

View file

@ -1,154 +1,95 @@
import base64
import contextlib
import functools
import hashlib
import json
import os
import tempfile
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict
import dill # type: ignore
CACHE: Dict[str, Any] = {}
import abc
def create_cache_folder(func):
def wrapper(*args, **kwargs):
# Get the destination folder
cache_path = Path(tempfile.gettempdir()) / PREFIX
# Create the destination folder if it doesn't exist
os.makedirs(cache_path, exist_ok=True)
return func(*args, **kwargs)
return wrapper
def memoize_dict(maxsize=128):
cache = OrderedDict()
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
hashed = compute_dict_hash(args[0])
key = (func.__name__, hashed, frozenset(kwargs.items()))
if key not in cache:
result = func(*args, **kwargs)
cache[key] = result
if len(cache) > maxsize:
cache.popitem(last=False)
else:
result = cache[key]
return result
def clear_cache():
cache.clear()
wrapper.clear_cache = clear_cache # type: ignore
wrapper.cache = cache # type: ignore
return wrapper
return decorator
PREFIX = "langflow_cache"
@create_cache_folder
def clear_old_cache_files(max_cache_size: int = 3):
cache_dir = Path(tempfile.gettempdir()) / PREFIX
cache_files = list(cache_dir.glob("*.dill"))
if len(cache_files) > max_cache_size:
cache_files_sorted_by_mtime = sorted(
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
)
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
with contextlib.suppress(OSError):
os.remove(cache_file)
def compute_dict_hash(graph_data):
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def filter_json(json_data):
filtered_data = json_data.copy()
# Remove 'viewport' and 'chatHistory' keys
if "viewport" in filtered_data:
del filtered_data["viewport"]
if "chatHistory" in filtered_data:
del filtered_data["chatHistory"]
# Filter nodes
if "nodes" in filtered_data:
for node in filtered_data["nodes"]:
if "position" in node:
del node["position"]
if "positionAbsolute" in node:
del node["positionAbsolute"]
if "selected" in node:
del node["selected"]
if "dragging" in node:
del node["dragging"]
return filtered_data
@create_cache_folder
def save_binary_file(content: str, file_name: str, accepted_types: list[str]) -> str:
class BaseCache(abc.ABC):
"""
Save a binary file to the specified folder.
Args:
content: The content of the file as a bytes object.
file_name: The name of the file, including its extension.
Returns:
The path to the saved file.
Abstract base class for a cache.
"""
if not any(file_name.endswith(suffix) for suffix in accepted_types):
raise ValueError(f"File {file_name} is not accepted")
# Get the destination folder
cache_path = Path(tempfile.gettempdir()) / PREFIX
if not content:
raise ValueError("Please, reload the file in the loader.")
data = content.split(",")[1]
decoded_bytes = base64.b64decode(data)
@abc.abstractmethod
def get(self, key):
"""
Retrieve an item from the cache.
# Create the full file path
file_path = os.path.join(cache_path, file_name)
Args:
key: The key of the item to retrieve.
# Save the binary content to the file
with open(file_path, "wb") as file:
file.write(decoded_bytes)
Returns:
The value associated with the key, or None if the key is not found.
"""
pass
return file_path
@abc.abstractmethod
def set(self, key, value):
"""
Add an item to the cache.
Args:
key: The key of the item.
value: The value to cache.
"""
pass
@create_cache_folder
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
with cache_path.open("wb") as cache_file:
dill.dump(chat_data, cache_file)
@abc.abstractmethod
def delete(self, key):
"""
Remove an item from the cache.
if clean_old_cache_files:
clear_old_cache_files()
Args:
key: The key of the item to remove.
"""
pass
@abc.abstractmethod
def clear(self):
"""
Clear all items from the cache.
"""
pass
@create_cache_folder
def load_cache(hash_val):
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
if cache_path.exists():
with cache_path.open("rb") as cache_file:
return dill.load(cache_file)
return None
@abc.abstractmethod
def __contains__(self, key):
"""
Check if the key is in the cache.
Args:
key: The key of the item to check.
Returns:
True if the key is in the cache, False otherwise.
"""
pass
@abc.abstractmethod
def __getitem__(self, key):
"""
Retrieve an item from the cache using the square bracket notation.
Args:
key: The key of the item to retrieve.
Returns:
The value associated with the key, or None if the key is not found.
"""
pass
@abc.abstractmethod
def __setitem__(self, key, value):
"""
Add an item to the cache using the square bracket notation.
Args:
key: The key of the item.
value: The value to cache.
"""
pass
@abc.abstractmethod
def __delitem__(self, key):
"""
Remove an item from the cache using the square bracket notation.
Args:
key: The key of the item to remove.
"""
pass

146
src/backend/langflow/cache/flow.py vendored Normal file
View file

@ -0,0 +1,146 @@
import threading
import time
from collections import OrderedDict
from langflow.cache.base import BaseCache
class InMemoryCache(BaseCache):
"""
A simple in-memory cache using an OrderedDict.
This cache supports setting a maximum size and expiration time for cached items.
When the cache is full, it uses a Least Recently Used (LRU) eviction policy.
Thread-safe using a threading Lock.
Attributes:
max_size (int, optional): Maximum number of items to store in the cache.
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
Example:
cache = InMemoryCache(max_size=3, expiration_time=5)
# setting cache values
cache.set("a", 1)
cache.set("b", 2)
cache["c"] = 3
# getting cache values
a = cache.get("a")
b = cache["b"]
"""
def __init__(self, max_size=None, expiration_time=60 * 60):
"""
Initialize a new InMemoryCache instance.
Args:
max_size (int, optional): Maximum number of items to store in the cache.
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
"""
self._cache = OrderedDict()
self._lock = threading.Lock()
self.max_size = max_size
self.expiration_time = expiration_time
def get(self, key):
"""
Retrieve an item from the cache.
Args:
key: The key of the item to retrieve.
Returns:
The value associated with the key, or None if the key is not found or the item has expired.
"""
with self._lock:
if key in self._cache:
item = self._cache.pop(key)
if (
self.expiration_time is None
or time.time() - item["time"] < self.expiration_time
):
# Move the key to the end to make it recently used
self._cache[key] = item
return item["value"]
else:
self.delete(key)
return None
def set(self, key, value):
"""
Add an item to the cache.
If the cache is full, the least recently used item is evicted.
Args:
key: The key of the item.
value: The value to cache.
"""
with self._lock:
if key in self._cache:
# Remove existing key before re-inserting to update order
self.delete(key)
elif self.max_size and len(self._cache) >= self.max_size:
# Remove least recently used item
self._cache.popitem(last=False)
self._cache[key] = {"value": value, "time": time.time()}
def get_or_set(self, key, value):
"""
Retrieve an item from the cache. If the item does not exist, set it with the provided value.
Args:
key: The key of the item.
value: The value to cache if the item doesn't exist.
Returns:
The cached value associated with the key.
"""
with self._lock:
if key in self._cache:
return self.get(key)
self.set(key, value)
return value
def delete(self, key):
"""
Remove an item from the cache.
Args:
key: The key of the item to remove.
"""
# with self._lock:
self._cache.pop(key, None)
def clear(self):
"""
Clear all items from the cache.
"""
with self._lock:
self._cache.clear()
def __contains__(self, key):
"""Check if the key is in the cache."""
return key in self._cache
def __getitem__(self, key):
"""Retrieve an item from the cache using the square bracket notation."""
return self.get(key)
def __setitem__(self, key, value):
"""Add an item to the cache using the square bracket notation."""
self.set(key, value)
def __delitem__(self, key):
"""Remove an item from the cache using the square bracket notation."""
self.delete(key)
def __len__(self):
"""Return the number of items in the cache."""
return len(self._cache)
def __repr__(self):
"""Return a string representation of the InMemoryCache instance."""
return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})"

View file

@ -54,7 +54,7 @@ class CacheManager(Subject):
def __init__(self):
super().__init__()
self.CACHE = {}
self._cache = {}
self.current_client_id = None
self.current_cache = {}
@ -68,12 +68,12 @@ class CacheManager(Subject):
"""
previous_client_id = self.current_client_id
self.current_client_id = client_id
self.current_cache = self.CACHE.setdefault(client_id, {})
self.current_cache = self._cache.setdefault(client_id, {})
try:
yield
finally:
self.current_client_id = previous_client_id
self.current_cache = self.CACHE.get(self.current_client_id, {})
self.current_cache = self._cache.get(self.current_client_id, {})
def add(self, name: str, obj: Any, obj_type: str, extension: Optional[str] = None):
"""

154
src/backend/langflow/cache/utils.py vendored Normal file
View file

@ -0,0 +1,154 @@
import base64
import contextlib
import functools
import hashlib
import json
import os
import tempfile
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict
import dill # type: ignore
CACHE: Dict[str, Any] = {}
def create_cache_folder(func):
def wrapper(*args, **kwargs):
# Get the destination folder
cache_path = Path(tempfile.gettempdir()) / PREFIX
# Create the destination folder if it doesn't exist
os.makedirs(cache_path, exist_ok=True)
return func(*args, **kwargs)
return wrapper
def memoize_dict(maxsize=128):
cache = OrderedDict()
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
hashed = compute_dict_hash(args[0])
key = (func.__name__, hashed, frozenset(kwargs.items()))
if key not in cache:
result = func(*args, **kwargs)
cache[key] = result
if len(cache) > maxsize:
cache.popitem(last=False)
else:
result = cache[key]
return result
def clear_cache():
cache.clear()
wrapper.clear_cache = clear_cache # type: ignore
wrapper.cache = cache # type: ignore
return wrapper
return decorator
PREFIX = "langflow_cache"
@create_cache_folder
def clear_old_cache_files(max_cache_size: int = 3):
cache_dir = Path(tempfile.gettempdir()) / PREFIX
cache_files = list(cache_dir.glob("*.dill"))
if len(cache_files) > max_cache_size:
cache_files_sorted_by_mtime = sorted(
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
)
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
with contextlib.suppress(OSError):
os.remove(cache_file)
def compute_dict_hash(graph_data):
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def filter_json(json_data):
filtered_data = json_data.copy()
# Remove 'viewport' and 'chatHistory' keys
if "viewport" in filtered_data:
del filtered_data["viewport"]
if "chatHistory" in filtered_data:
del filtered_data["chatHistory"]
# Filter nodes
if "nodes" in filtered_data:
for node in filtered_data["nodes"]:
if "position" in node:
del node["position"]
if "positionAbsolute" in node:
del node["positionAbsolute"]
if "selected" in node:
del node["selected"]
if "dragging" in node:
del node["dragging"]
return filtered_data
@create_cache_folder
def save_binary_file(content: str, file_name: str, accepted_types: list[str]) -> str:
"""
Save a binary file to the specified folder.
Args:
content: The content of the file as a bytes object.
file_name: The name of the file, including its extension.
Returns:
The path to the saved file.
"""
if not any(file_name.endswith(suffix) for suffix in accepted_types):
raise ValueError(f"File {file_name} is not accepted")
# Get the destination folder
cache_path = Path(tempfile.gettempdir()) / PREFIX
if not content:
raise ValueError("Please, reload the file in the loader.")
data = content.split(",")[1]
decoded_bytes = base64.b64decode(data)
# Create the full file path
file_path = os.path.join(cache_path, file_name)
# Save the binary content to the file
with open(file_path, "wb") as file:
file.write(decoded_bytes)
return file_path
@create_cache_folder
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
with cache_path.open("wb") as cache_file:
dill.dump(chat_data, cache_file)
if clean_old_cache_files:
clear_old_cache_files()
@create_cache_folder
def load_cache(hash_val):
cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill"
if cache_path.exists():
with cache_path.open("rb") as cache_file:
return dill.load(cache_file)
return None

View file

@ -10,7 +10,9 @@ from langflow.utils.logger import logger
import asyncio
import json
from typing import Dict, List
from typing import Any, Dict, List
from langflow.cache.flow import InMemoryCache
class ChatHistory(Subject):
@ -46,6 +48,7 @@ class ChatManager:
self.chat_history = ChatHistory()
self.cache_manager = cache_manager
self.cache_manager.attach(self.update)
self.in_memory_cache = InMemoryCache()
def on_chat_history_update(self):
"""Send the last chat message to the client."""
@ -99,24 +102,30 @@ class ChatManager:
websocket = self.active_connections[client_id]
await websocket.send_json(message.dict())
async def process_message(self, client_id: str, payload: Dict):
async def close_connection(self, client_id: str, code: int, reason: str):
if websocket := self.active_connections[client_id]:
await websocket.close(code=code, reason=reason)
self.disconnect(client_id)
async def process_message(
self, client_id: str, payload: Dict, langchain_object: Any
):
# Process the graph data and chat message
chat_message = payload.pop("message", "")
chat_message = ChatMessage(message=chat_message)
self.chat_history.add_message(client_id, chat_message)
graph_data = payload
# graph_data = payload
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
await self.send_json(client_id, start_resp)
is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1
# is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1
# Generate result and thought
try:
logger.debug("Generating result and thought")
result, intermediate_steps = await process_graph(
graph_data=graph_data,
is_first_message=is_first_message,
langchain_object=langchain_object,
chat_message=chat_message,
websocket=self.active_connections[client_id],
)
@ -149,6 +158,14 @@ class ChatManager:
await self.send_json(client_id, response)
self.chat_history.add_message(client_id, response)
def set_cache(self, client_id: str, langchain_object: Any) -> bool:
"""
Set the cache for a client.
"""
self.in_memory_cache.set(client_id, langchain_object)
return client_id in self.in_memory_cache
async def handle_websocket(self, client_id: str, websocket: WebSocket):
await self.connect(client_id, websocket)
@ -169,22 +186,24 @@ class ChatManager:
continue
with self.cache_manager.set_client_id(client_id):
await self.process_message(client_id, payload)
langchain_object = self.in_memory_cache.get(client_id)
await self.process_message(client_id, payload, langchain_object)
except Exception as e:
# Handle any exceptions that might occur
logger.exception(e)
# send a message to the client
await self.active_connections[client_id].close(
code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120]
logger.error(e)
await self.close_connection(
client_id=client_id,
code=status.WS_1011_INTERNAL_ERROR,
reason=str(e)[:120],
)
self.disconnect(client_id)
finally:
try:
connection = self.active_connections.get(client_id)
if connection:
await connection.close(code=1000, reason="Client disconnected")
self.disconnect(client_id)
await self.close_connection(
client_id=client_id,
code=status.WS_1000_NORMAL_CLOSURE,
reason="Client disconnected",
)
except Exception as e:
logger.exception(e)
logger.error(e)
self.disconnect(client_id)

View file

@ -1,23 +1,15 @@
from fastapi import WebSocket
from langflow.api.v1.schemas import ChatMessage
from langflow.processing.process import (
load_or_build_langchain_object,
)
from langflow.processing.base import get_result_and_steps
from langflow.interface.utils import try_setting_streaming_options
from langflow.utils.logger import logger
from typing import Dict
async def process_graph(
graph_data: Dict,
is_first_message: bool,
langchain_object,
chat_message: ChatMessage,
websocket: WebSocket,
):
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
langchain_object = try_setting_streaming_options(langchain_object, websocket)
logger.debug("Loaded langchain object")

View file

@ -1,4 +1,4 @@
from typing import Dict, List, Type, Union
from typing import Dict, Generator, List, Type, Union
from langflow.graph.edge.base import Edge
from langflow.graph.graph.constants import VERTEX_TYPE_MAP
@ -106,6 +106,47 @@ class Graph:
raise ValueError("No root node found")
return root_node.build()
def topological_sort(self) -> List[Vertex]:
"""
Performs a topological sort of the vertices in the graph.
Returns:
List[Vertex]: A list of vertices in topological order.
Raises:
ValueError: If the graph contains a cycle.
"""
# States: 0 = unvisited, 1 = visiting, 2 = visited
state = {node: 0 for node in self.nodes}
sorted_vertices = []
def dfs(node):
if state[node] == 1:
# We have a cycle
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
if state[node] == 0:
state[node] = 1
for edge in node.edges:
if edge.source == node:
dfs(edge.target)
state[node] = 2
sorted_vertices.append(node)
# Visit each node
for node in self.nodes:
if state[node] == 0:
dfs(node)
return list(reversed(sorted_vertices))
def generator_build(self) -> Generator:
"""Builds each vertex in the graph and yields it."""
sorted_vertices = self.topological_sort()
logger.info("Sorted vertices: %s", sorted_vertices)
yield from sorted_vertices
def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a node."""
neighbors: Dict[Vertex, int] = {}

View file

@ -1,4 +1,4 @@
from langflow.cache import base as cache_utils
from langflow.cache import utils as cache_utils
from langflow.graph.vertex.constants import DIRECT_TYPES
from langflow.interface import loading
from langflow.interface.listing import ALL_TYPES_DICT

View file

@ -1,4 +1,4 @@
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict
from langflow.graph import Graph
from langflow.utils.logger import logger

View file

@ -1,26 +1,18 @@
import {
classNames,
nodeColors,
nodeIcons,
toNormalCase,
toTitleCase,
} from "../../utils";
import { classNames, nodeColors, nodeIcons, toTitleCase } from "../../utils";
import ParameterComponent from "./components/parameterComponent";
import { typesContext } from "../../contexts/typesContext";
import { useContext, useState, useEffect, useRef, Fragment } from "react";
import { useContext, useState, useEffect, useRef } from "react";
import { NodeDataType } from "../../types/flow";
import { alertContext } from "../../contexts/alertContext";
import { PopUpContext } from "../../contexts/popUpContext";
import NodeModal from "../../modals/NodeModal";
import { useCallback } from "react";
import { TabsContext } from "../../contexts/tabsContext";
import { debounce } from "../../utils";
import Tooltip from "../../components/TooltipComponent";
import { NodeToolbar } from "reactflow";
import NodeToolbarComponent from "../../pages/FlowPage/components/nodeToolbarComponent";
import ShadTooltip from "../../components/ShadTooltipComponent";
import { postValidateNode } from "../../controllers/API";
import { useSSE } from "../../contexts/SSEContext";
export default function GenericNode({
data,
selected,
@ -31,46 +23,30 @@ export default function GenericNode({
const { setErrorData } = useContext(alertContext);
const showError = useRef(true);
const { types, deleteNode } = useContext(typesContext);
const { openPopUp } = useContext(PopUpContext);
const { closePopUp } = useContext(PopUpContext);
const { closePopUp, openPopUp } = useContext(PopUpContext);
const Icon = nodeIcons[data.type] || nodeIcons[types[data.type]];
const [validationStatus, setValidationStatus] = useState(null);
// State for outline color
const [isValid, setIsValid] = useState(false);
const { reactFlowInstance } = useContext(typesContext);
const [params, setParams] = useState([]);
const { sseData } = useSSE();
// useEffect(() => {
// if (reactFlowInstance) {
// setParams(Object.values(reactFlowInstance.toObject()));
// }
// }, [save]);
// New useEffect to watch for changes in sseData and update validation status
useEffect(() => {
if (reactFlowInstance) {
setParams(Object.values(reactFlowInstance.toObject()));
const relevantData = sseData[data.id];
if (relevantData) {
// Extract validation information from relevantData and update the validationStatus state
setValidationStatus(relevantData);
} else {
setValidationStatus(null);
}
}, []);
const validateNode = useCallback(
debounce(async () => {
try {
const response = await postValidateNode(
data.id,
reactFlowInstance.toObject()
);
if (response.status === 200) {
let jsonResponseParsed = await JSON.parse(response.data);
setValidationStatus(jsonResponseParsed);
}
} catch (error) {
// console.error("Error validating node:", error);
setValidationStatus("error");
}
}, 1000), // Adjust the debounce delay (500ms) as needed
[reactFlowInstance, data.id]
);
useEffect(() => {
if (params.length > 0) {
validateNode();
}
}, [params, validateNode]);
}, [sseData, data.id]);
if (!Icon) {
if (showError.current) {

View file

@ -0,0 +1,155 @@
import { useState, useContext } from "react";
import { Transition } from "@headlessui/react";
import { Zap } from "lucide-react";
import { validateNodes } from "../../../utils";
import { FlowType } from "../../../types/flow";
import Loading from "../../../components/ui/loading";
import { useSSE } from "../../../contexts/SSEContext";
import { typesContext } from "../../../contexts/typesContext";
import { alertContext } from "../../../contexts/alertContext";
import { postBuildInit } from "../../../controllers/API";
export default function BuildTrigger({
open,
flow,
setIsBuilt,
isBuilt,
}: {
open: boolean;
flow: FlowType;
setIsBuilt: any;
isBuilt: boolean;
}) {
const [isBuilding, setIsBuilding] = useState(false);
const { updateSSEData } = useSSE();
const { reactFlowInstance } = useContext(typesContext);
const { setErrorData } = useContext(alertContext);
async function handleBuild(flow: FlowType) {
const errors = validateNodes(reactFlowInstance);
if (errors.length > 0) {
setErrorData({
title: "Oops! Looks like you missed something",
list: errors,
});
return;
}
const minimumLoadingTime = 200; // in milliseconds
const startTime = Date.now();
setIsBuilding(true);
try {
const allNodesValid = await streamNodeData(flow);
await enforceMinimumLoadingTime(startTime, minimumLoadingTime);
setIsBuilt(allNodesValid);
} catch (error) {
console.error("Error:", error);
} finally {
setIsBuilding(false);
}
}
async function streamNodeData(flow: FlowType) {
// Step 1: Make a POST request to send the flow data and receive a unique session ID
const response = await postBuildInit(flow);
const { flowId } = response.data;
// Step 2: Use the session ID to establish an SSE connection using EventSource
let validationResults = [];
let finished = false;
const apiUrl = `/api/v1/build/stream/${flowId}`;
const eventSource = new EventSource(apiUrl);
try{
eventSource.onmessage = (event) => {
// If the event is parseable, return
if (!event.data) {
return;
}
const parsedData = JSON.parse(event.data);
// if the event is the end of the stream, close the connection
if (parsedData.end_of_stream) {
eventSource.close();
return;
}
// Otherwise, process the data
const isValid = processStreamResult(parsedData);
validationResults.push(isValid);
};
eventSource.onerror = (error) => {
console.error("EventSource failed:", error);
eventSource.close();
};
// Step 3: Wait for the stream to finish
while (!finished) {
await new Promise((resolve) => setTimeout(resolve, 100));
finished = validationResults.length === flow.data.nodes.length;
}
// Step 4: Return true if all nodes are valid, false otherwise
return validationResults.every((result) => result);
}
catch(e){
console.log(e)
eventSource.close();
return false;
}
}
function processStreamResult(parsedData) {
// Process each chunk of data here
// Parse the chunk and update the context
try {
updateSSEData({ [parsedData.id]: parsedData });
} catch (err) {
console.log("Error parsing stream data: ", err);
}
return parsedData.valid;
}
async function enforceMinimumLoadingTime(
startTime: number,
minimumLoadingTime: number
) {
const elapsedTime = Date.now() - startTime;
const remainingTime = minimumLoadingTime - elapsedTime;
if (remainingTime > 0) {
return new Promise((resolve) => setTimeout(resolve, remainingTime));
}
}
return (
<Transition
show={!open}
appear={true}
enter="transition ease-out duration-300"
enterFrom="translate-y-96"
enterTo="translate-y-0"
leave="transition ease-in duration-300"
leaveFrom="translate-y-0"
leaveTo="translate-y-96"
>
<div className={`fixed right-4` + (isBuilt ? " bottom-20" : " bottom-4")}>
<div
className="border flex justify-center align-center py-1 px-3 w-12 h-12 rounded-full bg-gradient-to-r from-blue-700 via-blue-600 to-blue-500 dark:border-gray-600 cursor-pointer"
onClick={() => {
if (!isBuilding) handleBuild(flow);
}}
>
<button>
<div className="flex gap-3 items-center">
{isBuilding ? (
// Render your loading animation here when isBuilding is true
<Loading style={{ color: "white" }} />
) : (
<Zap className="h-6 w-6" style={{ color: "white" }} />
)}
</div>
</button>
</div>
</div>
</Transition>
);
}

View file

@ -3,13 +3,26 @@ import {
Bars3CenterLeftIcon,
ChatBubbleBottomCenterTextIcon,
} from "@heroicons/react/24/outline";
import { MessagesSquare } from "lucide-react";
import { nodeColors } from "../../../utils";
import { PopUpContext } from "../../../contexts/popUpContext";
import { alertContext } from "../../../contexts/alertContext";
import { useContext } from "react";
import ChatModal from "../../../modals/chatModal";
export default function ChatTrigger({ open, setOpen }) {
const { openPopUp } = useContext(PopUpContext);
export default function ChatTrigger({ open, setOpen, isBuilt }) {
const { setErrorData } = useContext(alertContext);
function handleClick() {
if (isBuilt) {
setOpen(true);
} else {
setErrorData({
title: "Flow not built",
list: ["Please build the flow before chatting"],
});
}
}
return (
<Transition
show={!open}
@ -24,13 +37,11 @@ export default function ChatTrigger({ open, setOpen }) {
<div className="absolute bottom-4 right-3">
<div
className="border flex justify-center align-center py-1 px-3 w-12 h-12 rounded-full bg-gradient-to-r from-blue-500 via-blue-600 to-blue-700 dark:border-gray-600 cursor-pointer"
onClick={() => {
setOpen(true);
}}
onClick={handleClick}
>
<button>
<div className="flex gap-3 items-center">
<ChatBubbleBottomCenterTextIcon
<MessagesSquare
className="h-6 w-6 mt-1"
style={{ color: "white" }}
/>

View file

@ -1,13 +1,18 @@
import { useEffect, useRef, useState } from "react";
import { Context, useEffect, useRef, useState, useContext } from "react";
import ReactFlow, { useNodes } from "reactflow";
import { ChatMessageType, ChatType } from "../../types/chat";
import ChatTrigger from "./chatTrigger";
import BuildTrigger from "./buildTrigger";
import ChatModal from "../../modals/chatModal";
import _ from "lodash";
import _, { set } from "lodash";
import { getBuildStatus } from "../../controllers/API";
import { NodeType } from "../../types/flow";
export default function Chat({ flow }: ChatType) {
const [open, setOpen] = useState(false);
const [isBuilt, setIsBuilt] = useState(false);
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (
@ -23,10 +28,58 @@ export default function Chat({ flow }: ChatType) {
document.removeEventListener("keydown", handleKeyDown);
};
}, []);
useEffect(() => {
// Define an async function within the useEffect hook
const fetchBuildStatus = async () => {
const response = await getBuildStatus(flow.id);
setIsBuilt(response.built);
};
// Call the async function
fetchBuildStatus();
}, [flow]);
const prevNodesRef = useRef<any[] | undefined>();
const nodes = useNodes();
useEffect(() => {
const prevNodes = prevNodesRef.current;
const currentNodes = nodes.map(
(node: NodeType) => node.data.node.template.value
);
if (
prevNodes &&
JSON.stringify(prevNodes) !== JSON.stringify(currentNodes)
) {
setIsBuilt(false);
console.log("Nodes changed");
}
prevNodesRef.current = currentNodes;
}, [nodes]);
return (
<>
<ChatModal key={flow.id} flow={flow} open={open} setOpen={setOpen} />
<ChatTrigger open={open} setOpen={setOpen} />
{isBuilt ? (
<div>
<BuildTrigger
open={open}
flow={flow}
setIsBuilt={setIsBuilt}
isBuilt={isBuilt}
/>
<ChatModal key={flow.id} flow={flow} open={open} setOpen={setOpen} />
<ChatTrigger open={open} setOpen={setOpen} isBuilt={isBuilt} />
</div>
) : (
<BuildTrigger
open={open}
flow={flow}
setIsBuilt={setIsBuilt}
isBuilt={isBuilt}
/>
)}
</>
);
}

View file

@ -0,0 +1,39 @@
import { SVGProps } from "react";
// https://github.com/feathericons/feather/issues/695#issuecomment-1503699643
export const Loading = (props: SVGProps<SVGSVGElement>) => (
<svg
xmlns="http://www.w3.org/2000/svg"
width={24}
height={24}
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth={2}
strokeLinecap="round"
strokeLinejoin="round"
className="feather feather-circle"
{...props}
>
<circle cx={12} cy={12} r={10} strokeDasharray={63} strokeDashoffset={21}>
<animateTransform
attributeName="transform"
type="rotate"
from="0 12 12"
to="360 12 12"
dur="2s"
repeatCount="indefinite"
/>
<animate
attributeName="stroke-dashoffset"
dur="8s"
repeatCount="indefinite"
keyTimes="0; 0.5; 1"
values="-16; -47; -16"
calcMode="spline"
keySplines="0.4 0 0.2 1; 0.4 0 0.2 1"
/>
</circle>
</svg>
);
export default Loading;

View file

@ -0,0 +1,35 @@
import {
createContext,
useContext,
useState,
useEffect,
useCallback,
} from "react";
const initialValue = {
updateSSEData: ({}) => {},
sseData: {},
};
const SSEContext = createContext(initialValue);
export function useSSE() {
return useContext(SSEContext);
}
export function SSEProvider({ children }) {
const [sseData, setSSEData] = useState({});
const updateSSEData = useCallback((newData: any) => {
setSSEData((prevData) => ({
...prevData,
...newData,
}));
}, []);
return (
<SSEContext.Provider value={{ sseData, updateSSEData }}>
{children}
</SSEContext.Provider>
);
}

View file

@ -1,4 +1,9 @@
import { PromptTypeAPI, errorsTypeAPI } from "./../../types/api/index";
import {
BuildStatusTypeAPI,
PromptTypeAPI,
errorsTypeAPI,
InitTypeAPI,
} from "./../../types/api/index";
import { APIObjectType, sendAllProps } from "../../types/api/index";
import axios, { AxiosResponse } from "axios";
import { FlowStyleType, FlowType } from "../../types/flow";
@ -272,4 +277,14 @@ export async function getVersion() {
*/
export async function getHealth() {
return await axios.get("/health"); // Health is the only endpoint that doesn't require /api/v1
export async function getBuildStatus(
flowId: string
): Promise<BuildStatusTypeAPI> {
return await axios.get(`/api/v1/build/${flowId}/status`);
}
export async function postBuildInit(
flow: FlowType
): Promise<AxiosResponse<InitTypeAPI>> {
return await axios.post(`/api/v1/build/init`, flow);
}

View file

@ -22,7 +22,6 @@ import IntComponent from "../../components/intComponent";
import InputFileComponent from "../../components/inputFileComponent";
import PromptAreaComponent from "../../components/promptComponent";
import CodeAreaComponent from "../../components/codeAreaComponent";
import { TabsContext } from "../../contexts/tabsContext";
import {
Dialog,
DialogContent,
@ -33,7 +32,6 @@ import {
DialogTrigger,
} from "../../components/ui/dialog";
import { Button } from "../../components/ui/button";
import { Edit } from "lucide-react";
import { Badge } from "../../components/ui/badge";
export default function EditNodeModal({ data }: { data: NodeDataType }) {

View file

@ -3,7 +3,7 @@ import { ChatBubbleOvalLeftEllipsisIcon } from "@heroicons/react/24/outline";
import { Fragment, useContext, useEffect, useRef, useState } from "react";
import { FlowType, NodeType } from "../../types/flow";
import { alertContext } from "../../contexts/alertContext";
import { toNormalCase } from "../../utils";
import { toNormalCase, validateNodes } from "../../utils";
import { typesContext } from "../../contexts/typesContext";
import ChatMessage from "./chatMessage";
import { FaEraser } from "react-icons/fa";
@ -185,6 +185,17 @@ export default function ChatModal({
}://${host}${chatEndpoint}`;
}
function getWebSocketUrl(chatId, isDevelopment = false) {
const isSecureProtocol = window.location.protocol === "https:";
const webSocketProtocol = isSecureProtocol ? "wss" : "ws";
const host = isDevelopment ? "localhost:7860" : window.location.host;
const chatEndpoint = `/api/v1/chat/${chatId}`;
return `${
isDevelopment ? "ws" : webSocketProtocol
}://${host}${chatEndpoint}`;
}
function connectWS() {
try {
const urlWs = getWebSocketUrl(
@ -269,53 +280,6 @@ export default function ChatModal({
if (ref.current) ref.current.scrollIntoView({ behavior: "smooth" });
}, [chatHistory]);
function validateNode(n: NodeType): Array<string> {
if (!n.data?.node?.template || !Object.keys(n.data.node.template)) {
setNoticeData({
title:
"We've noticed a potential issue with a node in the flow. Please review it and, if necessary, submit a bug report with your exported flow file. Thank you for your help!",
});
return [];
}
const {
type,
node: { template },
} = n.data;
return Object.keys(template).reduce(
(errors: Array<string>, t) =>
errors.concat(
template[t].required &&
template[t].show &&
(template[t].value === undefined ||
template[t].value === null ||
template[t].value === "") &&
!reactFlowInstance
.getEdges()
.some(
(e) =>
e.targetHandle.split("|")[1] === t &&
e.targetHandle.split("|")[2] === n.id
)
? [
`${type} is missing ${
template.display_name
? template.display_name
: toNormalCase(template[t].name)
}.`,
]
: []
),
[] as string[]
);
}
function validateNodes() {
return reactFlowInstance
.getNodes()
.flatMap((n: NodeType) => validateNode(n));
}
const ref = useRef(null);
@ -327,7 +291,7 @@ export default function ChatModal({
function sendMessage() {
if (chatValue !== "") {
let nodeValidationErrors = validateNodes();
let nodeValidationErrors = validateNodes(reactFlowInstance);
if (nodeValidationErrors.length === 0) {
setLockChat(true);
let message = chatValue;

View file

@ -38,3 +38,11 @@ export type errorsTypeAPI = {
imports: { errors: Array<string> };
};
export type PromptTypeAPI = { input_variables: Array<string> };
export type BuildStatusTypeAPI = {
built: boolean;
};
export type InitTypeAPI = {
flowId: string;
};

View file

@ -17,7 +17,7 @@ import {
Bars3CenterLeftIcon,
} from "@heroicons/react/24/outline";
import { Connection, Edge, Node, ReactFlowInstance } from "reactflow";
import { FlowType, NodeDataType, NodeType } from "./types/flow";
import { FlowType, NodeType } from "./types/flow";
import { APITemplateType } from "./types/api";
import _ from "lodash";
import { ChromaIcon } from "./icons/ChromaIcon";
@ -737,3 +737,56 @@ export function buildTweaks(flow) {
return acc;
}, {});
}
export function validateNode(
n: NodeType,
reactFlowInstance: ReactFlowInstance
): Array<string> {
if (!n.data?.node?.template || !Object.keys(n.data.node.template)) {
return [
"We've noticed a potential issue with a node in the flow. Please review it and, if necessary, submit a bug report with your exported flow file. Thank you for your help!",
];
}
const {
type,
node: { template },
} = n.data;
return Object.keys(template).reduce(
(errors: Array<string>, t) =>
errors.concat(
template[t].required &&
template[t].show &&
(template[t].value === undefined ||
template[t].value === null ||
template[t].value === "") &&
!reactFlowInstance
.getEdges()
.some(
(e) =>
e.targetHandle.split("|")[1] === t &&
e.targetHandle.split("|")[2] === n.id
)
? [
`${type} is missing ${
template.display_name
? template.display_name
: toNormalCase(template[t].name)
}.`,
]
: []
),
[] as string[]
);
}
export function validateNodes(reactFlowInstance: ReactFlowInstance) {
if (reactFlowInstance.getNodes().length === 0) {
return [
"No nodes found in the flow. Please add at least one node to the flow.",
];
}
return reactFlowInstance
.getNodes()
.flatMap((n: NodeType) => validateNode(n, reactFlowInstance));
}

View file

@ -12,11 +12,9 @@ const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
changeOrigin: true,
secure: false,
ws: true,
// rewrite: (path) => `/api/v1${path}`,
};
return proxyObj;
}, {});
export default defineConfig(() => {
return {
build: {

View file

@ -66,6 +66,12 @@ def get_graph(_type="basic"):
return Graph(nodes, edges)
@pytest.fixture
def basic_graph_data():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
return json.load(f)
@pytest.fixture
def basic_graph():
return get_graph()

View file

@ -1,47 +1,47 @@
import json
from unittest.mock import patch
from fastapi import WebSocketDisconnect
from fastapi.testclient import TestClient
# from langflow.chat.manager import ChatManager
import pytest
def test_websocket_connection(client: TestClient):
with client.websocket_connect("api/v1/chat/test_client") as websocket:
assert websocket.scope["client"] == ["testclient", 50000]
assert websocket.scope["path"] == "/api/v1/chat/test_client"
def test_init_build(client):
response = client.post(
"api/v1/build/init", json={"id": "test", "data": {"key": "value"}}
)
assert response.status_code == 200
assert response.json() == {"flowId": "test"}
def test_chat_history(client: TestClient):
# Mock the process_graph function to return a specific value
with patch("langflow.chat.manager.process_graph") as mock_process_graph:
mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
def test_stream_build(client):
client.post("/build/init", json={"id": "stream_test", "data": {"key": "value"}})
with client.websocket_connect("api/v1/chat/test_client") as websocket:
# First message should be the history
history = websocket.receive_json()
assert history == [] # Empty history
# Send a message
payload = {"message": "Hello"}
websocket.send_json(json.dumps(payload))
# Test the stream
response = client.get("api/v1/build/stream/stream_test")
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# Receive the response from the server
response = websocket.receive_json()
assert response == {
"is_bot": True,
"message": None,
"type": "start",
"intermediate_steps": "",
"files": [],
}
# Send another message
payload = {"message": "How are you?"}
websocket.send_json(json.dumps(payload))
# Receive the response from the server
response = websocket.receive_json()
assert response == {
"is_bot": True,
"message": "Hello, I'm a mock response!",
"type": "end",
"intermediate_steps": "",
"files": [],
}
def test_websocket_endpoint(client):
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect(
"api/v1/chat/non_existing_client_id"
) as websocket:
websocket.send_json({"type": "test"})
data = websocket.receive_json()
assert "Please, build the flow before sending messages" in data["message"]
def test_websocket_endpoint_after_build(client, basic_graph_data):
# Assuming your websocket_endpoint uses chat_manager which caches data from stream_build
client.post("/build/init", json=basic_graph_data)
client.get("/build/stream/websocket_test")
# There should be more to test here, but it depends on the inner workings of your websocket handler
# and how your chat_manager and other classes behave. The following is just an example structure.
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("api/v1/chat/websocket_test") as websocket:
websocket.send_json({"type": "test"})
# Perform assertions here, based on what you expect the websocket to return
# data = websocket.receive_json()
# assert ...