🔧 fix(endpoints.py): remove unused import and function call to process_graph_cached
🔧 fix(worker.py): refactor process_graph_cached into process_graph_cached_task and update function signature and implementation to use SessionManager for loading and updating the graph
This commit is contained in:
parent
c195456c49
commit
682e611947
2 changed files with 31 additions and 55 deletions
|
|
@ -3,7 +3,7 @@ from typing import Annotated, Optional, Union
|
|||
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.processing.process import process_graph_cached, process_tweaks
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.services.utils import get_settings_manager
|
||||
from langflow.utils.logger import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body
|
||||
|
|
@ -26,6 +26,7 @@ from langflow.interface.types import (
|
|||
)
|
||||
|
||||
from langflow.services.utils import get_session
|
||||
from langflow.worker import process_graph_cached_task
|
||||
from sqlmodel import Session
|
||||
|
||||
# build router
|
||||
|
|
@ -95,9 +96,9 @@ async def process_flow(
|
|||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing tweaks: {exc}")
|
||||
response, session_id = process_graph_cached(
|
||||
response, session_id = process_graph_cached_task.delay(
|
||||
graph_data, inputs, clear_cache, session_id
|
||||
)
|
||||
).get()
|
||||
return ProcessResponse(result=response, session_id=session_id)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
from langflow.core.celery_app import celery_app
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from langflow.processing.process import (
|
||||
clear_caches_if_needed,
|
||||
generate_result,
|
||||
process_inputs,
|
||||
)
|
||||
from langflow.services.manager import initialize_session_manager
|
||||
from langflow.services.utils import get_session_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
|
@ -28,55 +35,23 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex":
|
|||
|
||||
|
||||
@celery_app.task(acks_late=True)
|
||||
def process_graph_cached(
|
||||
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
|
||||
):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
from langflow.interface.run import build_sorted_vertices_with_caching
|
||||
from langflow.processing.process import get_result_and_thought
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langflow.utils.logger import logger
|
||||
def process_graph_cached_task(
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[dict] = None,
|
||||
clear_cache=False,
|
||||
session_id=None,
|
||||
) -> Tuple[Any, str]:
|
||||
initialize_session_manager()
|
||||
clear_caches_if_needed(clear_cache)
|
||||
session_manager = get_session_manager()
|
||||
# Load the graph using SessionManager
|
||||
langchain_object, artifacts = session_manager.load_session(session_id, data_graph)
|
||||
processed_inputs = process_inputs(inputs, artifacts)
|
||||
result = generate_result(langchain_object, processed_inputs)
|
||||
# langchain_object is now updated with the new memory
|
||||
# we need to update the cache with the updated langchain_object
|
||||
session_manager.update_session(
|
||||
session_id, data_graph, (langchain_object, artifacts)
|
||||
)
|
||||
|
||||
# Load langchain object
|
||||
if clear_cache:
|
||||
build_sorted_vertices_with_caching.clear_cache()
|
||||
logger.debug("Cleared cache")
|
||||
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
|
||||
logger.debug("Loaded LangChain object")
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
|
||||
# Add artifacts to inputs
|
||||
# artifacts can be documents loaded when building
|
||||
# the flow
|
||||
for (
|
||||
key,
|
||||
value,
|
||||
) in artifacts.items():
|
||||
if key not in inputs or not inputs[key]:
|
||||
inputs[key] = value
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
if isinstance(langchain_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
logger.debug("Generating result and thought")
|
||||
result = get_result_and_thought(langchain_object, inputs)
|
||||
logger.debug("Generated result and thought")
|
||||
elif isinstance(langchain_object, VectorStore):
|
||||
result = langchain_object.search(**inputs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown langchain_object type: {type(langchain_object).__name__}"
|
||||
)
|
||||
return result
|
||||
return result, session_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue