diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 50f3b06b6..a4b17c15e 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -214,19 +214,18 @@ async def build_flow( vertex = graph.get_vertex(vertex_id) try: lock = chat_service._async_cache_locks[flow_id_str] - ( - result_dict, - params, - valid, - artifacts, - vertex, - ) = await graph.build_vertex( - chat_service=chat_service, + vertex_build_result = await graph.build_vertex( vertex_id=vertex_id, user_id=current_user.id, inputs_dict=inputs.model_dump() if inputs else {}, files=files, + get_cache=chat_service.get_cache, + set_cache=chat_service.set_cache, ) + result_dict = vertex_build_result.result_dict + params = vertex_build_result.params + valid = vertex_build_result.valid + artifacts = vertex_build_result.artifacts next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) @@ -476,19 +475,18 @@ async def build_vertex( try: lock = chat_service._async_cache_locks[flow_id_str] - ( - result_dict, - params, - valid, - artifacts, - vertex, - ) = await graph.build_vertex( - chat_service=chat_service, + vertex_build_result = await graph.build_vertex( vertex_id=vertex_id, user_id=current_user.id, inputs_dict=inputs.model_dump() if inputs else {}, files=files, + get_cache=chat_service.get_cache, + set_cache=chat_service.set_cache, ) + result_dict = vertex_build_result.result_dict + params = vertex_build_result.params + valid = vertex_build_result.valid + artifacts = vertex_build_result.artifacts next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False) top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 42da6a41a..54c525137 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -13,6 +13,7 @@ from langflow.graph.edge.base import ContractEdge from langflow.graph.edge.schema import EdgeData from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager +from langflow.graph.graph.schema import VertexBuildResult from langflow.graph.graph.state_manager import GraphStateManager from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex from langflow.graph.schema import InterfaceComponentTypes, RunOutputs @@ -21,7 +22,7 @@ from langflow.graph.vertex.types import InterfaceVertex, StateVertex from langflow.schema import Data from langflow.schema.schema import INPUT_FIELD_NAME, InputType from langflow.services.cache.utils import CacheMiss -from langflow.services.chat.service import ChatService +from langflow.services.chat.schema import GetCache, SetCache from langflow.services.deps import get_chat_service, get_tracing_service if TYPE_CHECKING: @@ -858,13 +859,14 @@ class Graph: async def build_vertex( self, - chat_service: Optional[ChatService], vertex_id: str, + get_cache: GetCache | None = None, + set_cache: SetCache | None = None, inputs_dict: Optional[Dict[str, str]] = None, files: Optional[list[str]] = None, user_id: Optional[str] = None, fallback_to_env_vars: bool = False, - ): + ) -> VertexBuildResult: """ Builds a vertex in the graph. @@ -887,10 +889,17 @@ class Graph: try: params = "" if vertex.frozen: - if chat_service: - cached_result = await chat_service.get_cache(key=vertex.id) + # Check the cache for the vertex + if get_cache is not None: + cached_result = await get_cache(key=vertex.id) else: cached_result = None + if isinstance(cached_result, CacheMiss): + await vertex.build( + user_id=user_id, inputs=inputs_dict, fallback_to_env_vars=fallback_to_env_vars, files=files + ) + if set_cache is not None: + await set_cache(key=vertex.id, data=vertex) if cached_result and not isinstance(cached_result, CacheMiss): cached_vertex = cached_result["result"] # Now set update the vertex with the cached vertex @@ -906,14 +915,14 @@ class Graph: await vertex.build( user_id=user_id, inputs=inputs_dict, fallback_to_env_vars=fallback_to_env_vars, files=files ) - if chat_service: - await chat_service.set_cache(key=vertex.id, data=vertex) + if set_cache is not None: + await set_cache(key=vertex.id, data=vertex) else: await vertex.build( user_id=user_id, inputs=inputs_dict, fallback_to_env_vars=fallback_to_env_vars, files=files ) - if chat_service: - await chat_service.set_cache(key=vertex.id, data=vertex) + if set_cache is not None: + await set_cache(key=vertex.id, data=vertex) if vertex.result is not None: params = f"{vertex._built_object_repr()}{params}" @@ -922,7 +931,11 @@ class Graph: artifacts = vertex.artifacts else: raise ValueError(f"No result found for vertex {vertex_id}") - return result_dict, params, valid, artifacts, vertex + + vertex_build_result = VertexBuildResult( + result_dict=result_dict, params=params, valid=valid, artifacts=artifacts, vertex=vertex + ) + return vertex_build_result except Exception as exc: if not isinstance(exc, ComponentBuildException): logger.exception(f"Error building Component: \n\n{exc}") @@ -976,11 +989,12 @@ class Graph: vertex = self.get_vertex(vertex_id) task = asyncio.create_task( self.build_vertex( - chat_service=chat_service, vertex_id=vertex_id, user_id=self.user_id, inputs_dict={}, fallback_to_env_vars=fallback_to_env_vars, + get_cache=chat_service.get_cache, + set_cache=chat_service.set_cache, ), name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}", ) diff --git a/src/backend/base/langflow/graph/graph/schema.py b/src/backend/base/langflow/graph/graph/schema.py new file mode 100644 index 000000000..30d67255f --- /dev/null +++ b/src/backend/base/langflow/graph/graph/schema.py @@ -0,0 +1,13 @@ +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from langflow.graph.schema import ResultData + from langflow.graph.vertex.base import Vertex + + +class VertexBuildResult(NamedTuple): + result_dict: "ResultData" + params: str + valid: bool + artifacts: dict + vertex: "Vertex" diff --git a/src/backend/base/langflow/services/chat/schema.py b/src/backend/base/langflow/services/chat/schema.py new file mode 100644 index 000000000..51cf32e22 --- /dev/null +++ b/src/backend/base/langflow/services/chat/schema.py @@ -0,0 +1,10 @@ +import asyncio +from typing import Any, Protocol + + +class GetCache(Protocol): + async def __call__(self, key: str, lock: asyncio.Lock | None = None) -> Any: ... + + +class SetCache(Protocol): + async def __call__(self, key: str, data: Any, lock: asyncio.Lock | None = None) -> bool: ...