From 2bfe93e0b8e44ff82785889e7589928d2cb8799b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:01:44 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20feat(langflow):=20add=20new=20AP?= =?UTF-8?q?I=20endpoints=20for=20chat,=20validation,=20and=20version=20Thi?= =?UTF-8?q?s=20commit=20adds=20new=20API=20endpoints=20for=20chat,=20valid?= =?UTF-8?q?ation,=20and=20version.=20The=20chat=20endpoint=20is=20a=20webs?= =?UTF-8?q?ocket=20endpoint=20for=20chat.=20The=20validation=20endpoint=20?= =?UTF-8?q?has=20three=20sub-endpoints=20for=20validating=20code,=20prompt?= =?UTF-8?q?,=20and=20node.=20The=20version=20endpoint=20returns=20the=20ve?= =?UTF-8?q?rsion=20of=20LangFlow.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/chat.py | 26 +++++++++ src/backend/langflow/api/v1/endpoints.py | 44 +++++++++++++++ src/backend/langflow/api/v1/schemas.py | 70 ++++++++++++++++++++++++ src/backend/langflow/api/v1/validate.py | 57 +++++++++++++++++++ 4 files changed, 197 insertions(+) create mode 100644 src/backend/langflow/api/v1/chat.py create mode 100644 src/backend/langflow/api/v1/endpoints.py create mode 100644 src/backend/langflow/api/v1/schemas.py create mode 100644 src/backend/langflow/api/v1/validate.py diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py new file mode 100644 index 000000000..7df4c65ed --- /dev/null +++ b/src/backend/langflow/api/v1/chat.py @@ -0,0 +1,26 @@ +from fastapi import ( + APIRouter, + WebSocket, + WebSocketDisconnect, + WebSocketException, + status, +) + +from langflow.chat.manager import ChatManager +from langflow.utils.logger import logger + +router = APIRouter() +chat_manager = ChatManager() + + +@router.websocket("/chat/{client_id}") +async def websocket_endpoint(client_id: str, websocket: WebSocket): + """Websocket endpoint for chat.""" + try: + await chat_manager.handle_websocket(client_id, websocket) + except WebSocketException as exc: + logger.error(exc) + await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) + except WebSocketDisconnect as exc: + logger.error(exc) + await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc)) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py new file mode 100644 index 000000000..1e9b0deb1 --- /dev/null +++ b/src/backend/langflow/api/v1/endpoints.py @@ -0,0 +1,44 @@ +import logging +from importlib.metadata import version + +from fastapi import APIRouter, HTTPException + +from langflow.api.v1.schemas import ( + ExportedFlow, + GraphData, + PredictRequest, + PredictResponse, +) + +from langflow.interface.types import build_langchain_types_dict + +# build router +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("/all") +def get_all(): + return build_langchain_types_dict() + + +@router.post("/predict", response_model=PredictResponse) +async def get_load(predict_request: PredictRequest): + try: + from langflow.processing.process import process_graph_cached + + exported_flow: ExportedFlow = predict_request.exported_flow + graph_data: GraphData = exported_flow.data + data = graph_data.dict() + response = process_graph_cached(data, predict_request.message) + return PredictResponse(result=response.get("result", "")) + except Exception as e: + # Log stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +# get endpoint to return version of langflow +@router.get("/version") +def get_version(): + return {"version": version("langflow")} diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py new file mode 100644 index 000000000..f73b0642d --- /dev/null +++ b/src/backend/langflow/api/v1/schemas.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, List, Union + +from pydantic import BaseModel, validator + + +class GraphData(BaseModel): + """Data inside the exported flow.""" + + nodes: List[Dict[str, Any]] + edges: List[Dict[str, Any]] + + +class ExportedFlow(BaseModel): + """Exported flow from LangFlow.""" + + description: str + name: str + id: str + data: GraphData + + +class PredictRequest(BaseModel): + """Predict request schema.""" + + message: str + exported_flow: ExportedFlow + + +class PredictResponse(BaseModel): + """Predict response schema.""" + + result: str + + +class ChatMessage(BaseModel): + """Chat message schema.""" + + is_bot: bool = False + message: Union[str, None] = None + type: str = "human" + + +class ChatResponse(ChatMessage): + """Chat response schema.""" + + intermediate_steps: str + type: str + is_bot: bool = True + files: list = [] + + @validator("type") + def validate_message_type(cls, v): + if v not in ["start", "stream", "end", "error", "info", "file"]: + raise ValueError("type must be start, stream, end, error, info, or file") + return v + + +class FileResponse(ChatMessage): + """File response schema.""" + + data: Any + data_type: str + type: str = "file" + is_bot: bool = True + + @validator("data_type") + def validate_data_type(cls, v): + if v not in ["image", "csv"]: + raise ValueError("data_type must be image or csv") + return v diff --git a/src/backend/langflow/api/v1/validate.py b/src/backend/langflow/api/v1/validate.py new file mode 100644 index 000000000..009cb9a30 --- /dev/null +++ b/src/backend/langflow/api/v1/validate.py @@ -0,0 +1,57 @@ +import json + +from fastapi import APIRouter, HTTPException + +from langflow.api.v1.base import ( + Code, + CodeValidationResponse, + Prompt, + PromptValidationResponse, + validate_prompt, +) +from langflow.graph.vertex.types import VectorStoreVertex +from langflow.graph import Graph +from langflow.utils.logger import logger +from langflow.utils.validate import validate_code + +# build router +router = APIRouter(prefix="/validate", tags=["validate"]) + + +@router.post("/code", status_code=200, response_model=CodeValidationResponse) +def post_validate_code(code: Code): + try: + errors = validate_code(code.code) + return CodeValidationResponse( + imports=errors.get("imports", {}), + function=errors.get("function", {}), + ) + except Exception as e: + return HTTPException(status_code=500, detail=str(e)) + + +@router.post("/prompt", status_code=200, response_model=PromptValidationResponse) +def post_validate_prompt(prompt: Prompt): + try: + return validate_prompt(prompt.template) + except Exception as e: + logger.exception(e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +# validate node +@router.post("/node/{node_id}", status_code=200) +def post_validate_node(node_id: str, data: dict): + try: + # build graph + graph = Graph.from_payload(data) + # validate node + node = graph.get_node(node_id) + if node is None: + raise ValueError(f"Node {node_id} not found") + if not isinstance(node, VectorStoreVertex): + node.build() + return json.dumps({"valid": True, "params": str(node._built_object_repr())}) + except Exception as e: + logger.exception(e) + return json.dumps({"valid": False, "params": str(e)})