From a9bb04ee24594079e004cc37efe60de9d9dee1bc Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 14 Aug 2023 17:05:40 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(base.py):=20add=20`is=5Ftask?= =?UTF-8?q?`=20parameter=20to=20`Vertex`=20constructor=20to=20indicate=20i?= =?UTF-8?q?f=20the=20vertex=20is=20a=20task=20=E2=9C=A8=20feat(base.py):?= =?UTF-8?q?=20add=20`get=5Fresult`=20method=20to=20`Vertex`=20to=20retriev?= =?UTF-8?q?e=20the=20result=20of=20a=20built=20vertex=20=F0=9F=90=9B=20fix?= =?UTF-8?q?(types.py):=20pass=20`is=5Ftask=3DTrue`=20to=20`super().=5F=5Fi?= =?UTF-8?q?nit=5F=5F`=20in=20`CustomComponentVertex`=20constructor=20?= =?UTF-8?q?=E2=9C=A8=20feat(worker.py):=20add=20`build=5Fvertex`=20task=20?= =?UTF-8?q?to=20build=20a=20vertex=20asynchronously?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/vertex/base.py | 32 ++++++++++++++++++++-- src/backend/langflow/graph/vertex/types.py | 2 +- src/backend/langflow/worker.py | 13 +++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index ac7f72b4d..39cf79bd3 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -10,13 +10,16 @@ import inspect import types from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING +from celery.result import AsyncResult if TYPE_CHECKING: from langflow.graph.edge.base import Edge class Vertex: - def __init__(self, data: Dict, base_type: Optional[str] = None) -> None: + def __init__( + self, data: Dict, base_type: Optional[str] = None, is_task: bool = False + ) -> None: self.id: str = data["id"] self._data = data self.edges: List["Edge"] = [] @@ -25,6 +28,8 @@ class Vertex: self._built_object = None self._built = False self.artifacts: Dict[str, Any] = {} + self.task_id: Optional[str] = None + self.is_task = is_task def _parse_data(self) -> None: self.data = self._data["data"] @@ -168,11 +173,32 @@ class Vertex: """ return all(self._is_node(node) for node in value) + def get_result(self, timeout=None) -> Any: + # Check if the Vertex was built already + if self._built: + return self._built_object + + # Check if there's a task_id, which means it was sent to a Celery worker + if self.is_task and self.task_id is not None: + result = AsyncResult(self.task_id).get( + timeout=timeout + ) # Blocking until result is ready or timeout + if result is not None: # If result is ready + self._update_built_object_and_artifacts(result) + return self._built_object + else: + # Handle the case when the result is not ready (retry, throw exception, etc.) + pass + + # If there's no task_id, build the vertex locally + return self.build() + def _build_node_and_update_params(self, key, node): """ Builds a given node and updates the params dictionary accordingly. """ - result = node.build() + + result = node.get_result() self._handle_func(key, result) if isinstance(result, list): self._extend_params_list_with_result(key, result) @@ -184,7 +210,7 @@ class Vertex: """ self.params[key] = [] for node in nodes: - built = node.build() + built = node.get_result() if isinstance(built, list): if key not in self.params: self.params[key] = [] diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index b7ac17983..ab3e54e0b 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -247,7 +247,7 @@ class OutputParserVertex(Vertex): class CustomComponentVertex(Vertex): def __init__(self, data: Dict): - super().__init__(data, base_type="custom_components") + super().__init__(data, base_type="custom_components", is_task=True) def _built_object_repr(self): if self.artifacts and "repr" in self.artifacts: diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index d390705f1..67eaf6b3b 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -1,5 +1,9 @@ from langflow.core.celery_app import celery_app from typing import Any, Dict, Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from langflow.graph.vertex.base import Vertex @celery_app.task(acks_late=True) @@ -7,6 +11,15 @@ def test_celery(word: str) -> str: return f"test task return {word}" +@celery_app.task +def build_vertex(vertex: "Vertex") -> "Vertex": + """ + Build a vertex + """ + vertex.build() + return vertex + + @celery_app.task(acks_late=True) def process_graph_cached( data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False