diff --git a/src/backend/base/langflow/components/vectorstores/base/model.py b/src/backend/base/langflow/components/vectorstores/base/model.py index 1e6b86b5e..608b234b6 100644 --- a/src/backend/base/langflow/components/vectorstores/base/model.py +++ b/src/backend/base/langflow/components/vectorstores/base/model.py @@ -4,15 +4,41 @@ from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore -from langflow.custom import CustomComponent +from langflow.custom import Component from langflow.field_typing import Text from langflow.helpers.data import docs_to_data from langflow.schema import Data +from langflow.template import Output -class LCVectorStoreComponent(CustomComponent): - display_name: str = "LC Vector Store" - description: str = "Search a LC Vector Store for similar documents." +class LCVectorStoreComponent(Component): + outputs = [ + Output( + display_name="Vector Store", + name="vector_store", + method="build_vector_store", + ), + Output( + display_name="Base Retriever", + name="base_retriever", + method="build_base_retriever", + ), + Output( + display_name="Search Results", + name="search_results", + method="search_documents", + ), + ] + + def _validate_outputs(self): + # At least these three outputs must be defined + required_output_methods = ["build_vector_store", "build_base_retriever", "search_documents"] + output_names = [output.name for output in self.outputs] + for method_name in required_output_methods: + if method_name not in output_names: + raise ValueError(f"Output with name '{method_name}' must be defined.") + elif not hasattr(self, method_name): + raise ValueError(f"Method '{method_name}' must be defined.") def search_with_vector_store( self, @@ -45,3 +71,19 @@ class LCVectorStoreComponent(CustomComponent): data = docs_to_data(docs) self.status = data return data + + def build_vector_store(self) -> VectorStore: + """ + Builds the Vector Store object. + """ + raise NotImplementedError("build_vector_store method must be implemented.") + + def build_base_retriever(self) -> BaseRetriever: + """ + Builds the BaseRetriever object. + """ + vector_store = self.build_vector_store() + if hasattr(vector_store, "as_retriever"): + return vector_store.as_retriever() + else: + raise ValueError(f"Vector Store {vector_store.__class__.__name__} does not have an as_retriever method.") diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index f3b08a6a0..b9cbdd962 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -57,6 +57,14 @@ class Component(CustomComponent): for input_ in inputs: self._inputs[input_.name] = input_ + def validate(self, params: dict): + self._validate_inputs(params) + self._validate_outputs() + + def _validate_outputs(self): + # Raise Error if some rule isn't met + pass + def _validate_inputs(self, params: dict): # Params keys are the `name` attribute of the Input objects for key, value in params.copy().items():