🐛 fix(base.py): add is_task parameter to Vertex constructor to indicate if the vertex is a task

 feat(base.py): add `get_result` method to `Vertex` to retrieve the result of a built vertex
🐛 fix(types.py): pass `is_task=True` to `super().__init__` in `CustomComponentVertex` constructor
 feat(worker.py): add `build_vertex` task to build a vertex asynchronously
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-14 17:05:40 -03:00
commit a9bb04ee24
3 changed files with 43 additions and 4 deletions

View file

@ -10,13 +10,16 @@ import inspect
import types
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING
from celery.result import AsyncResult
if TYPE_CHECKING:
from langflow.graph.edge.base import Edge
class Vertex:
def __init__(self, data: Dict, base_type: Optional[str] = None) -> None:
def __init__(
self, data: Dict, base_type: Optional[str] = None, is_task: bool = False
) -> None:
self.id: str = data["id"]
self._data = data
self.edges: List["Edge"] = []
@ -25,6 +28,8 @@ class Vertex:
self._built_object = None
self._built = False
self.artifacts: Dict[str, Any] = {}
self.task_id: Optional[str] = None
self.is_task = is_task
def _parse_data(self) -> None:
self.data = self._data["data"]
@ -168,11 +173,32 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
def get_result(self, timeout=None) -> Any:
# Check if the Vertex was built already
if self._built:
return self._built_object
# Check if there's a task_id, which means it was sent to a Celery worker
if self.is_task and self.task_id is not None:
result = AsyncResult(self.task_id).get(
timeout=timeout
) # Blocking until result is ready or timeout
if result is not None: # If result is ready
self._update_built_object_and_artifacts(result)
return self._built_object
else:
# Handle the case when the result is not ready (retry, throw exception, etc.)
pass
# If there's no task_id, build the vertex locally
return self.build()
def _build_node_and_update_params(self, key, node):
"""
Builds a given node and updates the params dictionary accordingly.
"""
result = node.build()
result = node.get_result()
self._handle_func(key, result)
if isinstance(result, list):
self._extend_params_list_with_result(key, result)
@ -184,7 +210,7 @@ class Vertex:
"""
self.params[key] = []
for node in nodes:
built = node.build()
built = node.get_result()
if isinstance(built, list):
if key not in self.params:
self.params[key] = []

View file

@ -247,7 +247,7 @@ class OutputParserVertex(Vertex):
class CustomComponentVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="custom_components")
super().__init__(data, base_type="custom_components", is_task=True)
def _built_object_repr(self):
if self.artifacts and "repr" in self.artifacts:

View file

@ -1,5 +1,9 @@
from langflow.core.celery_app import celery_app
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
@celery_app.task(acks_late=True)
@ -7,6 +11,15 @@ def test_celery(word: str) -> str:
return f"test task return {word}"
@celery_app.task
def build_vertex(vertex: "Vertex") -> "Vertex":
"""
Build a vertex
"""
vertex.build()
return vertex
@celery_app.task(acks_late=True)
def process_graph_cached(
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False