langflow/src/backend/langflow/interface/custom/custom_component/component.py
2024-02-22 11:57:03 -03:00

86 lines
2.7 KiB
Python

import operator
import warnings
from typing import Any, ClassVar, Optional
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.interface.custom.attributes import ATTR_FUNC_MAPPING
from langflow.interface.custom.code_parser import CodeParser
from langflow.interface.custom.eval import eval_custom_component_code
from langflow.utils import validate
class ComponentCodeNullError(HTTPException):
pass
class ComponentFunctionEntrypointNameNullError(HTTPException):
pass
class Component:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = (
"The name of the entrypoint function must be provided."
)
code: Optional[str] = None
_function_entrypoint_name: str = "build"
field_config: dict = {}
_user_id: Optional[str]
def __init__(self, **data):
self.cache = TTLCache(maxsize=1024, ttl=60)
for key, value in data.items():
if key == "user_id":
setattr(self, "_user_id", value)
else:
setattr(self, key, value)
def __setattr__(self, key, value):
if key == "_user_id" and hasattr(self, "_user_id"):
warnings.warn("user_id is immutable and cannot be changed.")
super().__setattr__(key, value)
@cachedmethod(cache=operator.attrgetter("cache"))
def get_code_tree(self, code: str):
parser = CodeParser(code)
return parser.parse_code()
def get_function(self):
if not self.code:
raise ComponentCodeNullError(
status_code=400,
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
)
if not self._function_entrypoint_name:
raise ComponentFunctionEntrypointNameNullError(
status_code=400,
detail={
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
"traceback": "",
},
)
return validate.create_function(self.code, self._function_entrypoint_name)
def build_template_config(self) -> dict:
if not self.code:
return {}
cc_class = eval_custom_component_code(self.code)
component_instance = cc_class()
template_config = {}
for attribute, func in ATTR_FUNC_MAPPING.items():
if hasattr(component_instance, attribute):
value = getattr(component_instance, attribute)
if value is not None:
template_config[attribute] = func(value=value)
return template_config
def build(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError