diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py index 04ae0c3bb..11b31b004 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/base.py @@ -27,10 +27,3 @@ class CodeValidationResponse(BaseModel): class PromptValidationResponse(BaseModel): input_variables: list - valid: bool - - -def validate_prompt(template): - # Extract the input variables from template - input_variables = extract_input_variables_from_prompt(template) - return input_variables, len(input_variables) > 0 diff --git a/src/backend/langflow/api/validate.py b/src/backend/langflow/api/validate.py index 81d7f8500..6dea45df0 100644 --- a/src/backend/langflow/api/validate.py +++ b/src/backend/langflow/api/validate.py @@ -5,8 +5,8 @@ from langflow.api.base import ( CodeValidationResponse, Prompt, PromptValidationResponse, - validate_prompt, ) +from langflow.graph.utils import extract_input_variables_from_prompt from langflow.utils.logger import logger from langflow.utils.validate import validate_code @@ -29,7 +29,8 @@ def post_validate_code(code: Code): @router.post("/prompt", status_code=200, response_model=PromptValidationResponse) def post_validate_prompt(prompt: Prompt): try: - input_variables, valid = validate_prompt(prompt.template) - return PromptValidationResponse(input_variables=input_variables, valid=valid) + input_variables = extract_input_variables_from_prompt(prompt.template) + return PromptValidationResponse(input_variables=input_variables) except Exception as e: + logger.exception(e) return HTTPException(status_code=500, detail=str(e)) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 55f002dad..83f6c62b1 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,25 +1,6 @@ -import json -from typing import Dict +import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS -from pathlib import Path - -import pytest - - -def test_post_predict(client: TestClient): - with open(Path(__file__).parent / "data" / "Build_error.json") as f: - data = f.read() - json_data = json.loads(data) - data: Dict = json_data["data"] - data["message"] = "I'm Bob" - response = client.post("/predict", json=data) - assert response.status_code == 200 - data["message"] = "What is my name?" - data["chatHistory"] = ["I'm Bob"] - response = client.post("/predict", json=data) - assert response.status_code == 200 - assert "Bob" in response.json()["result"] def test_get_all(client: TestClient): @@ -116,28 +97,27 @@ INVALID_PROMPT = "This is an invalid prompt without any input variable." def test_valid_prompt(client: TestClient): response = client.post("/validate/prompt", json={"template": VALID_PROMPT}) assert response.status_code == 200 - assert response.json() == {"input_variables": ["product"], "valid": True} + assert response.json() == {"input_variables": ["product"]} def test_invalid_prompt(client: TestClient): response = client.post("/validate/prompt", json={"template": INVALID_PROMPT}) assert response.status_code == 200 - assert response.json() == {"input_variables": [], "valid": False} + assert response.json() == {"input_variables": []} @pytest.mark.parametrize( - "prompt,expected_input_variables,expected_validity", + "prompt,expected_input_variables", [ - ("{color} is my favorite color.", ["color"], True), - ("The weather is {weather} today.", ["weather"], True), - ("This prompt has no variables.", [], False), - ("{a}, {b}, and {c} are variables.", ["a", "b", "c"], True), + ("{color} is my favorite color.", ["color"]), + ("The weather is {weather} today.", ["weather"]), + ("This prompt has no variables.", []), + ("{a}, {b}, and {c} are variables.", ["a", "b", "c"]), ], ) -def test_various_prompts(client, prompt, expected_input_variables, expected_validity): +def test_various_prompts(client, prompt, expected_input_variables): response = client.post("/validate/prompt", json={"template": prompt}) assert response.status_code == 200 assert response.json() == { "input_variables": expected_input_variables, - "valid": expected_validity, }