diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 2f7f69f14..5a22d989f 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -10,8 +10,12 @@ from langchain import ( text_splitter, ) from langchain.agents import agent_toolkits -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI -from langchain.chat_models import ChatAnthropic +from langchain.chat_models import ( + AzureChatOpenAI, + ChatOpenAI, + ChatVertexAI, + ChatAnthropic, +) from langflow.interface.importing.utils import import_class from langflow.interface.agents.custom import CUSTOM_AGENTS @@ -22,6 +26,7 @@ llm_type_to_cls_dict = llms.type_to_cls_dict llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore +llm_type_to_cls_dict["vertexai-chat"] = ChatVertexAI # type: ignore # Toolkits diff --git a/src/backend/langflow/template/frontend_node/llms.py b/src/backend/langflow/template/frontend_node/llms.py index c1080ea2d..7d2ab219d 100644 --- a/src/backend/langflow/template/frontend_node/llms.py +++ b/src/backend/langflow/template/frontend_node/llms.py @@ -7,12 +7,12 @@ from langflow.template.frontend_node.constants import OPENAI_API_BASE_INFO class LLMFrontendNode(FrontendNode): def add_extra_fields(self) -> None: - if self.template.type_name == "VertexAI": + if "VertexAI" in self.template.type_name: # Add credentials field which should of type file. self.template.add_field( TemplateField( field_type="file", - required=True, + required=False, show=True, name="credentials", value="", @@ -21,6 +21,34 @@ class LLMFrontendNode(FrontendNode): ) ) + @staticmethod + def format_vertex_field(field: TemplateField, name: str): + if "VertexAI" in name: + advanced_fields = [ + "tuned_model_name", + "verbose", + "top_p", + "top_k", + "max_output_tokens", + ] + if field.name in advanced_fields: + field.advanced = True + show_fields = [ + "tuned_model_name", + "verbose", + "project", + "location", + "credentials", + "max_output_tokens", + "model_name", + "temperature", + "top_p", + "top_k", + ] + + if field.name in show_fields: + field.show = True + @staticmethod def format_openai_field(field: TemplateField): if "openai" in field.name.lower(): @@ -61,6 +89,8 @@ class LLMFrontendNode(FrontendNode): LLMFrontendNode.format_azure_field(field) if name and "llama" in name.lower(): LLMFrontendNode.format_llama_field(field) + if name and "vertex" in name.lower(): + LLMFrontendNode.format_vertex_field(field, name) SHOW_FIELDS = ["repo_id"] if field.name in SHOW_FIELDS: field.show = True