Refactor endpoints.py to improve code handling and remove caching (#2180)
* chore: Refactor CustomComponent to improve repr_value handling
* ♻️ (endpoints.py): Refactor the simplified_run_flow and webhook_run_flow functions to remove caching
* Refactor endpoints.py to remove caching and update tests
This commit is contained in:
parent
8e74b16e64
commit
c53901f7ef
4 changed files with 49 additions and 72 deletions
|
|
@ -4,9 +4,6 @@ from uuid import UUID
|
|||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from langflow.api.utils import update_frontend_node_with_template_values
|
||||
from langflow.api.v1.schemas import (
|
||||
ConfigResponse,
|
||||
|
|
@ -34,6 +31,8 @@ from langflow.services.database.models.user.model import User
|
|||
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
|
||||
from langflow.services.session.service import SessionService
|
||||
from langflow.services.task.service import TaskService
|
||||
from loguru import logger
|
||||
from sqlmodel import Session, select
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.settings.manager import SettingsService
|
||||
|
|
@ -57,29 +56,20 @@ def get_all(
|
|||
|
||||
|
||||
async def simple_run_flow(
|
||||
db: Session,
|
||||
flow: Flow,
|
||||
input_request: SimplifiedAPIRequest,
|
||||
session_service: SessionService,
|
||||
stream: bool = False,
|
||||
api_key_user: Optional[User] = None,
|
||||
):
|
||||
try:
|
||||
task_result: List[RunOutputs] = []
|
||||
artifacts = {}
|
||||
user_id = api_key_user.id if api_key_user else None
|
||||
flow_id_str = str(flow.id)
|
||||
if input_request.session_id:
|
||||
session_data = await session_service.load_session(input_request.session_id, flow_id=flow_id_str)
|
||||
graph, artifacts = session_data if session_data else (None, None)
|
||||
if graph is None:
|
||||
raise ValueError(f"Session {input_request.session_id} not found")
|
||||
else:
|
||||
if flow.data is None:
|
||||
raise ValueError(f"Flow {flow_id_str} has no data")
|
||||
graph_data = flow.data
|
||||
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
|
||||
graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id))
|
||||
if flow.data is None:
|
||||
raise ValueError(f"Flow {flow_id_str} has no data")
|
||||
graph_data = flow.data.copy()
|
||||
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
|
||||
graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id))
|
||||
inputs = [
|
||||
InputValueRequest(components=[], input_value=input_request.input_value, type=input_request.input_type)
|
||||
]
|
||||
|
|
@ -101,8 +91,6 @@ async def simple_run_flow(
|
|||
session_id=input_request.session_id,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
|
@ -175,10 +163,8 @@ async def simplified_run_flow(
|
|||
"""
|
||||
try:
|
||||
return await simple_run_flow(
|
||||
db=db,
|
||||
flow=flow,
|
||||
input_request=input_request,
|
||||
session_service=session_service,
|
||||
stream=stream,
|
||||
api_key_user=api_key_user,
|
||||
)
|
||||
|
|
@ -249,7 +235,6 @@ async def webhook_run_flow(
|
|||
db=db,
|
||||
flow=flow,
|
||||
input_request=input_request,
|
||||
session_service=session_service,
|
||||
)
|
||||
return {"message": "Task started in the background", "status": "in progress"}
|
||||
except Exception as exc:
|
||||
|
|
@ -528,3 +513,4 @@ def get_config():
|
|||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@ from uuid import UUID
|
|||
import yaml
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from langchain_core.documents import Document
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.custom.code_parser.utils import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
|
|
@ -19,6 +17,7 @@ from langflow.schema.dotdict import dotdict
|
|||
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.utils import validate
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
|
@ -160,9 +159,9 @@ class CustomComponent(Component):
|
|||
self.repr_value = self.status
|
||||
if isinstance(self.repr_value, dict):
|
||||
self.repr_value = yaml.dump(self.repr_value)
|
||||
if isinstance(self.repr_value, BaseModel) and not isinstance(self.repr_value, Data):
|
||||
if isinstance(self.repr_value, BaseModel) and not isinstance(self.repr_value, Record):
|
||||
self.repr_value = str(self.repr_value)
|
||||
elif hasattr(self.repr_value, "to_json"):
|
||||
elif hasattr(self.repr_value, "to_json") and not isinstance(self.repr_value, Record):
|
||||
self.repr_value = self.repr_value.to_json()
|
||||
return self.repr_value
|
||||
|
||||
|
|
@ -469,3 +468,4 @@ class CustomComponent(Component):
|
|||
Any: The result of the build process.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.schema.graph import InputValue, Tweaks
|
||||
from langflow.schema.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.session.service import SessionService
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.api.v1.schemas import InputValueRequest
|
||||
|
|
@ -27,18 +25,13 @@ async def run_graph_internal(
|
|||
session_id: Optional[str] = None,
|
||||
inputs: Optional[List["InputValueRequest"]] = None,
|
||||
outputs: Optional[List[str]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
session_service: Optional[SessionService] = None,
|
||||
) -> tuple[List[RunOutputs], str]:
|
||||
"""Run the graph and generate the result"""
|
||||
inputs = inputs or []
|
||||
graph_data = graph._graph_data
|
||||
if session_id is None and session_service is not None:
|
||||
session_id_str = session_service.generate_key(session_id=flow_id, data_graph=graph_data)
|
||||
elif session_id is not None:
|
||||
session_id_str = session_id
|
||||
if session_id is None:
|
||||
session_id_str = flow_id
|
||||
else:
|
||||
raise ValueError("session_id or session_service must be provided")
|
||||
session_id_str = session_id
|
||||
components = []
|
||||
inputs_list = []
|
||||
types = []
|
||||
|
|
@ -53,16 +46,14 @@ async def run_graph_internal(
|
|||
fallback_to_env_vars = get_settings_service().settings.fallback_to_env_var
|
||||
|
||||
run_outputs = await graph.arun(
|
||||
inputs_list,
|
||||
components,
|
||||
types,
|
||||
outputs or [],
|
||||
inputs=inputs_list,
|
||||
inputs_components=components,
|
||||
types=types,
|
||||
outputs=outputs or [],
|
||||
stream=stream,
|
||||
session_id=session_id_str or "",
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
)
|
||||
if session_id_str and session_service:
|
||||
await session_service.update_session(session_id_str, (graph, artifacts))
|
||||
return run_outputs, session_id_str
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -701,36 +701,36 @@ def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_
|
|||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
def test_run_flow_with_session_id(client, starter_project, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
flow_id = starter_project["id"]
|
||||
payload = {
|
||||
"input_value": "value1",
|
||||
"input_type": "text",
|
||||
"output_type": "text",
|
||||
"session_id": "test-session-id",
|
||||
}
|
||||
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
data = response.json()
|
||||
assert {"detail": "Session test-session-id not found"} == data
|
||||
# @pytest.mark.api_key_required
|
||||
# def test_run_flow_with_session_id(client, starter_project, created_api_key):
|
||||
# headers = {"x-api-key": created_api_key.api_key}
|
||||
# flow_id = starter_project["id"]
|
||||
# payload = {
|
||||
# "input_value": "value1",
|
||||
# "input_type": "text",
|
||||
# "output_type": "text",
|
||||
# "session_id": "test-session-id",
|
||||
# }
|
||||
# response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
# assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# data = response.json()
|
||||
# assert {"detail": "Session test-session-id not found"} == data
|
||||
|
||||
|
||||
def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key):
|
||||
headers = {"x-api-key": created_api_key.api_key}
|
||||
flow_id = starter_project["id"]
|
||||
payload = {
|
||||
"input_value": "value1",
|
||||
"input_type": "text",
|
||||
"output_type": "text",
|
||||
"session_id": "invalid-session-id",
|
||||
}
|
||||
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
data = response.json()
|
||||
assert "detail" in data
|
||||
assert f"Session {payload['session_id']} not found" in data["detail"]
|
||||
# def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key):
|
||||
# headers = {"x-api-key": created_api_key.api_key}
|
||||
# flow_id = starter_project["id"]
|
||||
# payload = {
|
||||
# "input_value": "value1",
|
||||
# "input_type": "text",
|
||||
# "output_type": "text",
|
||||
# "session_id": "invalid-session-id",
|
||||
# }
|
||||
# response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
|
||||
# assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# data = response.json()
|
||||
# assert "detail" in data
|
||||
# assert f"Session {payload['session_id']} not found" in data["detail"]
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue