diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 96647c0d1..7acdf32ce 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -1,26 +1,119 @@ +import json from fastapi import ( APIRouter, + HTTPException, WebSocket, - WebSocketDisconnect, WebSocketException, status, ) +from fastapi.responses import StreamingResponse +from langflow.api.v1.schemas import BuiltResponse, InitResponse from langflow.chat.manager import ChatManager +from langflow.graph.graph.base import Graph from langflow.utils.logger import logger router = APIRouter(tags=["Chat"]) chat_manager = ChatManager() +flow_data_store = {} @router.websocket("/chat/{client_id}") -async def websocket_endpoint(client_id: str, websocket: WebSocket): +async def chat(client_id: str, websocket: WebSocket): """Websocket endpoint for chat.""" try: - await chat_manager.handle_websocket(client_id, websocket) + if client_id in chat_manager.in_memory_cache: + await chat_manager.handle_websocket(client_id, websocket) + else: + message = "Please, build the flow before sending messages" + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message) except WebSocketException as exc: logger.error(exc) await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) - except WebSocketDisconnect as exc: + + +@router.post("/build/init", response_model=InitResponse, status_code=201) +async def init_build(graph_data: dict): + """Initialize the build by storing graph data and returning a unique session ID.""" + + try: + flow_id = graph_data.get("id") + + flow_data_store[flow_id] = graph_data + + return InitResponse(flowId=flow_id) + except Exception as exc: logger.error(exc) - await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc)) + return HTTPException(status_code=500, detail=str(exc)) + + +@router.get("/build/{flow_id}/status", response_model=BuiltResponse) +async def build_status(flow_id: str): + """Check the flow_id is in the flow_data_store.""" + try: + built = flow_id in flow_data_store and not isinstance( + flow_data_store[flow_id], dict + ) + + return BuiltResponse( + built=built, + ) + + except Exception as exc: + logger.error(exc) + return HTTPException(status_code=500, detail=str(exc)) + + +@router.get("/build/stream/{flow_id}", response_class=StreamingResponse) +async def stream_build(flow_id: str): + """Stream the build process based on stored flow data.""" + + async def event_stream(flow_id): + final_response = json.dumps({"end_of_stream": True}) + try: + if flow_id not in flow_data_store: + error_message = "Invalid session ID" + yield f"data: {json.dumps({'error': error_message})}\n\n" + return + + graph_data = flow_data_store[flow_id].get("data") + + if not graph_data: + error_message = "No data provided" + yield f"data: {json.dumps({'error': error_message})}\n\n" + return + + logger.debug("Building langchain object") + graph = Graph.from_payload(graph_data) + for node in graph.generator_build(): + try: + node.build() + params = node._built_object_repr() + valid = True + logger.debug( + f"Building node {params[:50]}{'...' if len(params) > 50 else ''}" + ) + except Exception as exc: + params = str(exc) + valid = False + + response = json.dumps( + { + "valid": valid, + "params": params, + "id": node.id, + } + ) + yield f"data: {response}\n\n" + + chat_manager.set_cache(flow_id, graph.build()) + except Exception: + logger.error("Error while building the flow") + finally: + yield f"data: {final_response}\n\n" + + try: + return StreamingResponse(event_stream(flow_id), media_type="text/event-stream") + except Exception as exc: + logger.error(exc) + raise HTTPException(status_code=500, detail=str(exc)) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index aae4a1df3..714f0df7f 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -93,3 +93,11 @@ class FlowListCreate(BaseModel): class FlowListRead(BaseModel): flows: List[FlowRead] + + +class InitResponse(BaseModel): + flowId: str + + +class BuiltResponse(BaseModel): + built: bool diff --git a/src/backend/langflow/cache/__init__.py b/src/backend/langflow/cache/__init__.py index 583d5ac6d..723aa9e18 100644 --- a/src/backend/langflow/cache/__init__.py +++ b/src/backend/langflow/cache/__init__.py @@ -1 +1,7 @@ -from langflow.cache.manager import cache_manager # noqa +from langflow.cache.manager import cache_manager +from langflow.cache.flow import InMemoryCache + +__all__ = [ + "cache_manager", + "InMemoryCache", +] diff --git a/src/backend/langflow/cache/base.py b/src/backend/langflow/cache/base.py index 0f1ff5d92..96639774b 100644 --- a/src/backend/langflow/cache/base.py +++ b/src/backend/langflow/cache/base.py @@ -1,154 +1,95 @@ -import base64 -import contextlib -import functools -import hashlib -import json -import os -import tempfile -from collections import OrderedDict -from pathlib import Path -from typing import Any, Dict - -import dill # type: ignore - -CACHE: Dict[str, Any] = {} +import abc -def create_cache_folder(func): - def wrapper(*args, **kwargs): - # Get the destination folder - cache_path = Path(tempfile.gettempdir()) / PREFIX - - # Create the destination folder if it doesn't exist - os.makedirs(cache_path, exist_ok=True) - - return func(*args, **kwargs) - - return wrapper - - -def memoize_dict(maxsize=128): - cache = OrderedDict() - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - hashed = compute_dict_hash(args[0]) - key = (func.__name__, hashed, frozenset(kwargs.items())) - if key not in cache: - result = func(*args, **kwargs) - cache[key] = result - if len(cache) > maxsize: - cache.popitem(last=False) - else: - result = cache[key] - return result - - def clear_cache(): - cache.clear() - - wrapper.clear_cache = clear_cache # type: ignore - wrapper.cache = cache # type: ignore - return wrapper - - return decorator - - -PREFIX = "langflow_cache" - - -@create_cache_folder -def clear_old_cache_files(max_cache_size: int = 3): - cache_dir = Path(tempfile.gettempdir()) / PREFIX - cache_files = list(cache_dir.glob("*.dill")) - - if len(cache_files) > max_cache_size: - cache_files_sorted_by_mtime = sorted( - cache_files, key=lambda x: x.stat().st_mtime, reverse=True - ) - - for cache_file in cache_files_sorted_by_mtime[max_cache_size:]: - with contextlib.suppress(OSError): - os.remove(cache_file) - - -def compute_dict_hash(graph_data): - graph_data = filter_json(graph_data) - - cleaned_graph_json = json.dumps(graph_data, sort_keys=True) - return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest() - - -def filter_json(json_data): - filtered_data = json_data.copy() - - # Remove 'viewport' and 'chatHistory' keys - if "viewport" in filtered_data: - del filtered_data["viewport"] - if "chatHistory" in filtered_data: - del filtered_data["chatHistory"] - - # Filter nodes - if "nodes" in filtered_data: - for node in filtered_data["nodes"]: - if "position" in node: - del node["position"] - if "positionAbsolute" in node: - del node["positionAbsolute"] - if "selected" in node: - del node["selected"] - if "dragging" in node: - del node["dragging"] - - return filtered_data - - -@create_cache_folder -def save_binary_file(content: str, file_name: str, accepted_types: list[str]) -> str: +class BaseCache(abc.ABC): """ - Save a binary file to the specified folder. - - Args: - content: The content of the file as a bytes object. - file_name: The name of the file, including its extension. - - Returns: - The path to the saved file. + Abstract base class for a cache. """ - if not any(file_name.endswith(suffix) for suffix in accepted_types): - raise ValueError(f"File {file_name} is not accepted") - # Get the destination folder - cache_path = Path(tempfile.gettempdir()) / PREFIX - if not content: - raise ValueError("Please, reload the file in the loader.") - data = content.split(",")[1] - decoded_bytes = base64.b64decode(data) + @abc.abstractmethod + def get(self, key): + """ + Retrieve an item from the cache. - # Create the full file path - file_path = os.path.join(cache_path, file_name) + Args: + key: The key of the item to retrieve. - # Save the binary content to the file - with open(file_path, "wb") as file: - file.write(decoded_bytes) + Returns: + The value associated with the key, or None if the key is not found. + """ + pass - return file_path + @abc.abstractmethod + def set(self, key, value): + """ + Add an item to the cache. + Args: + key: The key of the item. + value: The value to cache. + """ + pass -@create_cache_folder -def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool): - cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill" - with cache_path.open("wb") as cache_file: - dill.dump(chat_data, cache_file) + @abc.abstractmethod + def delete(self, key): + """ + Remove an item from the cache. - if clean_old_cache_files: - clear_old_cache_files() + Args: + key: The key of the item to remove. + """ + pass + @abc.abstractmethod + def clear(self): + """ + Clear all items from the cache. + """ + pass -@create_cache_folder -def load_cache(hash_val): - cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill" - if cache_path.exists(): - with cache_path.open("rb") as cache_file: - return dill.load(cache_file) - return None + @abc.abstractmethod + def __contains__(self, key): + """ + Check if the key is in the cache. + + Args: + key: The key of the item to check. + + Returns: + True if the key is in the cache, False otherwise. + """ + pass + + @abc.abstractmethod + def __getitem__(self, key): + """ + Retrieve an item from the cache using the square bracket notation. + + Args: + key: The key of the item to retrieve. + + Returns: + The value associated with the key, or None if the key is not found. + """ + pass + + @abc.abstractmethod + def __setitem__(self, key, value): + """ + Add an item to the cache using the square bracket notation. + + Args: + key: The key of the item. + value: The value to cache. + """ + pass + + @abc.abstractmethod + def __delitem__(self, key): + """ + Remove an item from the cache using the square bracket notation. + + Args: + key: The key of the item to remove. + """ + pass diff --git a/src/backend/langflow/cache/flow.py b/src/backend/langflow/cache/flow.py new file mode 100644 index 000000000..6d8fee977 --- /dev/null +++ b/src/backend/langflow/cache/flow.py @@ -0,0 +1,146 @@ +import threading +import time +from collections import OrderedDict + +from langflow.cache.base import BaseCache + + +class InMemoryCache(BaseCache): + """ + A simple in-memory cache using an OrderedDict. + + This cache supports setting a maximum size and expiration time for cached items. + When the cache is full, it uses a Least Recently Used (LRU) eviction policy. + Thread-safe using a threading Lock. + + Attributes: + max_size (int, optional): Maximum number of items to store in the cache. + expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour. + + Example: + + cache = InMemoryCache(max_size=3, expiration_time=5) + + # setting cache values + cache.set("a", 1) + cache.set("b", 2) + cache["c"] = 3 + + # getting cache values + a = cache.get("a") + b = cache["b"] + """ + + def __init__(self, max_size=None, expiration_time=60 * 60): + """ + Initialize a new InMemoryCache instance. + + Args: + max_size (int, optional): Maximum number of items to store in the cache. + expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour. + """ + self._cache = OrderedDict() + self._lock = threading.Lock() + self.max_size = max_size + self.expiration_time = expiration_time + + def get(self, key): + """ + Retrieve an item from the cache. + + Args: + key: The key of the item to retrieve. + + Returns: + The value associated with the key, or None if the key is not found or the item has expired. + """ + with self._lock: + if key in self._cache: + item = self._cache.pop(key) + if ( + self.expiration_time is None + or time.time() - item["time"] < self.expiration_time + ): + # Move the key to the end to make it recently used + self._cache[key] = item + return item["value"] + else: + self.delete(key) + return None + + def set(self, key, value): + """ + Add an item to the cache. + + If the cache is full, the least recently used item is evicted. + + Args: + key: The key of the item. + value: The value to cache. + """ + with self._lock: + if key in self._cache: + # Remove existing key before re-inserting to update order + self.delete(key) + elif self.max_size and len(self._cache) >= self.max_size: + # Remove least recently used item + self._cache.popitem(last=False) + self._cache[key] = {"value": value, "time": time.time()} + + def get_or_set(self, key, value): + """ + Retrieve an item from the cache. If the item does not exist, set it with the provided value. + + Args: + key: The key of the item. + value: The value to cache if the item doesn't exist. + + Returns: + The cached value associated with the key. + """ + with self._lock: + if key in self._cache: + return self.get(key) + self.set(key, value) + return value + + def delete(self, key): + """ + Remove an item from the cache. + + Args: + key: The key of the item to remove. + """ + # with self._lock: + self._cache.pop(key, None) + + def clear(self): + """ + Clear all items from the cache. + """ + with self._lock: + self._cache.clear() + + def __contains__(self, key): + """Check if the key is in the cache.""" + return key in self._cache + + def __getitem__(self, key): + """Retrieve an item from the cache using the square bracket notation.""" + return self.get(key) + + def __setitem__(self, key, value): + """Add an item to the cache using the square bracket notation.""" + self.set(key, value) + + def __delitem__(self, key): + """Remove an item from the cache using the square bracket notation.""" + self.delete(key) + + def __len__(self): + """Return the number of items in the cache.""" + return len(self._cache) + + def __repr__(self): + """Return a string representation of the InMemoryCache instance.""" + return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})" diff --git a/src/backend/langflow/cache/manager.py b/src/backend/langflow/cache/manager.py index 947f5ce21..13b281008 100644 --- a/src/backend/langflow/cache/manager.py +++ b/src/backend/langflow/cache/manager.py @@ -54,7 +54,7 @@ class CacheManager(Subject): def __init__(self): super().__init__() - self.CACHE = {} + self._cache = {} self.current_client_id = None self.current_cache = {} @@ -68,12 +68,12 @@ class CacheManager(Subject): """ previous_client_id = self.current_client_id self.current_client_id = client_id - self.current_cache = self.CACHE.setdefault(client_id, {}) + self.current_cache = self._cache.setdefault(client_id, {}) try: yield finally: self.current_client_id = previous_client_id - self.current_cache = self.CACHE.get(self.current_client_id, {}) + self.current_cache = self._cache.get(self.current_client_id, {}) def add(self, name: str, obj: Any, obj_type: str, extension: Optional[str] = None): """ diff --git a/src/backend/langflow/cache/utils.py b/src/backend/langflow/cache/utils.py new file mode 100644 index 000000000..0f1ff5d92 --- /dev/null +++ b/src/backend/langflow/cache/utils.py @@ -0,0 +1,154 @@ +import base64 +import contextlib +import functools +import hashlib +import json +import os +import tempfile +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict + +import dill # type: ignore + +CACHE: Dict[str, Any] = {} + + +def create_cache_folder(func): + def wrapper(*args, **kwargs): + # Get the destination folder + cache_path = Path(tempfile.gettempdir()) / PREFIX + + # Create the destination folder if it doesn't exist + os.makedirs(cache_path, exist_ok=True) + + return func(*args, **kwargs) + + return wrapper + + +def memoize_dict(maxsize=128): + cache = OrderedDict() + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + hashed = compute_dict_hash(args[0]) + key = (func.__name__, hashed, frozenset(kwargs.items())) + if key not in cache: + result = func(*args, **kwargs) + cache[key] = result + if len(cache) > maxsize: + cache.popitem(last=False) + else: + result = cache[key] + return result + + def clear_cache(): + cache.clear() + + wrapper.clear_cache = clear_cache # type: ignore + wrapper.cache = cache # type: ignore + return wrapper + + return decorator + + +PREFIX = "langflow_cache" + + +@create_cache_folder +def clear_old_cache_files(max_cache_size: int = 3): + cache_dir = Path(tempfile.gettempdir()) / PREFIX + cache_files = list(cache_dir.glob("*.dill")) + + if len(cache_files) > max_cache_size: + cache_files_sorted_by_mtime = sorted( + cache_files, key=lambda x: x.stat().st_mtime, reverse=True + ) + + for cache_file in cache_files_sorted_by_mtime[max_cache_size:]: + with contextlib.suppress(OSError): + os.remove(cache_file) + + +def compute_dict_hash(graph_data): + graph_data = filter_json(graph_data) + + cleaned_graph_json = json.dumps(graph_data, sort_keys=True) + return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest() + + +def filter_json(json_data): + filtered_data = json_data.copy() + + # Remove 'viewport' and 'chatHistory' keys + if "viewport" in filtered_data: + del filtered_data["viewport"] + if "chatHistory" in filtered_data: + del filtered_data["chatHistory"] + + # Filter nodes + if "nodes" in filtered_data: + for node in filtered_data["nodes"]: + if "position" in node: + del node["position"] + if "positionAbsolute" in node: + del node["positionAbsolute"] + if "selected" in node: + del node["selected"] + if "dragging" in node: + del node["dragging"] + + return filtered_data + + +@create_cache_folder +def save_binary_file(content: str, file_name: str, accepted_types: list[str]) -> str: + """ + Save a binary file to the specified folder. + + Args: + content: The content of the file as a bytes object. + file_name: The name of the file, including its extension. + + Returns: + The path to the saved file. + """ + if not any(file_name.endswith(suffix) for suffix in accepted_types): + raise ValueError(f"File {file_name} is not accepted") + + # Get the destination folder + cache_path = Path(tempfile.gettempdir()) / PREFIX + if not content: + raise ValueError("Please, reload the file in the loader.") + data = content.split(",")[1] + decoded_bytes = base64.b64decode(data) + + # Create the full file path + file_path = os.path.join(cache_path, file_name) + + # Save the binary content to the file + with open(file_path, "wb") as file: + file.write(decoded_bytes) + + return file_path + + +@create_cache_folder +def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool): + cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill" + with cache_path.open("wb") as cache_file: + dill.dump(chat_data, cache_file) + + if clean_old_cache_files: + clear_old_cache_files() + + +@create_cache_folder +def load_cache(hash_val): + cache_path = Path(tempfile.gettempdir()) / PREFIX / f"{hash_val}.dill" + if cache_path.exists(): + with cache_path.open("rb") as cache_file: + return dill.load(cache_file) + return None diff --git a/src/backend/langflow/chat/manager.py b/src/backend/langflow/chat/manager.py index d24057b68..7c3a08240 100644 --- a/src/backend/langflow/chat/manager.py +++ b/src/backend/langflow/chat/manager.py @@ -10,7 +10,9 @@ from langflow.utils.logger import logger import asyncio import json -from typing import Dict, List +from typing import Any, Dict, List + +from langflow.cache.flow import InMemoryCache class ChatHistory(Subject): @@ -46,6 +48,7 @@ class ChatManager: self.chat_history = ChatHistory() self.cache_manager = cache_manager self.cache_manager.attach(self.update) + self.in_memory_cache = InMemoryCache() def on_chat_history_update(self): """Send the last chat message to the client.""" @@ -99,24 +102,30 @@ class ChatManager: websocket = self.active_connections[client_id] await websocket.send_json(message.dict()) - async def process_message(self, client_id: str, payload: Dict): + async def close_connection(self, client_id: str, code: int, reason: str): + if websocket := self.active_connections[client_id]: + await websocket.close(code=code, reason=reason) + self.disconnect(client_id) + + async def process_message( + self, client_id: str, payload: Dict, langchain_object: Any + ): # Process the graph data and chat message chat_message = payload.pop("message", "") chat_message = ChatMessage(message=chat_message) self.chat_history.add_message(client_id, chat_message) - graph_data = payload + # graph_data = payload start_resp = ChatResponse(message=None, type="start", intermediate_steps="") await self.send_json(client_id, start_resp) - is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1 + # is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1 # Generate result and thought try: logger.debug("Generating result and thought") result, intermediate_steps = await process_graph( - graph_data=graph_data, - is_first_message=is_first_message, + langchain_object=langchain_object, chat_message=chat_message, websocket=self.active_connections[client_id], ) @@ -149,6 +158,14 @@ class ChatManager: await self.send_json(client_id, response) self.chat_history.add_message(client_id, response) + def set_cache(self, client_id: str, langchain_object: Any) -> bool: + """ + Set the cache for a client. + """ + + self.in_memory_cache.set(client_id, langchain_object) + return client_id in self.in_memory_cache + async def handle_websocket(self, client_id: str, websocket: WebSocket): await self.connect(client_id, websocket) @@ -169,22 +186,24 @@ class ChatManager: continue with self.cache_manager.set_client_id(client_id): - await self.process_message(client_id, payload) + langchain_object = self.in_memory_cache.get(client_id) + await self.process_message(client_id, payload, langchain_object) except Exception as e: # Handle any exceptions that might occur - logger.exception(e) - # send a message to the client - await self.active_connections[client_id].close( - code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120] + logger.error(e) + await self.close_connection( + client_id=client_id, + code=status.WS_1011_INTERNAL_ERROR, + reason=str(e)[:120], ) - self.disconnect(client_id) finally: try: - connection = self.active_connections.get(client_id) - if connection: - await connection.close(code=1000, reason="Client disconnected") - self.disconnect(client_id) + await self.close_connection( + client_id=client_id, + code=status.WS_1000_NORMAL_CLOSURE, + reason="Client disconnected", + ) except Exception as e: - logger.exception(e) + logger.error(e) self.disconnect(client_id) diff --git a/src/backend/langflow/chat/utils.py b/src/backend/langflow/chat/utils.py index 410a442be..2e2ee367f 100644 --- a/src/backend/langflow/chat/utils.py +++ b/src/backend/langflow/chat/utils.py @@ -1,23 +1,15 @@ from fastapi import WebSocket from langflow.api.v1.schemas import ChatMessage -from langflow.processing.process import ( - load_or_build_langchain_object, -) from langflow.processing.base import get_result_and_steps from langflow.interface.utils import try_setting_streaming_options from langflow.utils.logger import logger -from typing import Dict - - async def process_graph( - graph_data: Dict, - is_first_message: bool, + langchain_object, chat_message: ChatMessage, websocket: WebSocket, ): - langchain_object = load_or_build_langchain_object(graph_data, is_first_message) langchain_object = try_setting_streaming_options(langchain_object, websocket) logger.debug("Loaded langchain object") diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 425825039..ae4b37e2c 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Type, Union +from typing import Dict, Generator, List, Type, Union from langflow.graph.edge.base import Edge from langflow.graph.graph.constants import VERTEX_TYPE_MAP @@ -106,6 +106,47 @@ class Graph: raise ValueError("No root node found") return root_node.build() + def topological_sort(self) -> List[Vertex]: + """ + Performs a topological sort of the vertices in the graph. + + Returns: + List[Vertex]: A list of vertices in topological order. + + Raises: + ValueError: If the graph contains a cycle. + """ + # States: 0 = unvisited, 1 = visiting, 2 = visited + state = {node: 0 for node in self.nodes} + sorted_vertices = [] + + def dfs(node): + if state[node] == 1: + # We have a cycle + raise ValueError( + "Graph contains a cycle, cannot perform topological sort" + ) + if state[node] == 0: + state[node] = 1 + for edge in node.edges: + if edge.source == node: + dfs(edge.target) + state[node] = 2 + sorted_vertices.append(node) + + # Visit each node + for node in self.nodes: + if state[node] == 0: + dfs(node) + + return list(reversed(sorted_vertices)) + + def generator_build(self) -> Generator: + """Builds each vertex in the graph and yields it.""" + sorted_vertices = self.topological_sort() + logger.info("Sorted vertices: %s", sorted_vertices) + yield from sorted_vertices + def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]: """Returns the neighbors of a node.""" neighbors: Dict[Vertex, int] = {} diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index bb6ff34dc..2900ba538 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -1,4 +1,4 @@ -from langflow.cache import base as cache_utils +from langflow.cache import utils as cache_utils from langflow.graph.vertex.constants import DIRECT_TYPES from langflow.interface import loading from langflow.interface.listing import ALL_TYPES_DICT diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 89f71fd8b..6fe4c089b 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -1,4 +1,4 @@ -from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict +from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict from langflow.graph import Graph from langflow.utils.logger import logger diff --git a/src/frontend/src/CustomNodes/GenericNode/index.tsx b/src/frontend/src/CustomNodes/GenericNode/index.tsx index 783c86072..fdf830408 100644 --- a/src/frontend/src/CustomNodes/GenericNode/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/index.tsx @@ -1,26 +1,18 @@ -import { - classNames, - nodeColors, - nodeIcons, - toNormalCase, - toTitleCase, -} from "../../utils"; +import { classNames, nodeColors, nodeIcons, toTitleCase } from "../../utils"; import ParameterComponent from "./components/parameterComponent"; import { typesContext } from "../../contexts/typesContext"; -import { useContext, useState, useEffect, useRef, Fragment } from "react"; +import { useContext, useState, useEffect, useRef } from "react"; import { NodeDataType } from "../../types/flow"; import { alertContext } from "../../contexts/alertContext"; import { PopUpContext } from "../../contexts/popUpContext"; import NodeModal from "../../modals/NodeModal"; -import { useCallback } from "react"; -import { TabsContext } from "../../contexts/tabsContext"; -import { debounce } from "../../utils"; import Tooltip from "../../components/TooltipComponent"; import { NodeToolbar } from "reactflow"; import NodeToolbarComponent from "../../pages/FlowPage/components/nodeToolbarComponent"; import ShadTooltip from "../../components/ShadTooltipComponent"; -import { postValidateNode } from "../../controllers/API"; +import { useSSE } from "../../contexts/SSEContext"; + export default function GenericNode({ data, selected, @@ -31,46 +23,30 @@ export default function GenericNode({ const { setErrorData } = useContext(alertContext); const showError = useRef(true); const { types, deleteNode } = useContext(typesContext); - const { openPopUp } = useContext(PopUpContext); - const { closePopUp } = useContext(PopUpContext); + + const { closePopUp, openPopUp } = useContext(PopUpContext); const Icon = nodeIcons[data.type] || nodeIcons[types[data.type]]; const [validationStatus, setValidationStatus] = useState(null); // State for outline color - const [isValid, setIsValid] = useState(false); - const { reactFlowInstance } = useContext(typesContext); - const [params, setParams] = useState([]); + const { sseData } = useSSE(); + // useEffect(() => { + // if (reactFlowInstance) { + // setParams(Object.values(reactFlowInstance.toObject())); + // } + // }, [save]); + + // New useEffect to watch for changes in sseData and update validation status useEffect(() => { - if (reactFlowInstance) { - setParams(Object.values(reactFlowInstance.toObject())); + const relevantData = sseData[data.id]; + if (relevantData) { + // Extract validation information from relevantData and update the validationStatus state + setValidationStatus(relevantData); + } else { + setValidationStatus(null); } - }, []); - - const validateNode = useCallback( - debounce(async () => { - try { - const response = await postValidateNode( - data.id, - reactFlowInstance.toObject() - ); - - if (response.status === 200) { - let jsonResponseParsed = await JSON.parse(response.data); - setValidationStatus(jsonResponseParsed); - } - } catch (error) { - // console.error("Error validating node:", error); - setValidationStatus("error"); - } - }, 1000), // Adjust the debounce delay (500ms) as needed - [reactFlowInstance, data.id] - ); - useEffect(() => { - if (params.length > 0) { - validateNode(); - } - }, [params, validateNode]); + }, [sseData, data.id]); if (!Icon) { if (showError.current) { diff --git a/src/frontend/src/components/chatComponent/buildTrigger/index.tsx b/src/frontend/src/components/chatComponent/buildTrigger/index.tsx new file mode 100644 index 000000000..56d587236 --- /dev/null +++ b/src/frontend/src/components/chatComponent/buildTrigger/index.tsx @@ -0,0 +1,155 @@ +import { useState, useContext } from "react"; +import { Transition } from "@headlessui/react"; +import { Zap } from "lucide-react"; +import { validateNodes } from "../../../utils"; +import { FlowType } from "../../../types/flow"; +import Loading from "../../../components/ui/loading"; +import { useSSE } from "../../../contexts/SSEContext"; +import { typesContext } from "../../../contexts/typesContext"; +import { alertContext } from "../../../contexts/alertContext"; +import { postBuildInit } from "../../../controllers/API"; + +export default function BuildTrigger({ + open, + flow, + setIsBuilt, + isBuilt, +}: { + open: boolean; + flow: FlowType; + setIsBuilt: any; + isBuilt: boolean; +}) { + const [isBuilding, setIsBuilding] = useState(false); + + const { updateSSEData } = useSSE(); + const { reactFlowInstance } = useContext(typesContext); + const { setErrorData } = useContext(alertContext); + + async function handleBuild(flow: FlowType) { + const errors = validateNodes(reactFlowInstance); + if (errors.length > 0) { + setErrorData({ + title: "Oops! Looks like you missed something", + list: errors, + }); + return; + } + const minimumLoadingTime = 200; // in milliseconds + const startTime = Date.now(); + setIsBuilding(true); + + try { + const allNodesValid = await streamNodeData(flow); + await enforceMinimumLoadingTime(startTime, minimumLoadingTime); + setIsBuilt(allNodesValid); + } catch (error) { + console.error("Error:", error); + } finally { + setIsBuilding(false); + } + } + + async function streamNodeData(flow: FlowType) { + // Step 1: Make a POST request to send the flow data and receive a unique session ID + const response = await postBuildInit(flow); + const { flowId } = response.data; + + // Step 2: Use the session ID to establish an SSE connection using EventSource + let validationResults = []; + let finished = false; + const apiUrl = `/api/v1/build/stream/${flowId}`; + const eventSource = new EventSource(apiUrl); + try{ + eventSource.onmessage = (event) => { + // If the event is parseable, return + if (!event.data) { + return; + } + const parsedData = JSON.parse(event.data); + // if the event is the end of the stream, close the connection + if (parsedData.end_of_stream) { + eventSource.close(); + + return; + } + // Otherwise, process the data + const isValid = processStreamResult(parsedData); + validationResults.push(isValid); + }; + + eventSource.onerror = (error) => { + console.error("EventSource failed:", error); + eventSource.close(); + }; + // Step 3: Wait for the stream to finish + while (!finished) { + await new Promise((resolve) => setTimeout(resolve, 100)); + finished = validationResults.length === flow.data.nodes.length; + } + // Step 4: Return true if all nodes are valid, false otherwise + return validationResults.every((result) => result); + } + catch(e){ + console.log(e) + eventSource.close(); + return false; + } + } + + function processStreamResult(parsedData) { + // Process each chunk of data here + // Parse the chunk and update the context + try { + updateSSEData({ [parsedData.id]: parsedData }); + } catch (err) { + console.log("Error parsing stream data: ", err); + } + return parsedData.valid; + } + + async function enforceMinimumLoadingTime( + startTime: number, + minimumLoadingTime: number + ) { + const elapsedTime = Date.now() - startTime; + const remainingTime = minimumLoadingTime - elapsedTime; + + if (remainingTime > 0) { + return new Promise((resolve) => setTimeout(resolve, remainingTime)); + } + } + + return ( + +
+
{ + if (!isBuilding) handleBuild(flow); + }} + > + +
+
+
+ ); +} diff --git a/src/frontend/src/components/chatComponent/chatTrigger/index.tsx b/src/frontend/src/components/chatComponent/chatTrigger/index.tsx index daa49f34f..1651a91bc 100644 --- a/src/frontend/src/components/chatComponent/chatTrigger/index.tsx +++ b/src/frontend/src/components/chatComponent/chatTrigger/index.tsx @@ -3,13 +3,26 @@ import { Bars3CenterLeftIcon, ChatBubbleBottomCenterTextIcon, } from "@heroicons/react/24/outline"; +import { MessagesSquare } from "lucide-react"; import { nodeColors } from "../../../utils"; -import { PopUpContext } from "../../../contexts/popUpContext"; +import { alertContext } from "../../../contexts/alertContext"; import { useContext } from "react"; import ChatModal from "../../../modals/chatModal"; -export default function ChatTrigger({ open, setOpen }) { - const { openPopUp } = useContext(PopUpContext); +export default function ChatTrigger({ open, setOpen, isBuilt }) { + const { setErrorData } = useContext(alertContext); + + function handleClick() { + if (isBuilt) { + setOpen(true); + } else { + setErrorData({ + title: "Flow not built", + list: ["Please build the flow before chatting"], + }); + } + } + return (
{ - setOpen(true); - }} + onClick={handleClick} >