diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index ac7f72b4d..39cf79bd3 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -10,13 +10,16 @@ import inspect import types from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING +from celery.result import AsyncResult if TYPE_CHECKING: from langflow.graph.edge.base import Edge class Vertex: - def __init__(self, data: Dict, base_type: Optional[str] = None) -> None: + def __init__( + self, data: Dict, base_type: Optional[str] = None, is_task: bool = False + ) -> None: self.id: str = data["id"] self._data = data self.edges: List["Edge"] = [] @@ -25,6 +28,8 @@ class Vertex: self._built_object = None self._built = False self.artifacts: Dict[str, Any] = {} + self.task_id: Optional[str] = None + self.is_task = is_task def _parse_data(self) -> None: self.data = self._data["data"] @@ -168,11 +173,32 @@ class Vertex: """ return all(self._is_node(node) for node in value) + def get_result(self, timeout=None) -> Any: + # Check if the Vertex was built already + if self._built: + return self._built_object + + # Check if there's a task_id, which means it was sent to a Celery worker + if self.is_task and self.task_id is not None: + result = AsyncResult(self.task_id).get( + timeout=timeout + ) # Blocking until result is ready or timeout + if result is not None: # If result is ready + self._update_built_object_and_artifacts(result) + return self._built_object + else: + # Handle the case when the result is not ready (retry, throw exception, etc.) + pass + + # If there's no task_id, build the vertex locally + return self.build() + def _build_node_and_update_params(self, key, node): """ Builds a given node and updates the params dictionary accordingly. """ - result = node.build() + + result = node.get_result() self._handle_func(key, result) if isinstance(result, list): self._extend_params_list_with_result(key, result) @@ -184,7 +210,7 @@ class Vertex: """ self.params[key] = [] for node in nodes: - built = node.build() + built = node.get_result() if isinstance(built, list): if key not in self.params: self.params[key] = [] diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index b7ac17983..ab3e54e0b 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -247,7 +247,7 @@ class OutputParserVertex(Vertex): class CustomComponentVertex(Vertex): def __init__(self, data: Dict): - super().__init__(data, base_type="custom_components") + super().__init__(data, base_type="custom_components", is_task=True) def _built_object_repr(self): if self.artifacts and "repr" in self.artifacts: diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index d390705f1..67eaf6b3b 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -1,5 +1,9 @@ from langflow.core.celery_app import celery_app from typing import Any, Dict, Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langflow.graph.vertex.base import Vertex @celery_app.task(acks_late=True) @@ -7,6 +11,15 @@ def test_celery(word: str) -> str: return f"test task return {word}" +@celery_app.task +def build_vertex(vertex: "Vertex") -> "Vertex": + """ + Build a vertex + """ + vertex.build() + return vertex + + @celery_app.task(acks_late=True) def process_graph_cached( data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False