Merge branch 'multipart_endpoint' of https://github.com/logspace-ai/langflow into multipart_endpoint
This commit is contained in:
commit
0a0584c7ce
4 changed files with 161 additions and 3 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from typing import Type
|
||||
from langchain.vectorstores import Pinecone, Qdrant, Chroma
|
||||
from langchain.vectorstores import Pinecone, Qdrant, Chroma, FAISS, Weaviate
|
||||
|
||||
|
||||
def docs_in_params(params: dict) -> bool:
|
||||
|
|
@ -10,6 +11,45 @@ def docs_in_params(params: dict) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def initialize_weaviate(class_object: Type[Weaviate], params: dict):
|
||||
"""Initialize weaviate and return the class object"""
|
||||
if not docs_in_params(params):
|
||||
import weaviate
|
||||
|
||||
client_kwargs_json = params.get("client_kwargs", "{}")
|
||||
client_kwargs = json.loads(client_kwargs_json)
|
||||
client_params = {
|
||||
"url": params.get("weaviate_url"),
|
||||
}
|
||||
client_params.update(client_kwargs)
|
||||
weaviate_client = weaviate.Client(**client_params)
|
||||
|
||||
new_params = {
|
||||
"client": weaviate_client,
|
||||
"index_name": params.get("index_name"),
|
||||
"text_key": params.get("text_key"),
|
||||
}
|
||||
weaviate = class_object(**new_params)
|
||||
# If there are docs in the params, create a new index
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_faiss(class_object: Type[FAISS], params: dict):
|
||||
"""Initialize faiss and return the class object"""
|
||||
|
||||
if not docs_in_params(params):
|
||||
return class_object.load_local
|
||||
|
||||
save_local = params.get("save_local")
|
||||
faiss_index = class_object(**params)
|
||||
if save_local:
|
||||
faiss_index.save_local(folder_path=save_local)
|
||||
return faiss_index
|
||||
|
||||
|
||||
def initialize_pinecone(class_object: Type[Pinecone], params: dict):
|
||||
"""Initialize pinecone and return the class object"""
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,18 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
multiline=False,
|
||||
value="http://localhost:8080",
|
||||
)
|
||||
extra_fields.append(extra_field)
|
||||
# Add client_kwargs field
|
||||
extra_field2 = TemplateField(
|
||||
name="client_kwargs",
|
||||
field_type="code",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
value="{}",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2))
|
||||
|
||||
elif self.template.type_name == "Chroma":
|
||||
# New bool field for persist parameter
|
||||
|
|
@ -55,6 +66,29 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
value="",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2))
|
||||
elif self.template.type_name == "FAISS":
|
||||
extra_field = TemplateField(
|
||||
name="folder_path",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
display_name="Local Path",
|
||||
value="",
|
||||
)
|
||||
extra_field2 = TemplateField(
|
||||
name="index_name",
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=False,
|
||||
value="",
|
||||
display_name="Index Name",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2))
|
||||
|
||||
if extra_fields:
|
||||
for field in extra_fields:
|
||||
self.template.add_field(field)
|
||||
|
|
@ -76,6 +110,7 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
"weaviate_url",
|
||||
"index_name",
|
||||
"namespace",
|
||||
"folder_path",
|
||||
]
|
||||
advanced_fields = [
|
||||
"n_dim",
|
||||
|
|
@ -94,6 +129,7 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
"grpc_port",
|
||||
"pinecone_api_key",
|
||||
"pinecone_env",
|
||||
"client_kwargs",
|
||||
]
|
||||
|
||||
# Check and set field attributes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue