From c41957c67e77049d38717d90c8232d4ad3d42d8c Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Tue, 4 Apr 2023 20:43:41 -0300 Subject: [PATCH] feat: adding validate/prompt endpoint --- src/backend/langflow/api/base.py | 20 ++++-- src/backend/langflow/api/endpoints.py | 14 ---- src/backend/langflow/api/validate.py | 38 +++++++++++ src/backend/langflow/main.py | 2 + src/frontend/src/controllers/API/index.ts | 2 +- tests/test_endpoints.py | 79 +++++++++++++++++++++-- 6 files changed, 130 insertions(+), 25 deletions(-) create mode 100644 src/backend/langflow/api/validate.py diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py index 3c3e0d8eb..e3c749aab 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/base.py @@ -1,16 +1,17 @@ +from langflow.graph.utils import extract_input_variables_from_prompt from pydantic import BaseModel, validator class Code(BaseModel): code: str - @validator("code") - def validate_code(cls, v): - return v + +class Prompt(BaseModel): + template: str # Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}} -class ValidationResponse(BaseModel): +class CodeValidationResponse(BaseModel): imports: dict function: dict @@ -21,3 +22,14 @@ class ValidationResponse(BaseModel): @validator("function") def validate_function(cls, v): return v or {"errors": []} + + +class PromptValidationResponse(BaseModel): + input_variables: list + valid: bool + + +def validate_prompt(template): + # Extract the input variables from template + input_variables = extract_input_variables_from_prompt(template) + return input_variables, len(input_variables) > 0 diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py index c82616d3e..22f548156 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/endpoints.py @@ -3,10 +3,8 @@ from typing import Any, Dict from fastapi import APIRouter, HTTPException -from langflow.api.base import Code, ValidationResponse from langflow.interface.run import process_graph from langflow.interface.types import build_langchain_types_dict -from langflow.utils.validate import validate_code # build router router = APIRouter() @@ -26,15 +24,3 @@ def get_load(data: Dict[str, Any]): # Log stack trace logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) from e - - -@router.post("/validate", status_code=200, response_model=ValidationResponse) -def post_validate_code(code: Code): - try: - errors = validate_code(code.code) - return ValidationResponse( - imports=errors.get("imports", {}), - function=errors.get("function", {}), - ) - except Exception as e: - return HTTPException(status_code=500, detail=str(e)) diff --git a/src/backend/langflow/api/validate.py b/src/backend/langflow/api/validate.py new file mode 100644 index 000000000..7ea5d6eb7 --- /dev/null +++ b/src/backend/langflow/api/validate.py @@ -0,0 +1,38 @@ +from fastapi import HTTPException +from langflow.api.base import ( + Code, + CodeValidationResponse, + Prompt, + PromptValidationResponse, + validate_prompt, +) + +from langflow.utils.validate import validate_code +from langflow.utils.logger import logger + + +from fastapi import APIRouter, HTTPException + +# build router +router = APIRouter(prefix="/validate", tags=["validate"]) + + +@router.post("/code", status_code=200, response_model=CodeValidationResponse) +def post_validate_code(code: Code): + try: + errors = validate_code(code.code) + return CodeValidationResponse( + imports=errors.get("imports", {}), + function=errors.get("function", {}), + ) + except Exception as e: + return HTTPException(status_code=500, detail=str(e)) + + +@router.post("/prompt", status_code=200, response_model=PromptValidationResponse) +def post_validate_prompt(prompt: Prompt): + try: + input_variables, valid = validate_prompt(prompt.template) + return PromptValidationResponse(input_variables=input_variables, valid=valid) + except Exception as e: + return HTTPException(status_code=500, detail=str(e)) diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 21d17690a..176e46236 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from langflow.api.endpoints import router as endpoints_router +from langflow.api.validate import router as validate_router def create_app(): @@ -21,6 +22,7 @@ def create_app(): ) app.include_router(endpoints_router) + app.include_router(validate_router) return app diff --git a/src/frontend/src/controllers/API/index.ts b/src/frontend/src/controllers/API/index.ts index c6315d1b3..bad966ea9 100644 --- a/src/frontend/src/controllers/API/index.ts +++ b/src/frontend/src/controllers/API/index.ts @@ -12,5 +12,5 @@ export async function sendAll(data:sendAllProps) { export async function checkCode(code:string):Promise>{ - return await axios.post('/validate',{code}) + return await axios.post('/validate/code',{code}) } \ No newline at end of file diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 570bf554e..55f002dad 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,5 +1,25 @@ +import json +from typing import Dict from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS +from pathlib import Path + +import pytest + + +def test_post_predict(client: TestClient): + with open(Path(__file__).parent / "data" / "Build_error.json") as f: + data = f.read() + json_data = json.loads(data) + data: Dict = json_data["data"] + data["message"] = "I'm Bob" + response = client.post("/predict", json=data) + assert response.status_code == 200 + data["message"] = "What is my name?" + data["chatHistory"] = ["I'm Bob"] + response = client.post("/predict", json=data) + assert response.status_code == 200 + assert "Bob" in response.json()["result"] def test_get_all(client: TestClient): @@ -20,7 +40,7 @@ import math def square(x): return x ** 2 """ - response1 = client.post("/validate", json={"code": code1}) + response1 = client.post("/validate/code", json={"code": code1}) assert response1.status_code == 200 assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -31,7 +51,7 @@ import non_existent_module def square(x): return x ** 2 """ - response2 = client.post("/validate", json={"code": code2}) + response2 = client.post("/validate/code", json={"code": code2}) assert response2.status_code == 200 assert response2.json() == { "imports": {"errors": ["No module named 'non_existent_module'"]}, @@ -45,7 +65,7 @@ import math def square(x) return x ** 2 """ - response3 = client.post("/validate", json={"code": code3}) + response3 = client.post("/validate/code", json={"code": code3}) assert response3.status_code == 200 assert response3.json() == { "imports": {"errors": []}, @@ -53,11 +73,11 @@ def square(x) } # Test case with invalid JSON payload - response4 = client.post("/validate", json={"invalid_key": code1}) + response4 = client.post("/validate/code", json={"invalid_key": code1}) assert response4.status_code == 422 # Test case with an empty code string - response5 = client.post("/validate", json={"code": ""}) + response5 = client.post("/validate/code", json={"code": ""}) assert response5.status_code == 200 assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -68,9 +88,56 @@ import math def square(x) return x ** 2 """ - response6 = client.post("/validate", json={"code": code6}) + response6 = client.post("/validate/code", json={"code": code6}) assert response6.status_code == 200 assert response6.json() == { "imports": {"errors": []}, "function": {"errors": ["expected ':' (, line 4)"]}, } + + +VALID_PROMPT = """ +I want you to act as a naming consultant for new companies. + +Here are some examples of good company names: + +- search engine, Google +- social media, Facebook +- video sharing, YouTube + +The name should be short, catchy and easy to remember. + +What is a good name for a company that makes {product}? +""" + +INVALID_PROMPT = "This is an invalid prompt without any input variable." + + +def test_valid_prompt(client: TestClient): + response = client.post("/validate/prompt", json={"template": VALID_PROMPT}) + assert response.status_code == 200 + assert response.json() == {"input_variables": ["product"], "valid": True} + + +def test_invalid_prompt(client: TestClient): + response = client.post("/validate/prompt", json={"template": INVALID_PROMPT}) + assert response.status_code == 200 + assert response.json() == {"input_variables": [], "valid": False} + + +@pytest.mark.parametrize( + "prompt,expected_input_variables,expected_validity", + [ + ("{color} is my favorite color.", ["color"], True), + ("The weather is {weather} today.", ["weather"], True), + ("This prompt has no variables.", [], False), + ("{a}, {b}, and {c} are variables.", ["a", "b", "c"], True), + ], +) +def test_various_prompts(client, prompt, expected_input_variables, expected_validity): + response = client.post("/validate/prompt", json={"template": prompt}) + assert response.status_code == 200 + assert response.json() == { + "input_variables": expected_input_variables, + "valid": expected_validity, + }