Refactor endpoints.py to improve code handling and remove caching (#2180)

* chore: Refactor CustomComponent to improve repr_value handling

* ♻️ (endpoints.py): Refactor the simplified_run_flow and webhook_run_flow functions to remove caching

* Refactor endpoints.py to remove caching and update tests
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-14 18:39:44 -07:00 committed by GitHub
commit c53901f7ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 49 additions and 72 deletions

View file

@ -4,9 +4,6 @@ from uuid import UUID
import sqlalchemy as sa
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status
from loguru import logger
from sqlmodel import Session, select
from langflow.api.utils import update_frontend_node_with_template_values
from langflow.api.v1.schemas import (
ConfigResponse,
@ -34,6 +31,8 @@ from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
from langflow.services.session.service import SessionService
from langflow.services.task.service import TaskService
from loguru import logger
from sqlmodel import Session, select
if TYPE_CHECKING:
from langflow.services.settings.manager import SettingsService
@ -57,29 +56,20 @@ def get_all(
async def simple_run_flow(
db: Session,
flow: Flow,
input_request: SimplifiedAPIRequest,
session_service: SessionService,
stream: bool = False,
api_key_user: Optional[User] = None,
):
try:
task_result: List[RunOutputs] = []
artifacts = {}
user_id = api_key_user.id if api_key_user else None
flow_id_str = str(flow.id)
if input_request.session_id:
session_data = await session_service.load_session(input_request.session_id, flow_id=flow_id_str)
graph, artifacts = session_data if session_data else (None, None)
if graph is None:
raise ValueError(f"Session {input_request.session_id} not found")
else:
if flow.data is None:
raise ValueError(f"Flow {flow_id_str} has no data")
graph_data = flow.data
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id))
if flow.data is None:
raise ValueError(f"Flow {flow_id_str} has no data")
graph_data = flow.data.copy()
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id))
inputs = [
InputValueRequest(components=[], input_value=input_request.input_value, type=input_request.input_type)
]
@ -101,8 +91,6 @@ async def simple_run_flow(
session_id=input_request.session_id,
inputs=inputs,
outputs=outputs,
artifacts=artifacts,
session_service=session_service,
stream=stream,
)
@ -175,10 +163,8 @@ async def simplified_run_flow(
"""
try:
return await simple_run_flow(
db=db,
flow=flow,
input_request=input_request,
session_service=session_service,
stream=stream,
api_key_user=api_key_user,
)
@ -249,7 +235,6 @@ async def webhook_run_flow(
db=db,
flow=flow,
input_request=input_request,
session_service=session_service,
)
return {"message": "Task started in the background", "status": "in progress"}
except Exception as exc:
@ -528,3 +513,4 @@ def get_config():
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
raise HTTPException(status_code=500, detail=str(exc)) from exc

View file

@ -6,8 +6,6 @@ from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from langchain_core.documents import Document
from pydantic import BaseModel
from langflow.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
@ -19,6 +17,7 @@ from langflow.schema.dotdict import dotdict
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
from langflow.services.storage.service import StorageService
from langflow.utils import validate
from pydantic import BaseModel
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
@ -160,9 +159,9 @@ class CustomComponent(Component):
self.repr_value = self.status
if isinstance(self.repr_value, dict):
self.repr_value = yaml.dump(self.repr_value)
if isinstance(self.repr_value, BaseModel) and not isinstance(self.repr_value, Data):
if isinstance(self.repr_value, BaseModel) and not isinstance(self.repr_value, Record):
self.repr_value = str(self.repr_value)
elif hasattr(self.repr_value, "to_json"):
elif hasattr(self.repr_value, "to_json") and not isinstance(self.repr_value, Record):
self.repr_value = self.repr_value.to_json()
return self.repr_value
@ -469,3 +468,4 @@ class CustomComponent(Component):
Any: The result of the build process.
"""
raise NotImplementedError
raise NotImplementedError

View file

@ -1,15 +1,13 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from loguru import logger
from pydantic import BaseModel
from langflow.graph.graph.base import Graph
from langflow.graph.schema import RunOutputs
from langflow.graph.vertex.base import Vertex
from langflow.schema.graph import InputValue, Tweaks
from langflow.schema.schema import INPUT_FIELD_NAME
from langflow.services.deps import get_settings_service
from langflow.services.session.service import SessionService
from loguru import logger
from pydantic import BaseModel
if TYPE_CHECKING:
from langflow.api.v1.schemas import InputValueRequest
@ -27,18 +25,13 @@ async def run_graph_internal(
session_id: Optional[str] = None,
inputs: Optional[List["InputValueRequest"]] = None,
outputs: Optional[List[str]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
) -> tuple[List[RunOutputs], str]:
"""Run the graph and generate the result"""
inputs = inputs or []
graph_data = graph._graph_data
if session_id is None and session_service is not None:
session_id_str = session_service.generate_key(session_id=flow_id, data_graph=graph_data)
elif session_id is not None:
session_id_str = session_id
if session_id is None:
session_id_str = flow_id
else:
raise ValueError("session_id or session_service must be provided")
session_id_str = session_id
components = []
inputs_list = []
types = []
@ -53,16 +46,14 @@ async def run_graph_internal(
fallback_to_env_vars = get_settings_service().settings.fallback_to_env_var
run_outputs = await graph.arun(
inputs_list,
components,
types,
outputs or [],
inputs=inputs_list,
inputs_components=components,
types=types,
outputs=outputs or [],
stream=stream,
session_id=session_id_str or "",
fallback_to_env_vars=fallback_to_env_vars,
)
if session_id_str and session_service:
await session_service.update_session(session_id_str, (graph, artifacts))
return run_outputs, session_id_str

View file

@ -701,36 +701,36 @@ def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.api_key_required
def test_run_flow_with_session_id(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]
payload = {
"input_value": "value1",
"input_type": "text",
"output_type": "text",
"session_id": "test-session-id",
}
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert {"detail": "Session test-session-id not found"} == data
# @pytest.mark.api_key_required
# def test_run_flow_with_session_id(client, starter_project, created_api_key):
# headers = {"x-api-key": created_api_key.api_key}
# flow_id = starter_project["id"]
# payload = {
# "input_value": "value1",
# "input_type": "text",
# "output_type": "text",
# "session_id": "test-session-id",
# }
# response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
# assert response.status_code == status.HTTP_404_NOT_FOUND
# data = response.json()
# assert {"detail": "Session test-session-id not found"} == data
def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]
payload = {
"input_value": "value1",
"input_type": "text",
"output_type": "text",
"session_id": "invalid-session-id",
}
response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert "detail" in data
assert f"Session {payload['session_id']} not found" in data["detail"]
# def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key):
# headers = {"x-api-key": created_api_key.api_key}
# flow_id = starter_project["id"]
# payload = {
# "input_value": "value1",
# "input_type": "text",
# "output_type": "text",
# "session_id": "invalid-session-id",
# }
# response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers)
# assert response.status_code == status.HTTP_404_NOT_FOUND
# data = response.json()
# assert "detail" in data
# assert f"Session {payload['session_id']} not found" in data["detail"]
@pytest.mark.api_key_required