Refactor constants.py to use VectorStoreRetriever in Retriever typevar
This commit is contained in:
parent
659793fcee
commit
438946741a
6 changed files with 26 additions and 22 deletions
|
|
@ -73,7 +73,7 @@ class LCVectorStoreComponent(Component):
|
|||
"""
|
||||
raise NotImplementedError("build_vector_store method must be implemented.")
|
||||
|
||||
def build_base_retriever(self) -> Retriever:
|
||||
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
|
||||
"""
|
||||
Builds the BaseRetriever object.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain_community.retrievers import AmazonKendraRetriever
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class AmazonKendraRetrieverComponent(CustomComponent):
|
|||
credentials_profile_name: Optional[str] = None,
|
||||
attribute_filter: Optional[dict] = None,
|
||||
user_context: Optional[dict] = None,
|
||||
) -> Retriever:
|
||||
) -> Retriever: # type: ignore[type-var]
|
||||
try:
|
||||
output = AmazonKendraRetriever(
|
||||
index_id=index_id,
|
||||
|
|
@ -48,4 +48,4 @@ class AmazonKendraRetrieverComponent(CustomComponent):
|
|||
) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to AmazonKendra API.") from e
|
||||
return output
|
||||
return cast(Retriever, output)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, cast
|
||||
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain_cohere import CohereRerank
|
||||
|
|
@ -36,15 +36,15 @@ class CohereRerankComponent(LCVectorStoreComponent):
|
|||
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
|
||||
]
|
||||
|
||||
def build_base_retriever(self) -> Retriever:
|
||||
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
|
||||
cohere_reranker = CohereRerank(
|
||||
cohere_api_key=self.api_key, model=self.model, top_n=self.top_n, user_agent=self.user_agent
|
||||
)
|
||||
retriever = ContextualCompressionRetriever(base_compressor=cohere_reranker, base_retriever=self.retriever)
|
||||
return retriever
|
||||
return cast(Retriever, retriever)
|
||||
|
||||
async def search_documents(self) -> List[Data]:
|
||||
retriever: ContextualCompressionRetriever = self.build_base_retriever()
|
||||
async def search_documents(self) -> List[Data]: # type: ignore
|
||||
retriever = self.build_base_retriever()
|
||||
documents = await retriever.ainvoke(self.search_query)
|
||||
data = self.to_data(documents)
|
||||
self.status = data
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain_community.retrievers import MetalRetriever
|
||||
from metal_sdk.metal import Metal # type: ignore
|
||||
|
|
@ -20,9 +20,9 @@ class MetalRetrieverComponent(CustomComponent):
|
|||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> Retriever:
|
||||
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> Retriever: # type: ignore[type-var]
|
||||
try:
|
||||
metal = Metal(api_key=api_key, client_id=client_id, index_id=index_id)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to Metal API.") from e
|
||||
return MetalRetriever(client=metal, params=params or {})
|
||||
return cast(Retriever, MetalRetriever(client=metal, params=params or {}))
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import json
|
||||
from typing import List
|
||||
from typing import List, cast
|
||||
|
||||
from langchain.chains.query_constructor.base import AttributeInfo
|
||||
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing.constants import LanguageModel
|
||||
from langflow.field_typing import Retriever
|
||||
from langflow.field_typing.constants import LanguageModel
|
||||
|
||||
|
||||
class VectaraSelfQueryRetriverComponent(CustomComponent):
|
||||
|
|
@ -40,7 +40,7 @@ class VectaraSelfQueryRetriverComponent(CustomComponent):
|
|||
document_content_description: str,
|
||||
llm: LanguageModel,
|
||||
metadata_field_info: List[str],
|
||||
) -> Retriever:
|
||||
) -> Retriever: # type: ignore
|
||||
metadata_field_obj = []
|
||||
|
||||
for meta in metadata_field_info:
|
||||
|
|
@ -54,6 +54,9 @@ class VectaraSelfQueryRetriverComponent(CustomComponent):
|
|||
)
|
||||
metadata_field_obj.append(attribute_info)
|
||||
|
||||
return SelfQueryRetriever.from_llm(
|
||||
llm, vectorstore, document_content_description, metadata_field_obj, verbose=True
|
||||
return cast(
|
||||
Retriever,
|
||||
SelfQueryRetriever.from_llm(
|
||||
llm, vectorstore, document_content_description, metadata_field_obj, verbose=True
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,15 +13,16 @@ from langchain_core.output_parsers import BaseOutputParser
|
|||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
|
||||
from langchain_text_splitters import TextSplitter
|
||||
|
||||
NestedDict: TypeAlias = Dict[str, Union[str, Dict]]
|
||||
LanguageModel = TypeVar("LanguageModel", BaseLanguageModel, BaseLLM, BaseChatModel)
|
||||
|
||||
|
||||
class Retriever(BaseRetriever):
|
||||
pass
|
||||
Retriever = TypeVar(
|
||||
"Retriever",
|
||||
BaseRetriever,
|
||||
VectorStoreRetriever,
|
||||
)
|
||||
|
||||
|
||||
class Object:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue