Refactor CustomComponent class and add new methods

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-26 09:52:06 -03:00
commit ec6c00838d

View file

@ -1,17 +1,19 @@
import operator
from pathlib import Path
from typing import Any, Callable, ClassVar, List, Optional, Union
from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union
from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langchain_core.documents import Document
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.interface.custom.custom_component.component import Component
from langflow.schema import Record
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import (
@ -86,6 +88,56 @@ class CustomComponent(Component):
def tree(self):
return self.get_code_tree(self.code or "")
def to_records(
self, data: Any, text_key: str = "text", data_key: str = "data"
) -> List[dict]:
"""
Convert data into a list of records.
Args:
data (Any): The input data to be converted.
text_key (str, optional): The key to extract the text from a dictionary item. Defaults to "text".
data_key (str, optional): The key to extract the data from a dictionary item. Defaults to "data".
Returns:
List[dict]: A list of records, where each record is a dictionary with 'text' and 'data' keys.
"""
records = []
if not isinstance(data, Sequence):
data = [data]
for item in data:
if isinstance(item, str):
records.append(Record(text=item))
elif isinstance(item, dict):
records.append(Record(text=item.get(text_key), data=item.get(data_key)))
elif isinstance(item, Document):
records.append(Record(text=item.page_content, data=item.metadata))
else:
raise ValueError(f"Invalid data type: {type(item)}")
return records
def create_references_from_records(
self, records: List[dict], include_data: bool = False
) -> str:
"""
Create references from a list of records.
Args:
records (List[dict]): A list of records, where each record is a dictionary.
include_data (bool, optional): Whether to include data in the references. Defaults to False.
Returns:
str: A string containing the references in markdown format.
"""
markdown_string = "---\n"
for record in records:
markdown_string += f"- Text: {record['text']}"
if include_data:
markdown_string += f" Data: {record['data']}"
markdown_string += "\n"
return markdown_string
@property
def get_function_entrypoint_args(self) -> list:
build_method = self.get_build_method()
@ -100,7 +152,8 @@ class CustomComponent(Component):
detail={
"error": "Type hint Error",
"traceback": (
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
"Prompt type is not supported in the build method."
" Try using PromptTemplate instead."
),
},
)
@ -114,14 +167,20 @@ class CustomComponent(Component):
if not self.code:
return {}
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
component_classes = [
cls
for cls in self.tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return {}
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
return build_methods[0] if build_methods else {}
@ -178,7 +237,9 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return credential_service.get_credential(
user_id=self._user_id or "", name=name, session=session
)
return get_credential
@ -188,7 +249,9 @@ class CustomComponent(Component):
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
return credential_service.list_credentials(
user_id=self._user_id, session=session
)
def index(self, value: int = 0):
"""Returns a function that returns the value at the given index in the iterable."""
@ -239,7 +302,11 @@ class CustomComponent(Component):
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
flow = (
session.query(Flow)
.filter(Flow.name == flow_name)
.filter(Flow.user_id == self.user_id)
).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")