Initial support for Azure LLM nodes. (#443)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-06 10:14:31 -03:00 committed by GitHub
commit 94b346196b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 2 deletions

View file

@ -51,6 +51,7 @@ embeddings:
llms:
- OpenAI
# - AzureOpenAI
# - AzureChatOpenAI
- ChatOpenAI
- LlamaCpp
- CTransformers

View file

@ -11,14 +11,15 @@ from langchain import (
text_splitter,
)
from langchain.agents import agent_toolkits
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.chat_models import ChatAnthropic
from langchain.chat_models import ChatOpenAI
from langflow.interface.importing.utils import import_class
## LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Chains

View file

@ -12,6 +12,18 @@ class LLMFrontendNode(FrontendNode):
field.name.title().replace("Openai", "OpenAI").replace("_", " ")
).replace("Api", "API")
@staticmethod
def format_azure_field(field: TemplateField):
if field.name == "model_name":
field.show = False # Azure uses deployment_name instead of model_name.
if field.name == "openai_api_type":
field.show = False
field.password = False
field.value = "azure"
if field.name == "openai_api_version":
field.password = False
field.value = "2023-03-15-preview"
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
display_names_dict = {
@ -43,8 +55,16 @@ class LLMFrontendNode(FrontendNode):
field.field_type = "code"
field.advanced = True
field.show = True
elif field.name in ["model_name", "temperature", "model_file", "model_type"]:
elif field.name in [
"model_name",
"temperature",
"model_file",
"model_type",
"deployment_name",
]:
field.advanced = False
field.show = True
LLMFrontendNode.format_openai_field(field)
if "azure" in name.lower():
LLMFrontendNode.format_azure_field(field)

View file

@ -482,3 +482,77 @@ def test_chat_open_ai(client: TestClient):
"ChatOpenAI",
"BaseLanguageModel",
}
def test_azure_open_ai(client: TestClient):
response = client.get("/all")
assert response.status_code == 200
json_response = response.json()
language_models = json_response["llms"]
model = language_models["AzureOpenAI"]
template = model["template"]
assert template["model_name"].show is False
assert template["deployment_name"] == {
"required": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "",
"password": False,
"name": "deployment_name",
"advanced": False,
"type": "str",
"list": False,
}
def test_azure_chat_open_ai(client: TestClient):
response = client.get("/all")
assert response.status_code == 200
json_response = response.json()
language_models = json_response["llms"]
model = language_models["AzureChatOpenAI"]
template = model["template"]
assert template["model_name"].show is False
assert template["deployment_name"] == {
"required": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "",
"password": False,
"name": "deployment_name",
"advanced": False,
"type": "str",
"list": False,
}
assert template["openai_api_type"] == {
"required": False,
"placeholder": "",
"show": False,
"multiline": False,
"value": "azure",
"password": False,
"name": "openai_api_type",
"display_name": "OpenAI API Type",
"advanced": False,
"type": "str",
"list": False,
}
assert template["openai_api_version"] == {
"required": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "2023-03-15-preview",
"password": False,
"name": "openai_api_version",
"display_name": "OpenAI API Version",
"advanced": False,
"type": "str",
"list": False,
}