From c53901f7efca9a6cb7e9d2efbdd9229111bc04a6 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 14 Jun 2024 18:39:44 -0700 Subject: [PATCH] Refactor endpoints.py to improve code handling and remove caching (#2180) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: Refactor CustomComponent to improve repr_value handling * ♻️ (endpoints.py): Refactor the simplified_run_flow and webhook_run_flow functions to remove caching * Refactor endpoints.py to remove caching and update tests --- src/backend/base/langflow/api/v1/endpoints.py | 30 +++------- .../custom_component/custom_component.py | 8 +-- .../base/langflow/processing/process.py | 27 +++------ tests/test_endpoints.py | 56 +++++++++---------- 4 files changed, 49 insertions(+), 72 deletions(-) diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index d7b52ed32..4b31661fa 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -4,9 +4,6 @@ from uuid import UUID import sqlalchemy as sa from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException, Request, UploadFile, status -from loguru import logger -from sqlmodel import Session, select - from langflow.api.utils import update_frontend_node_with_template_values from langflow.api.v1.schemas import ( ConfigResponse, @@ -34,6 +31,8 @@ from langflow.services.database.models.user.model import User from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service from langflow.services.session.service import SessionService from langflow.services.task.service import TaskService +from loguru import logger +from sqlmodel import Session, select if TYPE_CHECKING: from langflow.services.settings.manager import SettingsService @@ -57,29 +56,20 @@ def get_all( async def simple_run_flow( - db: Session, flow: Flow, input_request: SimplifiedAPIRequest, - session_service: SessionService, stream: bool = False, api_key_user: Optional[User] = None, ): try: task_result: List[RunOutputs] = [] - artifacts = {} user_id = api_key_user.id if api_key_user else None flow_id_str = str(flow.id) - if input_request.session_id: - session_data = await session_service.load_session(input_request.session_id, flow_id=flow_id_str) - graph, artifacts = session_data if session_data else (None, None) - if graph is None: - raise ValueError(f"Session {input_request.session_id} not found") - else: - if flow.data is None: - raise ValueError(f"Flow {flow_id_str} has no data") - graph_data = flow.data - graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream) - graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id)) + if flow.data is None: + raise ValueError(f"Flow {flow_id_str} has no data") + graph_data = flow.data.copy() + graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream) + graph = Graph.from_payload(graph_data, flow_id=flow_id_str, user_id=str(user_id)) inputs = [ InputValueRequest(components=[], input_value=input_request.input_value, type=input_request.input_type) ] @@ -101,8 +91,6 @@ async def simple_run_flow( session_id=input_request.session_id, inputs=inputs, outputs=outputs, - artifacts=artifacts, - session_service=session_service, stream=stream, ) @@ -175,10 +163,8 @@ async def simplified_run_flow( """ try: return await simple_run_flow( - db=db, flow=flow, input_request=input_request, - session_service=session_service, stream=stream, api_key_user=api_key_user, ) @@ -249,7 +235,6 @@ async def webhook_run_flow( db=db, flow=flow, input_request=input_request, - session_service=session_service, ) return {"message": "Task started in the background", "status": "in progress"} except Exception as exc: @@ -528,3 +513,4 @@ def get_config(): except Exception as exc: logger.exception(exc) raise HTTPException(status_code=500, detail=str(exc)) from exc + raise HTTPException(status_code=500, detail=str(exc)) from exc diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 963de3e28..af7062346 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -6,8 +6,6 @@ from uuid import UUID import yaml from cachetools import TTLCache, cachedmethod from langchain_core.documents import Document -from pydantic import BaseModel - from langflow.custom.code_parser.utils import ( extract_inner_type_from_generic_alias, extract_union_types_from_generic_alias, @@ -19,6 +17,7 @@ from langflow.schema.dotdict import dotdict from langflow.services.deps import get_storage_service, get_variable_service, session_scope from langflow.services.storage.service import StorageService from langflow.utils import validate +from pydantic import BaseModel if TYPE_CHECKING: from langflow.graph.graph.base import Graph @@ -160,9 +159,9 @@ class CustomComponent(Component): self.repr_value = self.status if isinstance(self.repr_value, dict): self.repr_value = yaml.dump(self.repr_value) - if isinstance(self.repr_value, BaseModel) and not isinstance(self.repr_value, Data): + if isinstance(self.repr_value, BaseModel) and not isinstance(self.repr_value, Record): self.repr_value = str(self.repr_value) - elif hasattr(self.repr_value, "to_json"): + elif hasattr(self.repr_value, "to_json") and not isinstance(self.repr_value, Record): self.repr_value = self.repr_value.to_json() return self.repr_value @@ -469,3 +468,4 @@ class CustomComponent(Component): Any: The result of the build process. """ raise NotImplementedError + raise NotImplementedError diff --git a/src/backend/base/langflow/processing/process.py b/src/backend/base/langflow/processing/process.py index 1b54d3f08..60f11b388 100644 --- a/src/backend/base/langflow/processing/process.py +++ b/src/backend/base/langflow/processing/process.py @@ -1,15 +1,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from loguru import logger -from pydantic import BaseModel - from langflow.graph.graph.base import Graph from langflow.graph.schema import RunOutputs from langflow.graph.vertex.base import Vertex from langflow.schema.graph import InputValue, Tweaks from langflow.schema.schema import INPUT_FIELD_NAME from langflow.services.deps import get_settings_service -from langflow.services.session.service import SessionService +from loguru import logger +from pydantic import BaseModel if TYPE_CHECKING: from langflow.api.v1.schemas import InputValueRequest @@ -27,18 +25,13 @@ async def run_graph_internal( session_id: Optional[str] = None, inputs: Optional[List["InputValueRequest"]] = None, outputs: Optional[List[str]] = None, - artifacts: Optional[Dict[str, Any]] = None, - session_service: Optional[SessionService] = None, ) -> tuple[List[RunOutputs], str]: """Run the graph and generate the result""" inputs = inputs or [] - graph_data = graph._graph_data - if session_id is None and session_service is not None: - session_id_str = session_service.generate_key(session_id=flow_id, data_graph=graph_data) - elif session_id is not None: - session_id_str = session_id + if session_id is None: + session_id_str = flow_id else: - raise ValueError("session_id or session_service must be provided") + session_id_str = session_id components = [] inputs_list = [] types = [] @@ -53,16 +46,14 @@ async def run_graph_internal( fallback_to_env_vars = get_settings_service().settings.fallback_to_env_var run_outputs = await graph.arun( - inputs_list, - components, - types, - outputs or [], + inputs=inputs_list, + inputs_components=components, + types=types, + outputs=outputs or [], stream=stream, session_id=session_id_str or "", fallback_to_env_vars=fallback_to_env_vars, ) - if session_id_str and session_service: - await session_service.update_session(session_id_str, (graph, artifacts)) return run_outputs, session_id_str diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 1f1dcd5df..86e7ce539 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -701,36 +701,36 @@ def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_ assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY -@pytest.mark.api_key_required -def test_run_flow_with_session_id(client, starter_project, created_api_key): - headers = {"x-api-key": created_api_key.api_key} - flow_id = starter_project["id"] - payload = { - "input_value": "value1", - "input_type": "text", - "output_type": "text", - "session_id": "test-session-id", - } - response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert {"detail": "Session test-session-id not found"} == data +# @pytest.mark.api_key_required +# def test_run_flow_with_session_id(client, starter_project, created_api_key): +# headers = {"x-api-key": created_api_key.api_key} +# flow_id = starter_project["id"] +# payload = { +# "input_value": "value1", +# "input_type": "text", +# "output_type": "text", +# "session_id": "test-session-id", +# } +# response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) +# assert response.status_code == status.HTTP_404_NOT_FOUND +# data = response.json() +# assert {"detail": "Session test-session-id not found"} == data -def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key): - headers = {"x-api-key": created_api_key.api_key} - flow_id = starter_project["id"] - payload = { - "input_value": "value1", - "input_type": "text", - "output_type": "text", - "session_id": "invalid-session-id", - } - response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert "detail" in data - assert f"Session {payload['session_id']} not found" in data["detail"] +# def test_run_flow_with_invalid_session_id(client, starter_project, created_api_key): +# headers = {"x-api-key": created_api_key.api_key} +# flow_id = starter_project["id"] +# payload = { +# "input_value": "value1", +# "input_type": "text", +# "output_type": "text", +# "session_id": "invalid-session-id", +# } +# response = client.post(f"/api/v1/run/{flow_id}", json=payload, headers=headers) +# assert response.status_code == status.HTTP_404_NOT_FOUND +# data = response.json() +# assert "detail" in data +# assert f"Session {payload['session_id']} not found" in data["detail"] @pytest.mark.api_key_required