Merge branch 'multipart_endpoint' of https://github.com/logspace-ai/langflow into multipart_endpoint

This commit is contained in:
Cristhian Zanforlin Lousa 2023-06-22 20:08:09 -03:00
commit 0a0584c7ce
4 changed files with 161 additions and 3 deletions

View file

@ -1,5 +1,6 @@
import json
from typing import Type
from langchain.vectorstores import Pinecone, Qdrant, Chroma
from langchain.vectorstores import Pinecone, Qdrant, Chroma, FAISS, Weaviate
def docs_in_params(params: dict) -> bool:
@ -10,6 +11,45 @@ def docs_in_params(params: dict) -> bool:
)
def initialize_weaviate(class_object: Type[Weaviate], params: dict):
"""Initialize weaviate and return the class object"""
if not docs_in_params(params):
import weaviate
client_kwargs_json = params.get("client_kwargs", "{}")
client_kwargs = json.loads(client_kwargs_json)
client_params = {
"url": params.get("weaviate_url"),
}
client_params.update(client_kwargs)
weaviate_client = weaviate.Client(**client_params)
new_params = {
"client": weaviate_client,
"index_name": params.get("index_name"),
"text_key": params.get("text_key"),
}
weaviate = class_object(**new_params)
# If there are docs in the params, create a new index
if "texts" in params:
params["documents"] = params.pop("texts")
return class_object.from_documents(**params)
def initialize_faiss(class_object: Type[FAISS], params: dict):
"""Initialize faiss and return the class object"""
if not docs_in_params(params):
return class_object.load_local
save_local = params.get("save_local")
faiss_index = class_object(**params)
if save_local:
faiss_index.save_local(folder_path=save_local)
return faiss_index
def initialize_pinecone(class_object: Type[Pinecone], params: dict):
"""Initialize pinecone and return the class object"""

View file

@ -18,7 +18,18 @@ class VectorStoreFrontendNode(FrontendNode):
multiline=False,
value="http://localhost:8080",
)
extra_fields.append(extra_field)
# Add client_kwargs field
extra_field2 = TemplateField(
name="client_kwargs",
field_type="code",
required=False,
placeholder="",
show=True,
advanced=True,
multiline=False,
value="{}",
)
extra_fields.extend((extra_field, extra_field2))
elif self.template.type_name == "Chroma":
# New bool field for persist parameter
@ -55,6 +66,29 @@ class VectorStoreFrontendNode(FrontendNode):
value="",
)
extra_fields.extend((extra_field, extra_field2))
elif self.template.type_name == "FAISS":
extra_field = TemplateField(
name="folder_path",
field_type="str",
required=False,
placeholder="",
show=True,
advanced=True,
multiline=False,
display_name="Local Path",
value="",
)
extra_field2 = TemplateField(
name="index_name",
field_type="str",
required=False,
show=True,
advanced=False,
value="",
display_name="Index Name",
)
extra_fields.extend((extra_field, extra_field2))
if extra_fields:
for field in extra_fields:
self.template.add_field(field)
@ -76,6 +110,7 @@ class VectorStoreFrontendNode(FrontendNode):
"weaviate_url",
"index_name",
"namespace",
"folder_path",
]
advanced_fields = [
"n_dim",
@ -94,6 +129,7 @@ class VectorStoreFrontendNode(FrontendNode):
"grpc_port",
"pinecone_api_key",
"pinecone_env",
"client_kwargs",
]
# Check and set field attributes