diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 963c29549..effeeb001 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -51,6 +51,7 @@ embeddings: llms: - OpenAI # - AzureOpenAI + # - AzureChatOpenAI - ChatOpenAI - LlamaCpp - CTransformers diff --git a/src/backend/langflow/interface/custom_lists.py b/src/backend/langflow/interface/custom_lists.py index a944363ae..34bc0103e 100644 --- a/src/backend/langflow/interface/custom_lists.py +++ b/src/backend/langflow/interface/custom_lists.py @@ -11,14 +11,15 @@ from langchain import ( text_splitter, ) from langchain.agents import agent_toolkits +from langchain.chat_models import AzureChatOpenAI, ChatOpenAI from langchain.chat_models import ChatAnthropic -from langchain.chat_models import ChatOpenAI from langflow.interface.importing.utils import import_class ## LLMs llm_type_to_cls_dict = llms.type_to_cls_dict llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore +llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore ## Chains diff --git a/src/backend/langflow/template/frontend_node/llms.py b/src/backend/langflow/template/frontend_node/llms.py index 272770e2e..39e82422f 100644 --- a/src/backend/langflow/template/frontend_node/llms.py +++ b/src/backend/langflow/template/frontend_node/llms.py @@ -12,6 +12,18 @@ class LLMFrontendNode(FrontendNode): field.name.title().replace("Openai", "OpenAI").replace("_", " ") ).replace("Api", "API") + @staticmethod + def format_azure_field(field: TemplateField): + if field.name == "model_name": + field.show = False # Azure uses deployment_name instead of model_name. + if field.name == "openai_api_type": + field.show = False + field.password = False + field.value = "azure" + if field.name == "openai_api_version": + field.password = False + field.value = "2023-03-15-preview" + @staticmethod def format_field(field: TemplateField, name: Optional[str] = None) -> None: display_names_dict = { @@ -43,8 +55,16 @@ class LLMFrontendNode(FrontendNode): field.field_type = "code" field.advanced = True field.show = True - elif field.name in ["model_name", "temperature", "model_file", "model_type"]: + elif field.name in [ + "model_name", + "temperature", + "model_file", + "model_type", + "deployment_name", + ]: field.advanced = False field.show = True LLMFrontendNode.format_openai_field(field) + if "azure" in name.lower(): + LLMFrontendNode.format_azure_field(field) diff --git a/tests/test_llms_template.py b/tests/test_llms_template.py index ccf2f6388..f54b452f1 100644 --- a/tests/test_llms_template.py +++ b/tests/test_llms_template.py @@ -482,3 +482,77 @@ def test_chat_open_ai(client: TestClient): "ChatOpenAI", "BaseLanguageModel", } + + +def test_azure_open_ai(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + language_models = json_response["llms"] + + model = language_models["AzureOpenAI"] + template = model["template"] + + assert template["model_name"].show is False + assert template["deployment_name"] == { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "", + "password": False, + "name": "deployment_name", + "advanced": False, + "type": "str", + "list": False, + } + + +def test_azure_chat_open_ai(client: TestClient): + response = client.get("/all") + assert response.status_code == 200 + json_response = response.json() + language_models = json_response["llms"] + + model = language_models["AzureChatOpenAI"] + template = model["template"] + + assert template["model_name"].show is False + assert template["deployment_name"] == { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "", + "password": False, + "name": "deployment_name", + "advanced": False, + "type": "str", + "list": False, + } + assert template["openai_api_type"] == { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "value": "azure", + "password": False, + "name": "openai_api_type", + "display_name": "OpenAI API Type", + "advanced": False, + "type": "str", + "list": False, + } + assert template["openai_api_version"] == { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "2023-03-15-preview", + "password": False, + "name": "openai_api_version", + "display_name": "OpenAI API Version", + "advanced": False, + "type": "str", + "list": False, + }