🚀 feat(langflow): add new API endpoints for chat, validation, and version

This commit adds new API endpoints for chat, validation, and version. The chat endpoint is a websocket endpoint for chat. The validation endpoint has three sub-endpoints for validating code, prompt, and node. The version endpoint returns the version of LangFlow.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-06 10:01:44 -03:00
commit 2bfe93e0b8
4 changed files with 197 additions and 0 deletions

View file

@ -0,0 +1,26 @@
from fastapi import (
APIRouter,
WebSocket,
WebSocketDisconnect,
WebSocketException,
status,
)
from langflow.chat.manager import ChatManager
from langflow.utils.logger import logger
router = APIRouter()
chat_manager = ChatManager()
@router.websocket("/chat/{client_id}")
async def websocket_endpoint(client_id: str, websocket: WebSocket):
"""Websocket endpoint for chat."""
try:
await chat_manager.handle_websocket(client_id, websocket)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except WebSocketDisconnect as exc:
logger.error(exc)
await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc))

View file

@ -0,0 +1,44 @@
import logging
from importlib.metadata import version
from fastapi import APIRouter, HTTPException
from langflow.api.v1.schemas import (
ExportedFlow,
GraphData,
PredictRequest,
PredictResponse,
)
from langflow.interface.types import build_langchain_types_dict
# build router
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/all")
def get_all():
return build_langchain_types_dict()
@router.post("/predict", response_model=PredictResponse)
async def get_load(predict_request: PredictRequest):
try:
from langflow.processing.process import process_graph_cached
exported_flow: ExportedFlow = predict_request.exported_flow
graph_data: GraphData = exported_flow.data
data = graph_data.dict()
response = process_graph_cached(data, predict_request.message)
return PredictResponse(result=response.get("result", ""))
except Exception as e:
# Log stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e)) from e
# get endpoint to return version of langflow
@router.get("/version")
def get_version():
return {"version": version("langflow")}

View file

@ -0,0 +1,70 @@
from typing import Any, Dict, List, Union
from pydantic import BaseModel, validator
class GraphData(BaseModel):
"""Data inside the exported flow."""
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
class ExportedFlow(BaseModel):
"""Exported flow from LangFlow."""
description: str
name: str
id: str
data: GraphData
class PredictRequest(BaseModel):
"""Predict request schema."""
message: str
exported_flow: ExportedFlow
class PredictResponse(BaseModel):
"""Predict response schema."""
result: str
class ChatMessage(BaseModel):
"""Chat message schema."""
is_bot: bool = False
message: Union[str, None] = None
type: str = "human"
class ChatResponse(ChatMessage):
"""Chat response schema."""
intermediate_steps: str
type: str
is_bot: bool = True
files: list = []
@validator("type")
def validate_message_type(cls, v):
if v not in ["start", "stream", "end", "error", "info", "file"]:
raise ValueError("type must be start, stream, end, error, info, or file")
return v
class FileResponse(ChatMessage):
"""File response schema."""
data: Any
data_type: str
type: str = "file"
is_bot: bool = True
@validator("data_type")
def validate_data_type(cls, v):
if v not in ["image", "csv"]:
raise ValueError("data_type must be image or csv")
return v

View file

@ -0,0 +1,57 @@
import json
from fastapi import APIRouter, HTTPException
from langflow.api.v1.base import (
Code,
CodeValidationResponse,
Prompt,
PromptValidationResponse,
validate_prompt,
)
from langflow.graph.vertex.types import VectorStoreVertex
from langflow.graph import Graph
from langflow.utils.logger import logger
from langflow.utils.validate import validate_code
# build router
router = APIRouter(prefix="/validate", tags=["validate"])
@router.post("/code", status_code=200, response_model=CodeValidationResponse)
def post_validate_code(code: Code):
try:
errors = validate_code(code.code)
return CodeValidationResponse(
imports=errors.get("imports", {}),
function=errors.get("function", {}),
)
except Exception as e:
return HTTPException(status_code=500, detail=str(e))
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
def post_validate_prompt(prompt: Prompt):
try:
return validate_prompt(prompt.template)
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e)) from e
# validate node
@router.post("/node/{node_id}", status_code=200)
def post_validate_node(node_id: str, data: dict):
try:
# build graph
graph = Graph.from_payload(data)
# validate node
node = graph.get_node(node_id)
if node is None:
raise ValueError(f"Node {node_id} not found")
if not isinstance(node, VectorStoreVertex):
node.build()
return json.dumps({"valid": True, "params": str(node._built_object_repr())})
except Exception as e:
logger.exception(e)
return json.dumps({"valid": False, "params": str(e)})