fix: update to validate prompt endpoint
This commit is contained in:
parent
9c62adb8a6
commit
41f05b2e85
3 changed files with 13 additions and 39 deletions
|
|
@ -27,10 +27,3 @@ class CodeValidationResponse(BaseModel):
|
|||
|
||||
class PromptValidationResponse(BaseModel):
|
||||
input_variables: list
|
||||
valid: bool
|
||||
|
||||
|
||||
def validate_prompt(template):
|
||||
# Extract the input variables from template
|
||||
input_variables = extract_input_variables_from_prompt(template)
|
||||
return input_variables, len(input_variables) > 0
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from langflow.api.base import (
|
|||
CodeValidationResponse,
|
||||
Prompt,
|
||||
PromptValidationResponse,
|
||||
validate_prompt,
|
||||
)
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
|
|
@ -29,7 +29,8 @@ def post_validate_code(code: Code):
|
|||
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
|
||||
def post_validate_prompt(prompt: Prompt):
|
||||
try:
|
||||
input_variables, valid = validate_prompt(prompt.template)
|
||||
return PromptValidationResponse(input_variables=input_variables, valid=valid)
|
||||
input_variables = extract_input_variables_from_prompt(prompt.template)
|
||||
return PromptValidationResponse(input_variables=input_variables)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -1,25 +1,6 @@
|
|||
import json
|
||||
from typing import Dict
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from langflow.interface.tools.constants import CUSTOM_TOOLS
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_post_predict(client: TestClient):
|
||||
with open(Path(__file__).parent / "data" / "Build_error.json") as f:
|
||||
data = f.read()
|
||||
json_data = json.loads(data)
|
||||
data: Dict = json_data["data"]
|
||||
data["message"] = "I'm Bob"
|
||||
response = client.post("/predict", json=data)
|
||||
assert response.status_code == 200
|
||||
data["message"] = "What is my name?"
|
||||
data["chatHistory"] = ["I'm Bob"]
|
||||
response = client.post("/predict", json=data)
|
||||
assert response.status_code == 200
|
||||
assert "Bob" in response.json()["result"]
|
||||
|
||||
|
||||
def test_get_all(client: TestClient):
|
||||
|
|
@ -116,28 +97,27 @@ INVALID_PROMPT = "This is an invalid prompt without any input variable."
|
|||
def test_valid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": VALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": ["product"], "valid": True}
|
||||
assert response.json() == {"input_variables": ["product"]}
|
||||
|
||||
|
||||
def test_invalid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": INVALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": [], "valid": False}
|
||||
assert response.json() == {"input_variables": []}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,expected_input_variables,expected_validity",
|
||||
"prompt,expected_input_variables",
|
||||
[
|
||||
("{color} is my favorite color.", ["color"], True),
|
||||
("The weather is {weather} today.", ["weather"], True),
|
||||
("This prompt has no variables.", [], False),
|
||||
("{a}, {b}, and {c} are variables.", ["a", "b", "c"], True),
|
||||
("{color} is my favorite color.", ["color"]),
|
||||
("The weather is {weather} today.", ["weather"]),
|
||||
("This prompt has no variables.", []),
|
||||
("{a}, {b}, and {c} are variables.", ["a", "b", "c"]),
|
||||
],
|
||||
)
|
||||
def test_various_prompts(client, prompt, expected_input_variables, expected_validity):
|
||||
def test_various_prompts(client, prompt, expected_input_variables):
|
||||
response = client.post("/validate/prompt", json={"template": prompt})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"input_variables": expected_input_variables,
|
||||
"valid": expected_validity,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue