langflow/src/backend/langflow/interface/custom/custom_component/custom_component.py
2024-01-25 19:44:33 -03:00

226 lines
8.4 KiB
Python

import operator
from typing import Any, Callable, ClassVar, List, Optional, Union
from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_credential_service, get_db_service
from langflow.utils import validate
from .component import Component
class CustomComponent(Component):
display_name: Optional[str] = None
description: Optional[str] = None
code: Optional[str] = None
field_config: dict = {}
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None
repr_value: Optional[Any] = ""
user_id: Optional[Union[UUID, str]] = None
status: Optional[Any] = None
_tree: Optional[dict] = None
def __init__(self, **data):
self.cache = TTLCache(maxsize=1024, ttl=60)
super().__init__(**data)
def custom_repr(self):
if self.repr_value == "":
self.repr_value = self.status
if isinstance(self.repr_value, dict):
return yaml.dump(self.repr_value)
if isinstance(self.repr_value, str):
return self.repr_value
return str(self.repr_value)
def build_config(self):
return self.field_config
@property
def tree(self):
return self.get_code_tree(self.code or "")
@property
def get_function_entrypoint_args(self) -> list:
build_method = self.get_build_method()
if not build_method:
return []
args = build_method["args"]
for arg in args:
if arg.get("type") == "prompt":
raise HTTPException(
status_code=400,
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
),
},
)
elif not arg.get("type") and arg.get("name") != "self":
# Set the type to Data
arg["type"] = "Data"
return args
@cachedmethod(operator.attrgetter("cache"))
def get_build_method(self):
if not self.code:
return {}
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@property
def get_function_entrypoint_return_type(self) -> List[Any]:
build_method = self.get_build_method()
if not build_method or not build_method.get("has_return"):
return []
return_type = build_method["return_type"]
# If list or List is in the return type, then we remove it and return the inner type
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
return_type = extract_inner_type_from_generic_alias(return_type)
# If the return type is not a Union, then we just return it as a list
if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union:
return return_type if isinstance(return_type, list) else [return_type]
# If the return type is a Union, then we need to parse itx
return_type = extract_union_types_from_generic_alias(return_type)
return return_type
@property
def get_main_class_name(self):
if not self.code:
return ""
base_name = self.code_class_base_inheritance
method_name = self.function_entrypoint_name
classes = []
for item in self.tree.get("classes", []):
if base_name in item["bases"]:
method_names = [method["name"] for method in item["methods"]]
if method_name in method_names:
classes.append(item["name"])
# Get just the first item
return next(iter(classes), "")
@property
def template_config(self):
return self.build_template_config()
def build_template_config(self):
if not self.code:
return {}
attributes = [
main_class["attributes"]
for main_class in self.tree.get("classes", [])
if main_class["name"] == self.get_main_class_name
]
# Get just the first item
attributes = next(iter(attributes), [])
return super().build_template_config(attributes)
@property
def keys(self):
def get_credential(name: str):
if hasattr(self, "_user_id") and not self._user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
credential_service = get_credential_service() # Get service instance
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return get_credential
def list_key_names(self):
if hasattr(self, "_user_id") and not self._user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
def get_index(iterable: List[Any]):
return iterable[value] if iterable else iterable
return get_index
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
from langflow.processing.process import build_sorted_vertices, process_tweaks
db_service = get_db_service()
with session_getter(db_service) as session:
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
if not graph_data:
raise ValueError(f"Flow {flow_id} not found")
if tweaks:
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
return await build_sorted_vertices(graph_data, self.user_id)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
if not self._user_id:
raise ValueError("Session is invalid")
try:
get_session = get_session or session_getter
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.query(Flow).filter(Flow.user_id == self.user_id).all()
return flows
except Exception as e:
raise ValueError("Session is invalid") from e
async def get_flow(
self,
*,
flow_name: Optional[str] = None,
flow_id: Optional[str] = None,
tweaks: Optional[dict] = None,
get_session: Optional[Callable] = None,
) -> Flow:
get_session = get_session or session_getter
db_service = get_db_service()
with get_session(db_service) as session:
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")
if not flow:
raise ValueError(f"Flow {flow_name or flow_id} not found")
return await self.load_flow(flow.id, tweaks)
def build(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError