Refactor CustomComponent class and add new methods
This commit is contained in:
parent
c1a228fe6c
commit
ec6c00838d
1 changed files with 74 additions and 7 deletions
|
|
@ -1,17 +1,19 @@
|
|||
import operator
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Union
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.interface.custom.code_parser.utils import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
)
|
||||
from langflow.interface.custom.custom_component.component import Component
|
||||
from langflow.schema import Record
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import (
|
||||
|
|
@ -86,6 +88,56 @@ class CustomComponent(Component):
|
|||
def tree(self):
|
||||
return self.get_code_tree(self.code or "")
|
||||
|
||||
def to_records(
|
||||
self, data: Any, text_key: str = "text", data_key: str = "data"
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Convert data into a list of records.
|
||||
|
||||
Args:
|
||||
data (Any): The input data to be converted.
|
||||
text_key (str, optional): The key to extract the text from a dictionary item. Defaults to "text".
|
||||
data_key (str, optional): The key to extract the data from a dictionary item. Defaults to "data".
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of records, where each record is a dictionary with 'text' and 'data' keys.
|
||||
"""
|
||||
records = []
|
||||
if not isinstance(data, Sequence):
|
||||
data = [data]
|
||||
for item in data:
|
||||
if isinstance(item, str):
|
||||
records.append(Record(text=item))
|
||||
elif isinstance(item, dict):
|
||||
records.append(Record(text=item.get(text_key), data=item.get(data_key)))
|
||||
elif isinstance(item, Document):
|
||||
records.append(Record(text=item.page_content, data=item.metadata))
|
||||
else:
|
||||
raise ValueError(f"Invalid data type: {type(item)}")
|
||||
|
||||
return records
|
||||
|
||||
def create_references_from_records(
|
||||
self, records: List[dict], include_data: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Create references from a list of records.
|
||||
|
||||
Args:
|
||||
records (List[dict]): A list of records, where each record is a dictionary.
|
||||
include_data (bool, optional): Whether to include data in the references. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: A string containing the references in markdown format.
|
||||
"""
|
||||
markdown_string = "---\n"
|
||||
for record in records:
|
||||
markdown_string += f"- Text: {record['text']}"
|
||||
if include_data:
|
||||
markdown_string += f" Data: {record['data']}"
|
||||
markdown_string += "\n"
|
||||
return markdown_string
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_args(self) -> list:
|
||||
build_method = self.get_build_method()
|
||||
|
|
@ -100,7 +152,8 @@ class CustomComponent(Component):
|
|||
detail={
|
||||
"error": "Type hint Error",
|
||||
"traceback": (
|
||||
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
|
||||
"Prompt type is not supported in the build method."
|
||||
" Try using PromptTemplate instead."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
@ -114,14 +167,20 @@ class CustomComponent(Component):
|
|||
if not self.code:
|
||||
return {}
|
||||
|
||||
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in self.tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return {}
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
return build_methods[0] if build_methods else {}
|
||||
|
|
@ -178,7 +237,9 @@ class CustomComponent(Component):
|
|||
# Retrieve and decrypt the credential by name for the current user
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
|
||||
return credential_service.get_credential(
|
||||
user_id=self._user_id or "", name=name, session=session
|
||||
)
|
||||
|
||||
return get_credential
|
||||
|
||||
|
|
@ -188,7 +249,9 @@ class CustomComponent(Component):
|
|||
credential_service = get_credential_service()
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.list_credentials(user_id=self._user_id, session=session)
|
||||
return credential_service.list_credentials(
|
||||
user_id=self._user_id, session=session
|
||||
)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
"""Returns a function that returns the value at the given index in the iterable."""
|
||||
|
|
@ -239,7 +302,11 @@ class CustomComponent(Component):
|
|||
if flow_id:
|
||||
flow = session.query(Flow).get(flow_id)
|
||||
elif flow_name:
|
||||
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
|
||||
flow = (
|
||||
session.query(Flow)
|
||||
.filter(Flow.name == flow_name)
|
||||
.filter(Flow.user_id == self.user_id)
|
||||
).first()
|
||||
else:
|
||||
raise ValueError("Either flow_name or flow_id must be provided")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue