diff --git a/src/backend/langflow/__main__.py b/src/backend/langflow/__main__.py index af57ed381..543c59f32 100644 --- a/src/backend/langflow/__main__.py +++ b/src/backend/langflow/__main__.py @@ -18,10 +18,12 @@ def get_number_of_workers(workers=None): return workers -def update_settings(config: str, dev: bool = False): +def update_settings(config: str, dev: bool = False, database_url: str = None): """Update the settings from a config file.""" if config: settings.update_from_yaml(config, dev=dev) + if database_url: + settings.update_database_url(database_url) def serve_on_jcloud(): @@ -81,6 +83,10 @@ def serve( log_file: Path = typer.Option("logs/langflow.log", help="Path to the log file."), jcloud: bool = typer.Option(False, help="Deploy on Jina AI Cloud"), dev: bool = typer.Option(False, help="Run in development mode (may contain bugs)"), + database_url: str = typer.Option( + None, + help="Database URL to connect to. If not provided, a local SQLite database will be used.", + ), ): """ Run the Langflow server. @@ -90,7 +96,7 @@ def serve( return serve_on_jcloud() configure(log_level=log_level, log_file=log_file) - update_settings(config, dev=dev) + update_settings(config, dev=dev, database_url=database_url) app = create_app() # get the directory of the current file path = Path(__file__).parent diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py index 19d3d5f46..610d86f18 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/endpoints.py @@ -3,7 +3,7 @@ from langflow.database.models.flow import Flow from langflow.utils.logger import logger from importlib.metadata import version -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile +from fastapi import APIRouter, Depends, HTTPException from langflow.api.schemas import ( GraphData, @@ -12,10 +12,9 @@ from langflow.api.schemas import ( ) from langflow.interface.run import process_graph_cached from langflow.interface.types import build_langchain_types_dict -from langflow.cache import cache_manager from langflow.database.base import get_session from sqlmodel import Session -from sqlmodel import select + # build router router = APIRouter(tags=["Base"]) @@ -26,8 +25,11 @@ def get_all(): @router.post("/predict/{flow_id}", status_code=200, response_model=PredictResponse) -async def get_load(predict_request: PredictRequest, flow_id: str, session: Session= Depends(get_session)): - +async def get_load( + predict_request: PredictRequest, + flow_id: str, + session: Session = Depends(get_session), +): try: flow_obj = session.get(Flow, flow_id) if flow_obj is None: @@ -35,7 +37,10 @@ async def get_load(predict_request: PredictRequest, flow_id: str, session: Sessi graph_data: GraphData = json.loads(flow_obj.flow) data = graph_data.dict() response = process_graph_cached(data, predict_request.message) - return PredictResponse(result=response.get("result", ""), intermediate_steps=response.get("thought", "")) + return PredictResponse( + result=response.get("result", ""), + intermediate_steps=response.get("thought", ""), + ) except Exception as e: # Log stack trace logger.exception(e) @@ -51,11 +56,3 @@ def get_version(): @router.get("/health") def get_health(): return {"status": "OK"} - - -# Make an endpoint to upload a file using the client_id and -# cache the file in the backend -@router.post("/uploadfile/{client_id}") -async def create_upload_file(client_id: str, file: UploadFile = File(...)): - - # TODO: Implement this endpoint diff --git a/src/backend/langflow/database/base.py b/src/backend/langflow/database/base.py index 6aa36baab..69f8b993b 100644 --- a/src/backend/langflow/database/base.py +++ b/src/backend/langflow/database/base.py @@ -1,11 +1,12 @@ from langflow.settings import settings from sqlmodel import SQLModel, Session, create_engine -sqlite_file_name = "database.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +if settings.database_url.startswith("sqlite"): + connect_args = {"check_same_thread": False} +else: + connect_args = {} +engine = create_engine(settings.database_url, connect_args=connect_args) def create_db_and_tables(): diff --git a/src/backend/langflow/database/models/flow.py b/src/backend/langflow/database/models/flow.py index d7c0ea698..dc42da928 100644 --- a/src/backend/langflow/database/models/flow.py +++ b/src/backend/langflow/database/models/flow.py @@ -1,5 +1,3 @@ -from typing import List -from pydantic import BaseModel from sqlmodel import Field, SQLModel from uuid import UUID, uuid4 diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 75c19bd25..152345860 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -5,9 +5,8 @@ from langflow.api.chat import router as chat_router from langflow.api.endpoints import router as endpoints_router from langflow.api.validate import router as validate_router from langflow.api.database import router as database_router -from langflow.utils.logger import logger from langflow.database.base import create_db_and_tables -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter # build router diff --git a/src/backend/langflow/settings.py b/src/backend/langflow/settings.py index 5ef61ab7b..5b66b8d7c 100644 --- a/src/backend/langflow/settings.py +++ b/src/backend/langflow/settings.py @@ -1,5 +1,5 @@ import os -from typing import List +from typing import List, Optional import yaml from pydantic import BaseSettings, root_validator @@ -20,10 +20,12 @@ class Settings(BaseSettings): textsplitters: List[str] = [] utilities: List[str] = [] dev: bool = False + dabatabase_url: str = "sqlite:///./langflow.db" class Config: validate_assignment = True extra = "ignore" + env_prefix = "LANGFLOW_" @root_validator(allow_reuse=True) def validate_lists(cls, values): @@ -46,6 +48,10 @@ class Settings(BaseSettings): self.utilities = new_settings.utilities or [] self.dev = dev + def update_database_url(self, database_url: Optional[str] = None): + if database_url: + self.database_url = database_url + def save_settings_to_yaml(settings: Settings, file_path: str): with open(file_path, "w") as f: