From 682e611947e5de0c736943fe5a372a433e01ff7c Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 18 Aug 2023 08:57:12 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(endpoints.py):=20remove=20un?= =?UTF-8?q?used=20import=20and=20function=20call=20to=20process=5Fgraph=5F?= =?UTF-8?q?cached=20=F0=9F=94=A7=20fix(worker.py):=20refactor=20process=5F?= =?UTF-8?q?graph=5Fcached=20into=20process=5Fgraph=5Fcached=5Ftask=20and?= =?UTF-8?q?=20update=20function=20signature=20and=20implementation=20to=20?= =?UTF-8?q?use=20SessionManager=20for=20loading=20and=20updating=20the=20g?= =?UTF-8?q?raph?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/endpoints.py | 7 ++- src/backend/langflow/worker.py | 79 ++++++++---------------- 2 files changed, 31 insertions(+), 55 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index a6b85dde2..999f751d5 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -3,7 +3,7 @@ from typing import Annotated, Optional, Union from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow -from langflow.processing.process import process_graph_cached, process_tweaks +from langflow.processing.process import process_tweaks from langflow.services.utils import get_settings_manager from langflow.utils.logger import logger from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body @@ -26,6 +26,7 @@ from langflow.interface.types import ( ) from langflow.services.utils import get_session +from langflow.worker import process_graph_cached_task from sqlmodel import Session # build router @@ -95,9 +96,9 @@ async def process_flow( graph_data = process_tweaks(graph_data, tweaks) except Exception as exc: logger.error(f"Error processing tweaks: {exc}") - response, session_id = process_graph_cached( + response, session_id = process_graph_cached_task.delay( graph_data, inputs, clear_cache, session_id - ) + ).get() return ProcessResponse(result=response, session_id=session_id) except Exception as e: # Log stack trace diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index f8831df76..4790090f0 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -1,7 +1,14 @@ from langflow.core.celery_app import celery_app -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from typing import TYPE_CHECKING from celery.exceptions import SoftTimeLimitExceeded +from langflow.processing.process import ( + clear_caches_if_needed, + generate_result, + process_inputs, +) +from langflow.services.manager import initialize_session_manager +from langflow.services.utils import get_session_manager if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex @@ -28,55 +35,23 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex": @celery_app.task(acks_late=True) -def process_graph_cached( - data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False -): - """ - Process graph by extracting input variables and replacing ZeroShotPrompt - with PromptTemplate,then run the graph and return the result and thought. - """ - from langflow.interface.run import build_sorted_vertices_with_caching - from langflow.processing.process import get_result_and_thought - from langchain.chains.base import Chain - from langchain.vectorstores.base import VectorStore - from langflow.utils.logger import logger +def process_graph_cached_task( + data_graph: Dict[str, Any], + inputs: Optional[dict] = None, + clear_cache=False, + session_id=None, +) -> Tuple[Any, str]: + initialize_session_manager() + clear_caches_if_needed(clear_cache) + session_manager = get_session_manager() + # Load the graph using SessionManager + langchain_object, artifacts = session_manager.load_session(session_id, data_graph) + processed_inputs = process_inputs(inputs, artifacts) + result = generate_result(langchain_object, processed_inputs) + # langchain_object is now updated with the new memory + # we need to update the cache with the updated langchain_object + session_manager.update_session( + session_id, data_graph, (langchain_object, artifacts) + ) - # Load langchain object - if clear_cache: - build_sorted_vertices_with_caching.clear_cache() - logger.debug("Cleared cache") - langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph) - logger.debug("Loaded LangChain object") - if inputs is None: - inputs = {} - - # Add artifacts to inputs - # artifacts can be documents loaded when building - # the flow - for ( - key, - value, - ) in artifacts.items(): - if key not in inputs or not inputs[key]: - inputs[key] = value - - if langchain_object is None: - # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) - - # Generate result and thought - if isinstance(langchain_object, Chain): - if inputs is None: - raise ValueError("Inputs must be provided for a Chain") - logger.debug("Generating result and thought") - result = get_result_and_thought(langchain_object, inputs) - logger.debug("Generated result and thought") - elif isinstance(langchain_object, VectorStore): - result = langchain_object.search(**inputs) - else: - raise ValueError( - f"Unknown langchain_object type: {type(langchain_object).__name__}" - ) - return result + return result, session_id