diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index be9d6802c..fd0232189 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -1,12 +1,6 @@ -from fastapi import ( - APIRouter, - HTTPException, - WebSocket, - WebSocketException, - status, -) +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketException, status from fastapi.responses import StreamingResponse -from langflow.api.v1.schemas import BuiltResponse, InitResponse, StreamData +from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData from langflow.chat.manager import ChatManager from langflow.graph.graph.base import Graph @@ -32,15 +26,29 @@ async def chat(client_id: str, websocket: WebSocket): await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) -@router.post("/build/init", response_model=InitResponse, status_code=201) -async def init_build(graph_data: dict): +@router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201) +async def init_build(graph_data: dict, flow_id: str): """Initialize the build by storing graph data and returning a unique session ID.""" try: - flow_id = graph_data.get("id") if flow_id is None: raise ValueError("No ID provided") - flow_data_store[flow_id] = graph_data + # Check if already building + if ( + flow_id in flow_data_store + and flow_data_store[flow_id]["status"] == BuildStatus.IN_PROGRESS + ): + return InitResponse(flowId=flow_id) + + # Delete from cache if already exists + if flow_id in chat_manager.in_memory_cache: + with chat_manager.in_memory_cache._lock: + chat_manager.in_memory_cache.delete(flow_id) + logger.debug(f"Deleted flow {flow_id} from cache") + flow_data_store[flow_id] = { + "graph_data": graph_data, + "status": BuildStatus.IN_PROGRESS, + } return InitResponse(flowId=flow_id) except Exception as exc: @@ -52,8 +60,9 @@ async def init_build(graph_data: dict): async def build_status(flow_id: str): """Check the flow_id is in the flow_data_store.""" try: - built = flow_id in flow_data_store and not isinstance( - flow_data_store[flow_id], dict + built = ( + flow_id in flow_data_store + and flow_data_store[flow_id]["status"] == BuildStatus.SUCCESS ) return BuiltResponse( @@ -77,6 +86,11 @@ async def stream_build(flow_id: str): yield str(StreamData(event="error", data={"error": error_message})) return + if flow_data_store[flow_id].get("status") == BuildStatus.IN_PROGRESS: + error_message = "Already building" + yield str(StreamData(event="error", data={"error": error_message})) + return + graph_data = flow_data_store[flow_id].get("data") if not graph_data: @@ -110,6 +124,7 @@ async def stream_build(flow_id: str): except Exception as exc: params = str(exc) valid = False + flow_data_store[flow_id]["status"] = BuildStatus.FAILURE response = { "valid": valid, @@ -121,8 +136,10 @@ async def stream_build(flow_id: str): yield str(StreamData(event="message", data=response)) chat_manager.set_cache(flow_id, graph.build()) + flow_data_store[flow_id]["status"] = BuildStatus.SUCCESS except Exception as exc: logger.error("Error while building the flow: %s", exc) + flow_data_store[flow_id]["status"] = BuildStatus.FAILURE yield str(StreamData(event="error", data={"error": str(exc)})) finally: yield str(StreamData(event="message", data=final_response)) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index ed5bf8b3b..2cf62a504 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -1,3 +1,4 @@ +from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union from langflow.database.models.flow import FlowCreate, FlowRead @@ -5,6 +6,14 @@ from pydantic import BaseModel, Field, validator import json +class BuildStatus(Enum): + """Status of the build.""" + + SUCCESS = "success" + FAILURE = "failure" + IN_PROGRESS = "in_progress" + + class GraphData(BaseModel): """Data inside the exported flow.""" diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 3ddb64c24..29730cc9f 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -21,7 +21,6 @@ from langchain.chains.base import Chain from langchain.vectorstores.base import VectorStore from langchain.document_loaders.base import BaseLoader from langchain.prompts.base import BasePromptTemplate -from langflow.chat.config import ChatConfig def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: @@ -49,8 +48,8 @@ def convert_params_to_sets(params): def convert_kwargs(params): # if *kwargs are passed as a string, convert to dict - # first find any key that has kwargs in it - kwargs_keys = [key for key in params.keys() if "kwargs" in key] + # first find any key that has kwargs or config in it + kwargs_keys = [key for key in params.keys() if "kwargs" in key or "config" in key] for key in kwargs_keys: if isinstance(params[key], str): params[key] = json.loads(params[key]) @@ -85,11 +84,6 @@ def instantiate_based_on_type(class_object, base_type, node_type, params): def instantiate_llm(node_type, class_object, params: Dict): - # This is a workaround so JinaChat works until streaming is implemented - # if "openai_api_base" in params and "jina" in params["openai_api_base"]: - # False if condition is True - ChatConfig.streaming = "jina" not in params.get("openai_api_base", "") - return class_object(**params) diff --git a/src/backend/langflow/template/frontend_node/constants.py b/src/backend/langflow/template/frontend_node/constants.py index d30239ff7..90cdbf280 100644 --- a/src/backend/langflow/template/frontend_node/constants.py +++ b/src/backend/langflow/template/frontend_node/constants.py @@ -33,6 +33,22 @@ HUMAN_PROMPT = "{input}" QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"] +CTRANSFORMERS_DEFAULT_CONFIG = { + "top_k": 40, + "top_p": 0.95, + "temperature": 0.8, + "repetition_penalty": 1.1, + "last_n_tokens": 64, + "seed": -1, + "max_new_tokens": 256, + "stop": None, + "stream": False, + "reset": True, + "batch_size": 8, + "threads": -1, + "context_length": -1, + "gpu_layers": 0, +} # This variable is used to tell the user # that it can be changed to use other APIs diff --git a/src/backend/langflow/template/frontend_node/llms.py b/src/backend/langflow/template/frontend_node/llms.py index 65bffc303..703dbacc6 100644 --- a/src/backend/langflow/template/frontend_node/llms.py +++ b/src/backend/langflow/template/frontend_node/llms.py @@ -1,7 +1,9 @@ +import json from typing import Optional from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode +from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG from langflow.template.frontend_node.constants import OPENAI_API_BASE_INFO @@ -35,6 +37,13 @@ class LLMFrontendNode(FrontendNode): field.show = True field.advanced = not field.required + @staticmethod + def format_ctransformers_field(field: TemplateField): + if field.name == "config": + field.show = True + field.advanced = True + field.value = json.dumps(CTRANSFORMERS_DEFAULT_CONFIG, indent=2) + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: display_names_dict = { @@ -42,6 +51,7 @@ class LLMFrontendNode(FrontendNode): } FrontendNode.format_field(field, name) LLMFrontendNode.format_openai_field(field) + LLMFrontendNode.format_ctransformers_field(field) if name and "azure" in name.lower(): LLMFrontendNode.format_azure_field(field) if name and "llama" in name.lower(): diff --git a/src/frontend/src/controllers/API/index.ts b/src/frontend/src/controllers/API/index.ts index 2651a0058..cfae748d8 100644 --- a/src/frontend/src/controllers/API/index.ts +++ b/src/frontend/src/controllers/API/index.ts @@ -311,7 +311,7 @@ export async function getBuildStatus( export async function postBuildInit( flow: FlowType ): Promise> { - return await axios.post(`/api/v1/build/init`, flow); + return await axios.post(`/api/v1/build/init/${flow.id}`, flow); } // fetch(`/upload/${id}`, { diff --git a/tests/test_websocket.py b/tests/test_websocket.py index f571671e8..0199ff14b 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -7,7 +7,7 @@ import pytest def test_init_build(client): response = client.post( - "api/v1/build/init", json={"id": "test", "data": {"key": "value"}} + "api/v1/build/init/test", json={"id": "test", "data": {"key": "value"}} ) assert response.status_code == 201 assert response.json() == {"flowId": "test"}