Refactor CustomComponent class and add resolve_path method

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-16 08:57:33 -03:00
commit 24a602f66e

View file

@ -1,11 +1,11 @@
import operator
from pathlib import Path
from typing import Any, Callable, ClassVar, List, Optional, Union
from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
@ -13,7 +13,11 @@ from langflow.interface.custom.code_parser.utils import (
from langflow.interface.custom.custom_component.component import Component
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_credential_service, get_db_service, get_storage_service
from langflow.services.deps import (
get_credential_service,
get_db_service,
get_storage_service,
)
from langflow.services.storage.service import StorageService
from langflow.utils import validate
@ -42,6 +46,16 @@ class CustomComponent(Component):
self.cache = TTLCache(maxsize=1024, ttl=60)
super().__init__(**data)
@staticmethod
def resolve_path(path: str) -> str:
"""Resolves the path to an absolute path."""
path_object = Path(path)
if path_object.parts[0] == "~":
path_object = path_object.expanduser()
elif path_object.is_relative_to("."):
path_object = path_object.resolve()
return str(path_object)
def get_full_path(self, path: str) -> str:
storage_svc: "StorageService" = get_storage_service()
@ -78,7 +92,8 @@ class CustomComponent(Component):
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
"Prompt type is not supported in the build method."
" Try using PromptTemplate instead."
),
},
)
@ -92,14 +107,20 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -112,7 +133,10 @@ class CustomComponent(Component):
return_type = build_method["return_type"]
# If list or List is in the return type, then we remove it and return the inner type
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
list,
List,
]:
return_type = extract_inner_type_from_generic_alias(return_type)
# If the return type is not a Union, then we just return it as a list
@ -153,7 +177,9 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return get_credential
@ -163,7 +189,9 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -214,7 +242,11 @@ class CustomComponent(Component):
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
flow = (
session.query(Flow)
.filter(Flow.name == flow_name)
.filter(Flow.user_id == self.user_id)
).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")