From 06aea7da88c577f3c2194dbc0d06750d029adda7 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 30 Jun 2023 11:09:13 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20chore(custom=5Flists.py):=20refo?= =?UTF-8?q?rmat=20import=20statements=20for=20better=20readability=20and?= =?UTF-8?q?=20maintainability=20=F0=9F=94=80=20chore(custom=5Flists.py):?= =?UTF-8?q?=20add=20ChatVertexAI=20to=20the=20import=20statements=20for=20?= =?UTF-8?q?better=20modularity=20and=20extensibility=20=F0=9F=94=80=20chor?= =?UTF-8?q?e(custom=5Flists.py):=20add=20ChatVertexAI=20to=20the=20llm=5Ft?= =?UTF-8?q?ype=5Fto=5Fcls=5Fdict=20for=20better=20compatibility=20and=20fl?= =?UTF-8?q?exibility=20=F0=9F=94=80=20chore(llms.py):=20change=20required?= =?UTF-8?q?=20field=20for=20credentials=20to=20be=20optional=20for=20bette?= =?UTF-8?q?r=20user=20experience=20=F0=9F=94=80=20chore(llms.py):=20add=20?= =?UTF-8?q?advanced=20and=20show=20fields=20for=20specific=20fields=20rela?= =?UTF-8?q?ted=20to=20VertexAI=20for=20better=20configurability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The import statements in `custom_lists.py` have been reformatted to improve readability and maintainability. The `ChatVertexAI` class has been added to the import statements to enhance modularity and extensibility. The `ChatVertexAI` class has been added to the `llm_type_to_cls_dict` dictionary in `custom_lists.py` to improve compatibility and flexibility. In `llms.py`, the `required` field for the `credentials` field has been changed to be optional for a better user experience. The `advanced` and `show` fields have been added to specific fields related to VertexAI in `llms.py` to provide better configurability. --- .../langflow/interface/custom_lists.py | 9 +++-- .../langflow/template/frontend_node/llms.py | 34 +++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index 2f7f69f14..5a22d989f 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -10,8 +10,12 @@ from langchain import ( text_splitter, ) from langchain.agents import agent_toolkits -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI -from langchain.chat_models import ChatAnthropic +from langchain.chat_models import ( + AzureChatOpenAI, + ChatOpenAI, + ChatVertexAI, + ChatAnthropic, +) from langflow.interface.importing.utils import import_class from langflow.interface.agents.custom import CUSTOM_AGENTS @@ -22,6 +26,7 @@ llm_type_to_cls_dict = llms.type_to_cls_dict llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore +llm_type_to_cls_dict["vertexai-chat"] = ChatVertexAI # type: ignore # Toolkits diff --git a/src/backend/langflow/template/frontend_node/llms.py b/src/backend/langflow/template/frontend_node/llms.py index c1080ea2d..7d2ab219d 100644 --- a/src/backend/langflow/template/frontend_node/llms.py +++ b/src/backend/langflow/template/frontend_node/llms.py @@ -7,12 +7,12 @@ from langflow.template.frontend_node.constants import OPENAI_API_BASE_INFO class LLMFrontendNode(FrontendNode): def add_extra_fields(self) -> None: - if self.template.type_name == "VertexAI": + if "VertexAI" in self.template.type_name: # Add credentials field which should of type file. self.template.add_field( TemplateField( field_type="file", - required=True, + required=False, show=True, name="credentials", value="", @@ -21,6 +21,34 @@ class LLMFrontendNode(FrontendNode): ) ) + @staticmethod + def format_vertex_field(field: TemplateField, name: str): + if "VertexAI" in name: + advanced_fields = [ + "tuned_model_name", + "verbose", + "top_p", + "top_k", + "max_output_tokens", + ] + if field.name in advanced_fields: + field.advanced = True + show_fields = [ + "tuned_model_name", + "verbose", + "project", + "location", + "credentials", + "max_output_tokens", + "model_name", + "temperature", + "top_p", + "top_k", + ] + + if field.name in show_fields: + field.show = True + @staticmethod def format_openai_field(field: TemplateField): if "openai" in field.name.lower(): @@ -61,6 +89,8 @@ class LLMFrontendNode(FrontendNode): LLMFrontendNode.format_azure_field(field) if name and "llama" in name.lower(): LLMFrontendNode.format_llama_field(field) + if name and "vertex" in name.lower(): + LLMFrontendNode.format_vertex_field(field, name) SHOW_FIELDS = ["repo_id"] if field.name in SHOW_FIELDS: field.show = True