Refactor constants.py to use VectorStoreRetriever in Retriever typevar

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-22 21:37:01 -03:00
commit 438946741a
6 changed files with 26 additions and 22 deletions

View file

@ -73,7 +73,7 @@ class LCVectorStoreComponent(Component):
"""
raise NotImplementedError("build_vector_store method must be implemented.")
def build_base_retriever(self) -> Retriever:
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
"""
Builds the BaseRetriever object.
"""

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, cast
from langchain_community.retrievers import AmazonKendraRetriever
@ -36,7 +36,7 @@ class AmazonKendraRetrieverComponent(CustomComponent):
credentials_profile_name: Optional[str] = None,
attribute_filter: Optional[dict] = None,
user_context: Optional[dict] = None,
) -> Retriever:
) -> Retriever: # type: ignore[type-var]
try:
output = AmazonKendraRetriever(
index_id=index_id,
@ -48,4 +48,4 @@ class AmazonKendraRetrieverComponent(CustomComponent):
) # type: ignore
except Exception as e:
raise ValueError("Could not connect to AmazonKendra API.") from e
return output
return cast(Retriever, output)

View file

@ -1,4 +1,4 @@
from typing import List
from typing import List, cast
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
@ -36,15 +36,15 @@ class CohereRerankComponent(LCVectorStoreComponent):
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]
def build_base_retriever(self) -> Retriever:
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
cohere_reranker = CohereRerank(
cohere_api_key=self.api_key, model=self.model, top_n=self.top_n, user_agent=self.user_agent
)
retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever)
return retriever
return cast(Retriever, retriever)
async def search_documents(self) -> List[Data]:
retriever: ContextualCompressionRetriever = self.build_base_retriever()
async def search_documents(self) -> List[Data]: # type: ignore
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query)
data = self.to_data(documents)
self.status = data

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, cast
from langchain_community.retrievers import MetalRetriever
from metal_sdk.metal import Metal # type: ignore
@ -20,9 +20,9 @@ class MetalRetrieverComponent(CustomComponent):
"code": {"show": False},
}
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> Retriever:
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> Retriever: # type: ignore[type-var]
try:
metal = Metal(api_key=api_key, client_id=client_id, index_id=index_id)
except Exception as e:
raise ValueError("Could not connect to Metal API.") from e
return MetalRetriever(client=metal, params=params or {})
return cast(Retriever, MetalRetriever(client=metal, params=params or {}))

View file

@ -1,13 +1,13 @@
import json
from typing import List
from typing import List, cast
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.vectorstores import VectorStore
from langflow.custom import CustomComponent
from langflow.field_typing.constants import LanguageModel
from langflow.field_typing import Retriever
from langflow.field_typing.constants import LanguageModel
class VectaraSelfQueryRetriverComponent(CustomComponent):
@ -40,7 +40,7 @@ class VectaraSelfQueryRetriverComponent(CustomComponent):
document_content_description: str,
llm: LanguageModel,
metadata_field_info: List[str],
) -> Retriever:
) -> Retriever: # type: ignore
metadata_field_obj = []
for meta in metadata_field_info:
@ -54,6 +54,9 @@ class VectaraSelfQueryRetriverComponent(CustomComponent):
)
metadata_field_obj.append(attribute_info)
return SelfQueryRetriever.from_llm(
llm, vectorstore, document_content_description, metadata_field_obj, verbose=True
return cast(
Retriever,
SelfQueryRetriever.from_llm(
llm, vectorstore, document_content_description, metadata_field_obj, verbose=True
),
)

View file

@ -13,15 +13,16 @@ from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import Tool
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_text_splitters import TextSplitter
NestedDict: TypeAlias = Dict[str, Union[str, Dict]]
LanguageModel = TypeVar("LanguageModel", BaseLanguageModel, BaseLLM, BaseChatModel)
class Retriever(BaseRetriever):
pass
Retriever = TypeVar(
"Retriever",
BaseRetriever,
VectorStoreRetriever,
)
class Object: