From 7818e55146cc8019269c2cc45ca212f510fa5975 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Mon, 17 Jun 2024 14:23:49 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=9D=20(model.py):=20Refactor=20LCVecto?= =?UTF-8?q?rStoreComponent=20to=20use=20Component=20class=20instead=20of?= =?UTF-8?q?=20CustomComponent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 📝 (model.py): Add outputs attribute to LCVectorStoreComponent to define available outputs and their methods 📝 (model.py): Implement _validate_outputs method in LCVectorStoreComponent to ensure required outputs are defined 📝 (model.py): Add build_vector_store and build_base_retriever methods to LCVectorStoreComponent for building Vector Store and Base Retriever objects 📝 (model.py): Update search_with_vector_store method in LCVectorStoreComponent to return data 📝 (model.py): Add NotImplementedError and ValueError handling in build_vector_store and build_base_retriever methods 📝 (component.py): Implement validate method in Component class to validate inputs and outputs 📝 (component.py): Implement _validate_outputs method in Component class to be extended by subclasses for output validation --- .../components/vectorstores/base/model.py | 50 +++++++++++++++++-- .../custom/custom_component/component.py | 8 +++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/backend/base/langflow/components/vectorstores/base/model.py b/src/backend/base/langflow/components/vectorstores/base/model.py index 1e6b86b5e..608b234b6 100644 --- a/src/backend/base/langflow/components/vectorstores/base/model.py +++ b/src/backend/base/langflow/components/vectorstores/base/model.py @@ -4,15 +4,41 @@ from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore -from langflow.custom import CustomComponent +from langflow.custom import Component from langflow.field_typing import Text from langflow.helpers.data import docs_to_data from langflow.schema import Data +from langflow.template import Output -class LCVectorStoreComponent(CustomComponent): - display_name: str = "LC Vector Store" - description: str = "Search a LC Vector Store for similar documents." +class LCVectorStoreComponent(Component): + outputs = [ + Output( + display_name="Vector Store", + name="vector_store", + method="build_vector_store", + ), + Output( + display_name="Base Retriever", + name="base_retriever", + method="build_base_retriever", + ), + Output( + display_name="Search Results", + name="search_results", + method="search_documents", + ), + ] + + def _validate_outputs(self): + # At least these three outputs must be defined + required_output_methods = ["build_vector_store", "build_base_retriever", "search_documents"] + output_names = [output.name for output in self.outputs] + for method_name in required_output_methods: + if method_name not in output_names: + raise ValueError(f"Output with name '{method_name}' must be defined.") + elif not hasattr(self, method_name): + raise ValueError(f"Method '{method_name}' must be defined.") def search_with_vector_store( self, @@ -45,3 +71,19 @@ class LCVectorStoreComponent(CustomComponent): data = docs_to_data(docs) self.status = data return data + + def build_vector_store(self) -> VectorStore: + """ + Builds the Vector Store object. + """ + raise NotImplementedError("build_vector_store method must be implemented.") + + def build_base_retriever(self) -> BaseRetriever: + """ + Builds the BaseRetriever object. + """ + vector_store = self.build_vector_store() + if hasattr(vector_store, "as_retriever"): + return vector_store.as_retriever() + else: + raise ValueError(f"Vector Store {vector_store.__class__.__name__} does not have an as_retriever method.") diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index f3b08a6a0..b9cbdd962 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -57,6 +57,14 @@ class Component(CustomComponent): for input_ in inputs: self._inputs[input_.name] = input_ + def validate(self, params: dict): + self._validate_inputs(params) + self._validate_outputs() + + def _validate_outputs(self): + # Raise Error if some rule isn't met + pass + def _validate_inputs(self, params: dict): # Params keys are the `name` attribute of the Input objects for key, value in params.copy().items():