add support for VertexAIEmbeddings node

This commit is contained in:
Dave Morris 2023-08-08 17:35:50 -05:00
commit bb2b8fbb3d
8 changed files with 121 additions and 4 deletions

View file

@ -12,4 +12,3 @@ RUN rm *.whl
EXPOSE 80
CMD [ "uvicorn", "--host", "0.0.0.0", "--port", "7860", "--factory", "langflow.main:create_app" ]

View file

@ -88,7 +88,7 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
elif base_type == "toolkits":
return instantiate_toolkit(node_type, class_object, params)
elif base_type == "embeddings":
return instantiate_embedding(class_object, params)
return instantiate_embedding(node_type, class_object, params)
elif base_type == "vectorstores":
return instantiate_vectorstore(class_object, params)
elif base_type == "documentloaders":
@ -258,9 +258,13 @@ def instantiate_toolkit(node_type, class_object: Type[BaseToolkit], params: Dict
return loaded_toolkit
def instantiate_embedding(class_object, params: Dict):
def instantiate_embedding(node_type, class_object, params: Dict):
params.pop("model", None)
params.pop("headers", None)
if "VertexAI" in node_type:
return initialize_vertexai(class_object=class_object, params=params)
try:
return class_object(**params)
except ValidationError:

View file

@ -5,6 +5,47 @@ from langflow.template.frontend_node.base import FrontendNode
class EmbeddingFrontendNode(FrontendNode):
def add_extra_fields(self) -> None:
if "VertexAI" in self.template.type_name:
# Add credentials field which should of type file.
self.template.add_field(
TemplateField(
field_type="file",
required=False,
show=True,
name="credentials",
value="",
suffixes=[".json"],
file_types=["json"],
)
)
@staticmethod
def format_vertex_field(field: TemplateField, name: str):
if "VertexAI" in name:
advanced_fields = [
"verbose",
"top_p",
"top_k",
"max_output_tokens",
]
if field.name in advanced_fields:
field.advanced = True
show_fields = [
"verbose",
"project",
"location",
"credentials",
"max_output_tokens",
"model_name",
"temperature",
"top_p",
"top_k",
]
if field.name in show_fields:
field.show = True
@staticmethod
def format_jina_fields(field: TemplateField):
if "jina" in field.name:
@ -41,10 +82,36 @@ class EmbeddingFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
FrontendNode.format_field(field, name)
if name and "vertex" in name.lower():
EmbeddingFrontendNode.format_vertex_field(field, name)
field.advanced = not field.required
field.show = True
if field.name == "headers":
field.show = False
if field.name == "model_kwargs":
field.field_type = "code"
field.advanced = True
field.show = True
elif field.name in [
"model_name",
"temperature",
"model_file",
"model_type",
"deployment_name",
"credentials",
]:
field.advanced = False
field.show = True
if field.name == "credentials":
field.field_type = "file"
if name == "VertexAI" and field.name not in [
"callbacks",
"client",
"stop",
"tags",
"cache",
]:
field.show = True
# Format Jina fields
EmbeddingFrontendNode.format_jina_fields(field)

View file

@ -205,6 +205,7 @@ export const nodeIconsLucide = {
SupabaseVectorStore: SupabaseIcon,
VertexAI: VertexAIIcon,
ChatVertexAI: VertexAIIcon,
VertexAIEmbeddings: VertexAIIcon,
agents: Rocket,
WikipediaAPIWrapper: SvgWikipedia,
chains: Link,

View file

@ -5,6 +5,7 @@ const apiRoutes = ["^/api/v1/", "/health"];
// Use environment variable to determine the target.
const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860";
const port = process.env.VITE_PROXY_PORT || 3000;
const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
proxyObj[route] = {
@ -22,7 +23,7 @@ export default defineConfig(() => {
},
plugins: [react(), svgr()],
server: {
port: 3000,
port: port,
proxy: {
...proxyTargets,
},