From 108514b73480cb5e18f06e947eb0c5d161bf9fde Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Mon, 10 Jun 2024 19:00:31 -0300 Subject: [PATCH] refactor: Update langflow/custom/custom_component/component.py to cache output values for improved performance --- .../custom/custom_component/component.py | 18 ++++++++++------- .../base/langflow/template/field/base.py | 20 +++++++++++++++++++ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 6e1b2e723..c25a34d93 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -18,7 +18,7 @@ from loguru import logger from pydantic import BaseModel from langflow.schema.record import Record -from langflow.template.field.base import Input, Output +from langflow.template.field.base import UNDEFINED, Input, Output from .custom_component import CustomComponent @@ -73,12 +73,16 @@ class Component(CustomComponent): # or if it's not connected to any vertex if not vertex.outgoing_edges or output.name in vertex.edges_source_names: method: Callable | Awaitable = getattr(self, output.method) - result = method() - # If the method is asynchronous, we need to await it - if inspect.iscoroutinefunction(method): - result = await result - _results[output.name] = result - self._results = _results + if output.cache and not isinstance(output.value, UNDEFINED): + _results[output.name] = output.value + else: + result = method() + # If the method is asynchronous, we need to await it + if inspect.iscoroutinefunction(method): + result = await result + _results[output.name] = result + output.value = result + return _results def custom_repr(self): diff --git a/src/backend/base/langflow/template/field/base.py b/src/backend/base/langflow/template/field/base.py index 808a59dfc..d2a5d870c 100644 --- a/src/backend/base/langflow/template/field/base.py +++ b/src/backend/base/langflow/template/field/base.py @@ -1,3 +1,4 @@ +from enum import Enum from types import GenericAlias from typing import Any, Callable, Optional, Union @@ -7,6 +8,13 @@ from langflow.field_typing import Text from langflow.field_typing.range_spec import RangeSpec +class UndefinedType(Enum): + undefined = "__UNDEFINED__" + + +UNDEFINED = UndefinedType.undefined + + class Input(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -162,6 +170,10 @@ class Output(BaseModel): method: Optional[str] = Field(default=None) """The method to use for the output.""" + value: Optional[Any] = Field(default=UNDEFINED) + + cache: bool = Field(default=True) + def to_dict(self): return self.model_dump(by_alias=True, exclude_none=True) @@ -185,5 +197,13 @@ class Output(BaseModel): @model_serializer(mode="wrap") def serialize_model(self, handler): result = handler(self) + if self.value == UNDEFINED: + result["value"] = UNDEFINED.value return result + + @model_validator(mode="after") + def validate_model(self): + if self.value == UNDEFINED.value: + self.value = UNDEFINED + return self