add support for VertexAIEmbeddings node
This commit is contained in:
parent
35af73822a
commit
bb2b8fbb3d
8 changed files with 121 additions and 4 deletions
|
|
@ -12,4 +12,3 @@ RUN rm *.whl
|
|||
EXPOSE 80
|
||||
|
||||
CMD [ "uvicorn", "--host", "0.0.0.0", "--port", "7860", "--factory", "langflow.main:create_app" ]
|
||||
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
|
|||
elif base_type == "toolkits":
|
||||
return instantiate_toolkit(node_type, class_object, params)
|
||||
elif base_type == "embeddings":
|
||||
return instantiate_embedding(class_object, params)
|
||||
return instantiate_embedding(node_type, class_object, params)
|
||||
elif base_type == "vectorstores":
|
||||
return instantiate_vectorstore(class_object, params)
|
||||
elif base_type == "documentloaders":
|
||||
|
|
@ -258,9 +258,13 @@ def instantiate_toolkit(node_type, class_object: Type[BaseToolkit], params: Dict
|
|||
return loaded_toolkit
|
||||
|
||||
|
||||
def instantiate_embedding(class_object, params: Dict):
|
||||
def instantiate_embedding(node_type, class_object, params: Dict):
|
||||
params.pop("model", None)
|
||||
params.pop("headers", None)
|
||||
|
||||
if "VertexAI" in node_type:
|
||||
return initialize_vertexai(class_object=class_object, params=params)
|
||||
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,47 @@ from langflow.template.frontend_node.base import FrontendNode
|
|||
|
||||
|
||||
class EmbeddingFrontendNode(FrontendNode):
|
||||
def add_extra_fields(self) -> None:
|
||||
if "VertexAI" in self.template.type_name:
|
||||
# Add credentials field which should of type file.
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="file",
|
||||
required=False,
|
||||
show=True,
|
||||
name="credentials",
|
||||
value="",
|
||||
suffixes=[".json"],
|
||||
file_types=["json"],
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_vertex_field(field: TemplateField, name: str):
|
||||
if "VertexAI" in name:
|
||||
advanced_fields = [
|
||||
"verbose",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"max_output_tokens",
|
||||
]
|
||||
if field.name in advanced_fields:
|
||||
field.advanced = True
|
||||
show_fields = [
|
||||
"verbose",
|
||||
"project",
|
||||
"location",
|
||||
"credentials",
|
||||
"max_output_tokens",
|
||||
"model_name",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
]
|
||||
|
||||
if field.name in show_fields:
|
||||
field.show = True
|
||||
|
||||
@staticmethod
|
||||
def format_jina_fields(field: TemplateField):
|
||||
if "jina" in field.name:
|
||||
|
|
@ -41,10 +82,36 @@ class EmbeddingFrontendNode(FrontendNode):
|
|||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
if name and "vertex" in name.lower():
|
||||
EmbeddingFrontendNode.format_vertex_field(field, name)
|
||||
field.advanced = not field.required
|
||||
field.show = True
|
||||
if field.name == "headers":
|
||||
field.show = False
|
||||
if field.name == "model_kwargs":
|
||||
field.field_type = "code"
|
||||
field.advanced = True
|
||||
field.show = True
|
||||
elif field.name in [
|
||||
"model_name",
|
||||
"temperature",
|
||||
"model_file",
|
||||
"model_type",
|
||||
"deployment_name",
|
||||
"credentials",
|
||||
]:
|
||||
field.advanced = False
|
||||
field.show = True
|
||||
if field.name == "credentials":
|
||||
field.field_type = "file"
|
||||
if name == "VertexAI" and field.name not in [
|
||||
"callbacks",
|
||||
"client",
|
||||
"stop",
|
||||
"tags",
|
||||
"cache",
|
||||
]:
|
||||
field.show = True
|
||||
|
||||
# Format Jina fields
|
||||
EmbeddingFrontendNode.format_jina_fields(field)
|
||||
|
|
|
|||
|
|
@ -205,6 +205,7 @@ export const nodeIconsLucide = {
|
|||
SupabaseVectorStore: SupabaseIcon,
|
||||
VertexAI: VertexAIIcon,
|
||||
ChatVertexAI: VertexAIIcon,
|
||||
VertexAIEmbeddings: VertexAIIcon,
|
||||
agents: Rocket,
|
||||
WikipediaAPIWrapper: SvgWikipedia,
|
||||
chains: Link,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ const apiRoutes = ["^/api/v1/", "/health"];
|
|||
|
||||
// Use environment variable to determine the target.
|
||||
const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860";
|
||||
const port = process.env.VITE_PROXY_PORT || 3000;
|
||||
|
||||
const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
|
||||
proxyObj[route] = {
|
||||
|
|
@ -22,7 +23,7 @@ export default defineConfig(() => {
|
|||
},
|
||||
plugins: [react(), svgr()],
|
||||
server: {
|
||||
port: 3000,
|
||||
port: port,
|
||||
proxy: {
|
||||
...proxyTargets,
|
||||
},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue