Refactor CustomComponent class and add resolve_path method
This commit is contained in:
parent
3d4821cf98
commit
24a602f66e
1 changed files with 41 additions and 9 deletions
|
|
@ -1,11 +1,11 @@
|
|||
import operator
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
|
||||
from langflow.interface.custom.code_parser.utils import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
|
|
@ -13,7 +13,11 @@ from langflow.interface.custom.code_parser.utils import (
|
|||
from langflow.interface.custom.custom_component.component import Component
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_credential_service, get_db_service, get_storage_service
|
||||
from langflow.services.deps import (
|
||||
get_credential_service,
|
||||
get_db_service,
|
||||
get_storage_service,
|
||||
)
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.utils import validate
|
||||
|
||||
|
|
@ -42,6 +46,16 @@ class CustomComponent(Component):
|
|||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
super().__init__(**data)
|
||||
|
||||
@staticmethod
|
||||
def resolve_path(path: str) -> str:
|
||||
"""Resolves the path to an absolute path."""
|
||||
path_object = Path(path)
|
||||
if path_object.parts[0] == "~":
|
||||
path_object = path_object.expanduser()
|
||||
elif path_object.is_relative_to("."):
|
||||
path_object = path_object.resolve()
|
||||
return str(path_object)
|
||||
|
||||
def get_full_path(self, path: str) -> str:
|
||||
storage_svc: "StorageService" = get_storage_service()
|
||||
|
||||
|
|
@ -78,7 +92,8 @@ class CustomComponent(Component):
|
|||
detail={
|
||||
"error": "Type hint Error",
|
||||
"traceback": (
|
||||
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
|
||||
"Prompt type is not supported in the build method."
|
||||
" Try using PromptTemplate instead."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
@ -92,14 +107,20 @@ class CustomComponent(Component):
|
|||
if not self.code:
|
||||
return {}
|
||||
|
||||
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in self.tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return {}
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
return build_methods[0] if build_methods else {}
|
||||
|
|
@ -112,7 +133,10 @@ class CustomComponent(Component):
|
|||
return_type = build_method["return_type"]
|
||||
|
||||
# If list or List is in the return type, then we remove it and return the inner type
|
||||
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
|
||||
if hasattr(return_type, "__origin__") and return_type.__origin__ in [
|
||||
list,
|
||||
List,
|
||||
]:
|
||||
return_type = extract_inner_type_from_generic_alias(return_type)
|
||||
|
||||
# If the return type is not a Union, then we just return it as a list
|
||||
|
|
@ -153,7 +177,9 @@ class CustomComponent(Component):
|
|||
# Retrieve and decrypt the credential by name for the current user
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
|
||||
return credential_service.get_credential(
|
||||
user_id=self._user_id or "", name=name, session=session
|
||||
)
|
||||
|
||||
return get_credential
|
||||
|
||||
|
|
@ -163,7 +189,9 @@ class CustomComponent(Component):
|
|||
credential_service = get_credential_service()
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.list_credentials(user_id=self._user_id, session=session)
|
||||
return credential_service.list_credentials(
|
||||
user_id=self._user_id, session=session
|
||||
)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
"""Returns a function that returns the value at the given index in the iterable."""
|
||||
|
|
@ -214,7 +242,11 @@ class CustomComponent(Component):
|
|||
if flow_id:
|
||||
flow = session.query(Flow).get(flow_id)
|
||||
elif flow_name:
|
||||
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
|
||||
flow = (
|
||||
session.query(Flow)
|
||||
.filter(Flow.name == flow_name)
|
||||
.filter(Flow.user_id == self.user_id)
|
||||
).first()
|
||||
else:
|
||||
raise ValueError("Either flow_name or flow_id must be provided")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue