Update QdrantComponent build method to handle pre-existing vector-stores

This commit is contained in:
Ricardo Henriques 2024-02-14 15:58:30 +00:00
commit 79e472772e

View file

@ -15,7 +15,7 @@ class QdrantComponent(CustomComponent):
return {
"documents": {"display_name": "Documents"},
"embedding": {"display_name": "Embedding"},
"api_key": {"display_name": "API Key", "password": True},
"api_key": {"display_name": "API Key", "password": True, "advanced": True},
"collection_name": {"display_name": "Collection Name"},
"content_payload_key": {"display_name": "Content Payload Key", "advanced": True},
"distance_func": {"display_name": "Distance Function", "advanced": True},
@ -36,7 +36,7 @@ class QdrantComponent(CustomComponent):
def build(
self,
embedding: Embeddings,
documents: List[Document],
documents: Optional[Document] = None,
api_key: Optional[str] = None,
collection_name: Optional[str] = None,
content_payload_key: str = "page_content",
@ -44,7 +44,7 @@ class QdrantComponent(CustomComponent):
grpc_port: Optional[int] = 6334,
host: Optional[str] = None,
https: bool = False,
location: str = ":memory:",
location: Optional[str] = None,
metadata_payload_key: str = "metadata",
path: Optional[str] = None,
port: Optional[int] = 6333,
@ -54,23 +54,50 @@ class QdrantComponent(CustomComponent):
timeout: Optional[float] = None,
url: Optional[str] = None,
) -> Union[VectorStore, Qdrant, BaseRetriever]:
return Qdrant.from_documents(
documents=documents,
embedding=embedding,
api_key=api_key,
collection_name=collection_name,
content_payload_key=content_payload_key,
distance_func=distance_func,
grpc_port=grpc_port,
host=host,
https=https,
location=location,
metadata_payload_key=metadata_payload_key,
path=path,
port=port,
prefer_grpc=prefer_grpc,
prefix=prefix,
search_kwargs=search_kwargs,
timeout=timeout,
url=url,
)
if documents is None:
from qdrant_client import QdrantClient
client = QdrantClient(
location=location,
url=host,
port=port,
grpc_port=grpc_port,
https=https,
prefix=prefix,
timeout=timeout,
prefer_grpc=prefer_grpc,
metadata_payload_key=metadata_payload_key,
content_payload_key=content_payload_key,
api_key=api_key,
collection_name=collection_name,
host=host,
path=path,
)
vs = Qdrant(client=client,
collection_name=collection_name,
embeddings=embedding,
search_kwargs=search_kwargs,
distance_func=distance_func,
)
return vs
else:
vs = Qdrant.from_documents(
documents=documents,
embedding=embedding,
api_key=api_key,
collection_name=collection_name,
content_payload_key=content_payload_key,
distance_func=distance_func,
grpc_port=grpc_port,
host=host,
https=https,
location=location,
metadata_payload_key=metadata_payload_key,
path=path,
port=port,
prefer_grpc=prefer_grpc,
prefix=prefix,
search_kwargs=search_kwargs,
timeout=timeout,
url=url,
)
return vs