diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 70087a98f..507e760f6 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -1,17 +1,19 @@ import operator from pathlib import Path -from typing import Any, Callable, ClassVar, List, Optional, Union +from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union from uuid import UUID import yaml from cachetools import TTLCache, cachedmethod from fastapi import HTTPException +from langchain_core.documents import Document from langflow.interface.custom.code_parser.utils import ( extract_inner_type_from_generic_alias, extract_union_types_from_generic_alias, ) from langflow.interface.custom.custom_component.component import Component +from langflow.schema import Record from langflow.services.database.models.flow import Flow from langflow.services.database.utils import session_getter from langflow.services.deps import ( @@ -86,6 +88,56 @@ class CustomComponent(Component): def tree(self): return self.get_code_tree(self.code or "") + def to_records( + self, data: Any, text_key: str = "text", data_key: str = "data" + ) -> List[dict]: + """ + Convert data into a list of records. + + Args: + data (Any): The input data to be converted. + text_key (str, optional): The key to extract the text from a dictionary item. Defaults to "text". + data_key (str, optional): The key to extract the data from a dictionary item. Defaults to "data". + + Returns: + List[dict]: A list of records, where each record is a dictionary with 'text' and 'data' keys. + """ + records = [] + if not isinstance(data, Sequence): + data = [data] + for item in data: + if isinstance(item, str): + records.append(Record(text=item)) + elif isinstance(item, dict): + records.append(Record(text=item.get(text_key), data=item.get(data_key))) + elif isinstance(item, Document): + records.append(Record(text=item.page_content, data=item.metadata)) + else: + raise ValueError(f"Invalid data type: {type(item)}") + + return records + + def create_references_from_records( + self, records: List[dict], include_data: bool = False + ) -> str: + """ + Create references from a list of records. + + Args: + records (List[dict]): A list of records, where each record is a dictionary. + include_data (bool, optional): Whether to include data in the references. Defaults to False. + + Returns: + str: A string containing the references in markdown format. + """ + markdown_string = "---\n" + for record in records: + markdown_string += f"- Text: {record['text']}" + if include_data: + markdown_string += f" Data: {record['data']}" + markdown_string += "\n" + return markdown_string + @property def get_function_entrypoint_args(self) -> list: build_method = self.get_build_method() @@ -100,7 +152,8 @@ class CustomComponent(Component): detail={ "error": "Type hint Error", "traceback": ( - "Prompt type is not supported in the build method." " Try using PromptTemplate instead." + "Prompt type is not supported in the build method." + " Try using PromptTemplate instead." ), }, ) @@ -114,14 +167,20 @@ class CustomComponent(Component): if not self.code: return {} - component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]] + component_classes = [ + cls + for cls in self.tree["classes"] + if self.code_class_base_inheritance in cls["bases"] + ] if not component_classes: return {} # Assume the first Component class is the one we're interested in component_class = component_classes[0] build_methods = [ - method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name + method + for method in component_class["methods"] + if method["name"] == self.function_entrypoint_name ] return build_methods[0] if build_methods else {} @@ -178,7 +237,9 @@ class CustomComponent(Component): # Retrieve and decrypt the credential by name for the current user db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session) + return credential_service.get_credential( + user_id=self._user_id or "", name=name, session=session + ) return get_credential @@ -188,7 +249,9 @@ class CustomComponent(Component): credential_service = get_credential_service() db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.list_credentials(user_id=self._user_id, session=session) + return credential_service.list_credentials( + user_id=self._user_id, session=session + ) def index(self, value: int = 0): """Returns a function that returns the value at the given index in the iterable.""" @@ -239,7 +302,11 @@ class CustomComponent(Component): if flow_id: flow = session.query(Flow).get(flow_id) elif flow_name: - flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first() + flow = ( + session.query(Flow) + .filter(Flow.name == flow_name) + .filter(Flow.user_id == self.user_id) + ).first() else: raise ValueError("Either flow_name or flow_id must be provided")