refactor: Update langflow/custom/custom_component/component.py to cache output values for improved performance
This commit is contained in:
parent
4d11c76476
commit
108514b734
2 changed files with 31 additions and 7 deletions
|
|
@ -18,7 +18,7 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
|
||||
from langflow.schema.record import Record
|
||||
from langflow.template.field.base import Input, Output
|
||||
from langflow.template.field.base import UNDEFINED, Input, Output
|
||||
|
||||
from .custom_component import CustomComponent
|
||||
|
||||
|
|
@ -73,12 +73,16 @@ class Component(CustomComponent):
|
|||
# or if it's not connected to any vertex
|
||||
if not vertex.outgoing_edges or output.name in vertex.edges_source_names:
|
||||
method: Callable | Awaitable = getattr(self, output.method)
|
||||
result = method()
|
||||
# If the method is asynchronous, we need to await it
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await result
|
||||
_results[output.name] = result
|
||||
self._results = _results
|
||||
if output.cache and not isinstance(output.value, UNDEFINED):
|
||||
_results[output.name] = output.value
|
||||
else:
|
||||
result = method()
|
||||
# If the method is asynchronous, we need to await it
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await result
|
||||
_results[output.name] = result
|
||||
output.value = result
|
||||
|
||||
return _results
|
||||
|
||||
def custom_repr(self):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
from types import GenericAlias
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
|
|
@ -7,6 +8,13 @@ from langflow.field_typing import Text
|
|||
from langflow.field_typing.range_spec import RangeSpec
|
||||
|
||||
|
||||
class UndefinedType(Enum):
|
||||
undefined = "__UNDEFINED__"
|
||||
|
||||
|
||||
UNDEFINED = UndefinedType.undefined
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
|
@ -162,6 +170,10 @@ class Output(BaseModel):
|
|||
method: Optional[str] = Field(default=None)
|
||||
"""The method to use for the output."""
|
||||
|
||||
value: Optional[Any] = Field(default=UNDEFINED)
|
||||
|
||||
cache: bool = Field(default=True)
|
||||
|
||||
def to_dict(self):
|
||||
return self.model_dump(by_alias=True, exclude_none=True)
|
||||
|
||||
|
|
@ -185,5 +197,13 @@ class Output(BaseModel):
|
|||
@model_serializer(mode="wrap")
|
||||
def serialize_model(self, handler):
|
||||
result = handler(self)
|
||||
if self.value == UNDEFINED:
|
||||
result["value"] = UNDEFINED.value
|
||||
|
||||
return result
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model(self):
|
||||
if self.value == UNDEFINED.value:
|
||||
self.value = UNDEFINED
|
||||
return self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue