diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 945d7db2c..999782ffc 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -1,11 +1,11 @@ import operator +from pathlib import Path from typing import Any, Callable, ClassVar, List, Optional, Union from uuid import UUID import yaml from cachetools import TTLCache, cachedmethod from fastapi import HTTPException - from langflow.interface.custom.code_parser.utils import ( extract_inner_type_from_generic_alias, extract_union_types_from_generic_alias, @@ -13,7 +13,11 @@ from langflow.interface.custom.code_parser.utils import ( from langflow.interface.custom.custom_component.component import Component from langflow.services.database.models.flow import Flow from langflow.services.database.utils import session_getter -from langflow.services.deps import get_credential_service, get_db_service, get_storage_service +from langflow.services.deps import ( + get_credential_service, + get_db_service, + get_storage_service, +) from langflow.services.storage.service import StorageService from langflow.utils import validate @@ -42,6 +46,16 @@ class CustomComponent(Component): self.cache = TTLCache(maxsize=1024, ttl=60) super().__init__(**data) + @staticmethod + def resolve_path(path: str) -> str: + """Resolves the path to an absolute path.""" + path_object = Path(path) + if path_object.parts[0] == "~": + path_object = path_object.expanduser() + elif path_object.is_relative_to("."): + path_object = path_object.resolve() + return str(path_object) + def get_full_path(self, path: str) -> str: storage_svc: "StorageService" = get_storage_service() @@ -78,7 +92,8 @@ class CustomComponent(Component): detail={ "error": "Type hint Error", "traceback": ( - "Prompt type is not supported in the build method." " Try using PromptTemplate instead." + "Prompt type is not supported in the build method." + " Try using PromptTemplate instead." ), }, ) @@ -92,14 +107,20 @@ class CustomComponent(Component): if not self.code: return {} - component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]] + component_classes = [ + cls + for cls in self.tree["classes"] + if self.code_class_base_inheritance in cls["bases"] + ] if not component_classes: return {} # Assume the first Component class is the one we're interested in component_class = component_classes[0] build_methods = [ - method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name + method + for method in component_class["methods"] + if method["name"] == self.function_entrypoint_name ] return build_methods[0] if build_methods else {} @@ -112,7 +133,10 @@ class CustomComponent(Component): return_type = build_method["return_type"] # If list or List is in the return type, then we remove it and return the inner type - if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]: + if hasattr(return_type, "__origin__") and return_type.__origin__ in [ + list, + List, + ]: return_type = extract_inner_type_from_generic_alias(return_type) # If the return type is not a Union, then we just return it as a list @@ -153,7 +177,9 @@ class CustomComponent(Component): # Retrieve and decrypt the credential by name for the current user db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session) + return credential_service.get_credential( + user_id=self._user_id or "", name=name, session=session + ) return get_credential @@ -163,7 +189,9 @@ class CustomComponent(Component): credential_service = get_credential_service() db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.list_credentials(user_id=self._user_id, session=session) + return credential_service.list_credentials( + user_id=self._user_id, session=session + ) def index(self, value: int = 0): """Returns a function that returns the value at the given index in the iterable.""" @@ -214,7 +242,11 @@ class CustomComponent(Component): if flow_id: flow = session.query(Flow).get(flow_id) elif flow_name: - flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first() + flow = ( + session.query(Flow) + .filter(Flow.name == flow_name) + .filter(Flow.user_id == self.user_id) + ).first() else: raise ValueError("Either flow_name or flow_id must be provided")