refactor: Update vector store retriever types to use 'Retriever' alias

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-20 11:47:44 -03:00
commit 0497c8a192
14 changed files with 38 additions and 20 deletions

View file

@ -1,9 +1,9 @@
from typing import List, Union
from typing import List
from langchain_core.documents import Document
from langflow.custom import Component
from langflow.field_typing import BaseRetriever, Text, VectorStore
from langflow.field_typing import Retriever, Text, VectorStore
from langflow.helpers.data import docs_to_data
from langflow.io import Output
from langflow.schema import Data
@ -37,7 +37,7 @@ class LCVectorStoreComponent(Component):
self,
input_value: Text,
search_type: str,
vector_store: Union[VectorStore, BaseRetriever],
vector_store: VectorStore,
k=10,
**kwargs,
) -> List[Data]:
@ -71,7 +71,7 @@ class LCVectorStoreComponent(Component):
"""
raise NotImplementedError("build_vector_store method must be implemented.")
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
"""
Builds the BaseRetriever object.
"""

View file

@ -1,9 +1,9 @@
from typing import Optional
from langchain_community.retrievers import AmazonKendraRetriever
from langchain_core.retrievers import BaseRetriever
from langflow.custom import CustomComponent
from langflow.field_typing import Retriever
class AmazonKendraRetrieverComponent(CustomComponent):
@ -36,7 +36,7 @@ class AmazonKendraRetrieverComponent(CustomComponent):
credentials_profile_name: Optional[str] = None,
attribute_filter: Optional[dict] = None,
user_context: Optional[dict] = None,
) -> BaseRetriever:
) -> Retriever:
try:
output = AmazonKendraRetriever(
index_id=index_id,

View file

@ -1,10 +1,10 @@
from typing import Optional
from langchain_community.retrievers import MetalRetriever
from langchain_core.retrievers import BaseRetriever
from metal_sdk.metal import Metal # type: ignore
from langflow.custom import CustomComponent
from langflow.field_typing import Retriever
class MetalRetrieverComponent(CustomComponent):
@ -20,7 +20,7 @@ class MetalRetrieverComponent(CustomComponent):
"code": {"show": False},
}
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> BaseRetriever:
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> Retriever:
try:
metal = Metal(api_key=api_key, client_id=client_id, index_id=index_id)
except Exception as e:

View file

@ -3,11 +3,11 @@ from typing import List
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langflow.custom import CustomComponent
from langflow.field_typing.constants import LanguageModel
from langflow.field_typing import Retriever
class VectaraSelfQueryRetriverComponent(CustomComponent):
@ -40,7 +40,7 @@ class VectaraSelfQueryRetriverComponent(CustomComponent):
document_content_description: str,
llm: LanguageModel,
metadata_field_info: List[str],
) -> BaseRetriever:
) -> Retriever:
metadata_field_obj = []
for meta in metadata_field_info:

View file

@ -2,10 +2,12 @@ from typing import List
from langchain_community.vectorstores import Cassandra
from langchain_core.retrievers import BaseRetriever
from langflow.custom import Component
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, DropdownInput, HandleInput, IntInput, Output, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.field_typing import Retriever
class CassandraVectorStoreComponent(Component):
@ -97,7 +99,7 @@ class CassandraVectorStoreComponent(Component):
def build_vector_store(self) -> Cassandra:
return self._build_cassandra()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_cassandra()
def _build_cassandra(self) -> Cassandra:

View file

@ -3,10 +3,12 @@ from typing import List
from langchain_community.vectorstores import CouchbaseVectorStore
from langchain_core.retrievers import BaseRetriever
from langflow.custom import Component
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, HandleInput, IntInput, Output, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.field_typing import Retriever
class CouchbaseVectorStoreComponent(Component):
@ -64,7 +66,7 @@ class CouchbaseVectorStoreComponent(Component):
def build_vector_store(self) -> CouchbaseVectorStore:
return self._build_couchbase()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_couchbase()
def _build_couchbase(self) -> CouchbaseVectorStore:

View file

@ -8,6 +8,8 @@ from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, HandleInput, IntInput, Output, StrInput
from langflow.schema import Data
from langflow.field_typing import Retriever
class MongoVectorStoreComponent(Component):
display_name = "MongoDB Atlas"
@ -61,7 +63,7 @@ class MongoVectorStoreComponent(Component):
def build_vector_store(self) -> MongoDBAtlasVectorSearch:
return self._build_mongodb_atlas()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_mongodb_atlas()
def _build_mongodb_atlas(self) -> MongoDBAtlasVectorSearch:

View file

@ -2,7 +2,9 @@ from typing import List
from langchain_core.retrievers import BaseRetriever
from langchain_pinecone import Pinecone
from langflow.custom import Component
from langflow.field_typing import Retriever
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, DropdownInput, HandleInput, IntInput, Output, SecretStrInput, StrInput
from langflow.schema import Data
@ -68,7 +70,7 @@ class PineconeVectorStoreComponent(Component):
def build_vector_store(self) -> Pinecone:
return self._build_pinecone()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_pinecone()
def _build_pinecone(self) -> Pinecone:

View file

@ -2,7 +2,9 @@ from typing import List
from langchain_community.vectorstores import Qdrant
from langchain_core.retrievers import BaseRetriever
from langflow.custom import Component
from langflow.field_typing import Retriever
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, DropdownInput, HandleInput, IntInput, Output, SecretStrInput, StrInput
from langflow.schema import Data
@ -69,7 +71,7 @@ class QdrantVectorStoreComponent(Component):
def build_vector_store(self) -> Qdrant:
return self._build_qdrant()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_qdrant()
def _build_qdrant(self) -> Qdrant:

View file

@ -5,6 +5,7 @@ from langchain_core.retrievers import BaseRetriever
from supabase.client import Client, create_client
from langflow.custom import Component
from langflow.field_typing import Retriever
from langflow.helpers.data import docs_to_data
from langflow.io import HandleInput, IntInput, Output, StrInput
from langflow.schema import Data
@ -57,7 +58,7 @@ class SupabaseVectorStoreComponent(Component):
def build_vector_store(self) -> SupabaseVectorStore:
return self._build_supabase()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_supabase()
def _build_supabase(self) -> SupabaseVectorStore:

View file

@ -4,6 +4,7 @@ from langchain_community.vectorstores import UpstashVectorStore
from langchain_core.retrievers import BaseRetriever
from langflow.custom import Component
from langflow.field_typing import Retriever
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, HandleInput, IntInput, Output, StrInput
from langflow.schema import Data
@ -73,7 +74,7 @@ class UpstashVectorStoreComponent(Component):
def build_vector_store(self) -> UpstashVectorStore:
return self._build_upstash()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_upstash()
def _build_upstash(self) -> UpstashVectorStore:

View file

@ -3,10 +3,12 @@ from typing import List
from langchain_community.embeddings import FakeEmbeddings
from langchain_community.vectorstores import Vectara
from langchain_core.retrievers import BaseRetriever
from langflow.custom import Component
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, HandleInput, IntInput, Output, SecretStrInput, StrInput
from langflow.schema import Data
from langflow.field_typing import Retriever
class VectaraVectorStoreComponent(Component):
@ -54,7 +56,7 @@ class VectaraVectorStoreComponent(Component):
def build_vector_store(self) -> Vectara:
return self._build_vectara()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_vectara()
def _build_vectara(self) -> Vectara:

View file

@ -3,7 +3,9 @@ from typing import List
import weaviate
from langchain_community.vectorstores import Weaviate
from langchain_core.retrievers import BaseRetriever
from langflow.custom import Component
from langflow.field_typing import Retriever
from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, HandleInput, IntInput, Output, SecretStrInput, StrInput
from langflow.schema import Data
@ -57,7 +59,7 @@ class WeaviateVectorStoreComponent(Component):
def build_vector_store(self) -> Weaviate:
return self._build_weaviate()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_weaviate()
def _build_weaviate(self) -> Weaviate:

View file

@ -8,6 +8,8 @@ from langflow.helpers.data import docs_to_data
from langflow.io import BoolInput, HandleInput, IntInput, Output, StrInput
from langflow.schema import Data
from langflow.field_typing import Retriever
class PGVectorStoreComponent(Component):
display_name = "PGVector"
@ -54,7 +56,7 @@ class PGVectorStoreComponent(Component):
def build_vector_store(self) -> PGVector:
return self._build_pgvector()
def build_base_retriever(self) -> BaseRetriever:
def build_base_retriever(self) -> Retriever:
return self._build_pgvector()
def _build_pgvector(self) -> PGVector: