fix: update to validate prompt endpoint

This commit is contained in:
Gabriel Almeida 2023-04-04 22:07:40 -03:00
commit 41f05b2e85
3 changed files with 13 additions and 39 deletions

View file

@ -27,10 +27,3 @@ class CodeValidationResponse(BaseModel):
class PromptValidationResponse(BaseModel):
input_variables: list
valid: bool
def validate_prompt(template):
# Extract the input variables from template
input_variables = extract_input_variables_from_prompt(template)
return input_variables, len(input_variables) > 0

View file

@ -5,8 +5,8 @@ from langflow.api.base import (
CodeValidationResponse,
Prompt,
PromptValidationResponse,
validate_prompt,
)
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.utils.logger import logger
from langflow.utils.validate import validate_code
@ -29,7 +29,8 @@ def post_validate_code(code: Code):
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
def post_validate_prompt(prompt: Prompt):
try:
input_variables, valid = validate_prompt(prompt.template)
return PromptValidationResponse(input_variables=input_variables, valid=valid)
input_variables = extract_input_variables_from_prompt(prompt.template)
return PromptValidationResponse(input_variables=input_variables)
except Exception as e:
logger.exception(e)
return HTTPException(status_code=500, detail=str(e))

View file

@ -1,25 +1,6 @@
import json
from typing import Dict
import pytest
from fastapi.testclient import TestClient
from langflow.interface.tools.constants import CUSTOM_TOOLS
from pathlib import Path
import pytest
def test_post_predict(client: TestClient):
with open(Path(__file__).parent / "data" / "Build_error.json") as f:
data = f.read()
json_data = json.loads(data)
data: Dict = json_data["data"]
data["message"] = "I'm Bob"
response = client.post("/predict", json=data)
assert response.status_code == 200
data["message"] = "What is my name?"
data["chatHistory"] = ["I'm Bob"]
response = client.post("/predict", json=data)
assert response.status_code == 200
assert "Bob" in response.json()["result"]
def test_get_all(client: TestClient):
@ -116,28 +97,27 @@ INVALID_PROMPT = "This is an invalid prompt without any input variable."
def test_valid_prompt(client: TestClient):
response = client.post("/validate/prompt", json={"template": VALID_PROMPT})
assert response.status_code == 200
assert response.json() == {"input_variables": ["product"], "valid": True}
assert response.json() == {"input_variables": ["product"]}
def test_invalid_prompt(client: TestClient):
response = client.post("/validate/prompt", json={"template": INVALID_PROMPT})
assert response.status_code == 200
assert response.json() == {"input_variables": [], "valid": False}
assert response.json() == {"input_variables": []}
@pytest.mark.parametrize(
"prompt,expected_input_variables,expected_validity",
"prompt,expected_input_variables",
[
("{color} is my favorite color.", ["color"], True),
("The weather is {weather} today.", ["weather"], True),
("This prompt has no variables.", [], False),
("{a}, {b}, and {c} are variables.", ["a", "b", "c"], True),
("{color} is my favorite color.", ["color"]),
("The weather is {weather} today.", ["weather"]),
("This prompt has no variables.", []),
("{a}, {b}, and {c} are variables.", ["a", "b", "c"]),
],
)
def test_various_prompts(client, prompt, expected_input_variables, expected_validity):
def test_various_prompts(client, prompt, expected_input_variables):
response = client.post("/validate/prompt", json={"template": prompt})
assert response.status_code == 200
assert response.json() == {
"input_variables": expected_input_variables,
"valid": expected_validity,
}