🔧 fix(endpoints.py): remove unused import and function call to process_graph_cached

🔧 fix(worker.py): refactor process_graph_cached into process_graph_cached_task and update function signature and implementation to use SessionManager for loading and updating the graph
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-18 08:57:12 -03:00
commit 682e611947
2 changed files with 31 additions and 55 deletions

View file

@ -3,7 +3,7 @@ from typing import Annotated, Optional, Union
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.processing.process import process_graph_cached, process_tweaks
from langflow.processing.process import process_tweaks
from langflow.services.utils import get_settings_manager
from langflow.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body
@ -26,6 +26,7 @@ from langflow.interface.types import (
)
from langflow.services.utils import get_session
from langflow.worker import process_graph_cached_task
from sqlmodel import Session
# build router
@ -95,9 +96,9 @@ async def process_flow(
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
response, session_id = process_graph_cached(
response, session_id = process_graph_cached_task.delay(
graph_data, inputs, clear_cache, session_id
)
).get()
return ProcessResponse(result=response, session_id=session_id)
except Exception as e:
# Log stack trace

View file

@ -1,7 +1,14 @@
from langflow.core.celery_app import celery_app
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING
from celery.exceptions import SoftTimeLimitExceeded
from langflow.processing.process import (
clear_caches_if_needed,
generate_result,
process_inputs,
)
from langflow.services.manager import initialize_session_manager
from langflow.services.utils import get_session_manager
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
@ -28,55 +35,23 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex":
@celery_app.task(acks_late=True)
def process_graph_cached(
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
from langflow.interface.run import build_sorted_vertices_with_caching
from langflow.processing.process import get_result_and_thought
from langchain.chains.base import Chain
from langchain.vectorstores.base import VectorStore
from langflow.utils.logger import logger
def process_graph_cached_task(
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
clear_cache=False,
session_id=None,
) -> Tuple[Any, str]:
initialize_session_manager()
clear_caches_if_needed(clear_cache)
session_manager = get_session_manager()
# Load the graph using SessionManager
langchain_object, artifacts = session_manager.load_session(session_id, data_graph)
processed_inputs = process_inputs(inputs, artifacts)
result = generate_result(langchain_object, processed_inputs)
# langchain_object is now updated with the new memory
# we need to update the cache with the updated langchain_object
session_manager.update_session(
session_id, data_graph, (langchain_object, artifacts)
)
# Load langchain object
if clear_cache:
build_sorted_vertices_with_caching.clear_cache()
logger.debug("Cleared cache")
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
logger.debug("Loaded LangChain object")
if inputs is None:
inputs = {}
# Add artifacts to inputs
# artifacts can be documents loaded when building
# the flow
for (
key,
value,
) in artifacts.items():
if key not in inputs or not inputs[key]:
inputs[key] = value
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
if isinstance(langchain_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
logger.debug("Generating result and thought")
result = get_result_and_thought(langchain_object, inputs)
logger.debug("Generated result and thought")
elif isinstance(langchain_object, VectorStore):
result = langchain_object.search(**inputs)
else:
raise ValueError(
f"Unknown langchain_object type: {type(langchain_object).__name__}"
)
return result
return result, session_id