feat: adding validate/prompt endpoint
This commit is contained in:
parent
9d626c7cbc
commit
c41957c67e
6 changed files with 130 additions and 25 deletions
|
|
@ -1,16 +1,17 @@
|
|||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class Code(BaseModel):
|
||||
code: str
|
||||
|
||||
@validator("code")
|
||||
def validate_code(cls, v):
|
||||
return v
|
||||
|
||||
class Prompt(BaseModel):
|
||||
template: str
|
||||
|
||||
|
||||
# Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
class ValidationResponse(BaseModel):
|
||||
class CodeValidationResponse(BaseModel):
|
||||
imports: dict
|
||||
function: dict
|
||||
|
||||
|
|
@ -21,3 +22,14 @@ class ValidationResponse(BaseModel):
|
|||
@validator("function")
|
||||
def validate_function(cls, v):
|
||||
return v or {"errors": []}
|
||||
|
||||
|
||||
class PromptValidationResponse(BaseModel):
|
||||
input_variables: list
|
||||
valid: bool
|
||||
|
||||
|
||||
def validate_prompt(template):
|
||||
# Extract the input variables from template
|
||||
input_variables = extract_input_variables_from_prompt(template)
|
||||
return input_variables, len(input_variables) > 0
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@ from typing import Any, Dict
|
|||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.base import Code, ValidationResponse
|
||||
from langflow.interface.run import process_graph
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
# build router
|
||||
router = APIRouter()
|
||||
|
|
@ -26,15 +24,3 @@ def get_load(data: Dict[str, Any]):
|
|||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/validate", status_code=200, response_model=ValidationResponse)
|
||||
def post_validate_code(code: Code):
|
||||
try:
|
||||
errors = validate_code(code.code)
|
||||
return ValidationResponse(
|
||||
imports=errors.get("imports", {}),
|
||||
function=errors.get("function", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
38
src/backend/langflow/api/validate.py
Normal file
38
src/backend/langflow/api/validate.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from fastapi import HTTPException
|
||||
from langflow.api.base import (
|
||||
Code,
|
||||
CodeValidationResponse,
|
||||
Prompt,
|
||||
PromptValidationResponse,
|
||||
validate_prompt,
|
||||
)
|
||||
|
||||
from langflow.utils.validate import validate_code
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
# build router
|
||||
router = APIRouter(prefix="/validate", tags=["validate"])
|
||||
|
||||
|
||||
@router.post("/code", status_code=200, response_model=CodeValidationResponse)
|
||||
def post_validate_code(code: Code):
|
||||
try:
|
||||
errors = validate_code(code.code)
|
||||
return CodeValidationResponse(
|
||||
imports=errors.get("imports", {}),
|
||||
function=errors.get("function", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
|
||||
def post_validate_prompt(prompt: Prompt):
|
||||
try:
|
||||
input_variables, valid = validate_prompt(prompt.template)
|
||||
return PromptValidationResponse(input_variables=input_variables, valid=valid)
|
||||
except Exception as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -2,6 +2,7 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from langflow.api.endpoints import router as endpoints_router
|
||||
from langflow.api.validate import router as validate_router
|
||||
|
||||
|
||||
def create_app():
|
||||
|
|
@ -21,6 +22,7 @@ def create_app():
|
|||
)
|
||||
|
||||
app.include_router(endpoints_router)
|
||||
app.include_router(validate_router)
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,5 +12,5 @@ export async function sendAll(data:sendAllProps) {
|
|||
|
||||
export async function checkCode(code:string):Promise<AxiosResponse<errorsTypeAPI>>{
|
||||
|
||||
return await axios.post('/validate',{code})
|
||||
return await axios.post('/validate/code',{code})
|
||||
}
|
||||
|
|
@ -1,5 +1,25 @@
|
|||
import json
|
||||
from typing import Dict
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_post_predict(client: TestClient):
|
||||
with open(Path(__file__).parent / "data" / "Build_error.json") as f:
|
||||
data = f.read()
|
||||
json_data = json.loads(data)
|
||||
data: Dict = json_data["data"]
|
||||
data["message"] = "I'm Bob"
|
||||
response = client.post("/predict", json=data)
|
||||
assert response.status_code == 200
|
||||
data["message"] = "What is my name?"
|
||||
data["chatHistory"] = ["I'm Bob"]
|
||||
response = client.post("/predict", json=data)
|
||||
assert response.status_code == 200
|
||||
assert "Bob" in response.json()["result"]
|
||||
|
||||
|
||||
def test_get_all(client: TestClient):
|
||||
|
|
@ -20,7 +40,7 @@ import math
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response1 = client.post("/validate", json={"code": code1})
|
||||
response1 = client.post("/validate/code", json={"code": code1})
|
||||
assert response1.status_code == 200
|
||||
assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -31,7 +51,7 @@ import non_existent_module
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response2 = client.post("/validate", json={"code": code2})
|
||||
response2 = client.post("/validate/code", json={"code": code2})
|
||||
assert response2.status_code == 200
|
||||
assert response2.json() == {
|
||||
"imports": {"errors": ["No module named 'non_existent_module'"]},
|
||||
|
|
@ -45,7 +65,7 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response3 = client.post("/validate", json={"code": code3})
|
||||
response3 = client.post("/validate/code", json={"code": code3})
|
||||
assert response3.status_code == 200
|
||||
assert response3.json() == {
|
||||
"imports": {"errors": []},
|
||||
|
|
@ -53,11 +73,11 @@ def square(x)
|
|||
}
|
||||
|
||||
# Test case with invalid JSON payload
|
||||
response4 = client.post("/validate", json={"invalid_key": code1})
|
||||
response4 = client.post("/validate/code", json={"invalid_key": code1})
|
||||
assert response4.status_code == 422
|
||||
|
||||
# Test case with an empty code string
|
||||
response5 = client.post("/validate", json={"code": ""})
|
||||
response5 = client.post("/validate/code", json={"code": ""})
|
||||
assert response5.status_code == 200
|
||||
assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -68,9 +88,56 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response6 = client.post("/validate", json={"code": code6})
|
||||
response6 = client.post("/validate/code", json={"code": code6})
|
||||
assert response6.status_code == 200
|
||||
assert response6.json() == {
|
||||
"imports": {"errors": []},
|
||||
"function": {"errors": ["expected ':' (<unknown>, line 4)"]},
|
||||
}
|
||||
|
||||
|
||||
VALID_PROMPT = """
|
||||
I want you to act as a naming consultant for new companies.
|
||||
|
||||
Here are some examples of good company names:
|
||||
|
||||
- search engine, Google
|
||||
- social media, Facebook
|
||||
- video sharing, YouTube
|
||||
|
||||
The name should be short, catchy and easy to remember.
|
||||
|
||||
What is a good name for a company that makes {product}?
|
||||
"""
|
||||
|
||||
INVALID_PROMPT = "This is an invalid prompt without any input variable."
|
||||
|
||||
|
||||
def test_valid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": VALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": ["product"], "valid": True}
|
||||
|
||||
|
||||
def test_invalid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": INVALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": [], "valid": False}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,expected_input_variables,expected_validity",
|
||||
[
|
||||
("{color} is my favorite color.", ["color"], True),
|
||||
("The weather is {weather} today.", ["weather"], True),
|
||||
("This prompt has no variables.", [], False),
|
||||
("{a}, {b}, and {c} are variables.", ["a", "b", "c"], True),
|
||||
],
|
||||
)
|
||||
def test_various_prompts(client, prompt, expected_input_variables, expected_validity):
|
||||
response = client.post("/validate/prompt", json={"template": prompt})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"input_variables": expected_input_variables,
|
||||
"valid": expected_validity,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue