Initial support for Azure LLM nodes.

There are still some rough edges due to underlying langchain and
openai API limitations, e.g. hwchase17/langchain#3769 and
openai/openai-python#411. Notably, you can't use the Azure and
non-Azure node types in the same server, since there's global openai
configuration needed to choose between the two. So it's probably best
to still leave the Azure node types commented out in the default
config. But with this PR, if you uncomment those nodes and start the
server with OPENAI_API_TYPE=azure, you will have working Azure nodes.
This commit is contained in:
Jacob Lee 2023-06-05 08:56:44 -05:00
commit 5b28bbb795
4 changed files with 98 additions and 2 deletions

View file

@ -51,6 +51,7 @@ embeddings:
llms:
- OpenAI
# - AzureOpenAI
# - AzureChatOpenAI
- ChatOpenAI
- LlamaCpp
- CTransformers

View file

@ -11,14 +11,15 @@ from langchain import (
text_splitter,
)
from langchain.agents import agent_toolkits
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.chat_models import ChatAnthropic
from langchain.chat_models import ChatOpenAI
from langflow.interface.importing.utils import import_class
## LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Chains

View file

@ -12,6 +12,18 @@ class LLMFrontendNode(FrontendNode):
field.name.title().replace("Openai", "OpenAI").replace("_", " ")
).replace("Api", "API")
@staticmethod
def format_azure_field(field: TemplateField):
if field.name == "model_name":
field.show = False # Azure uses deployment_name instead of model_name.
if field.name == "openai_api_type":
field.show = False
field.password = False
field.value = "azure"
if field.name == "openai_api_version":
field.password = False
field.value = "2023-03-15-preview"
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
display_names_dict = {
@ -43,8 +55,16 @@ class LLMFrontendNode(FrontendNode):
field.field_type = "code"
field.advanced = True
field.show = True
elif field.name in ["model_name", "temperature", "model_file", "model_type"]:
elif field.name in [
"model_name",
"temperature",
"model_file",
"model_type",
"deployment_name",
]:
field.advanced = False
field.show = True
LLMFrontendNode.format_openai_field(field)
if "azure" in name.lower():
LLMFrontendNode.format_azure_field(field)

View file

@ -482,3 +482,77 @@ def test_chat_open_ai(client: TestClient):
"ChatOpenAI",
"BaseLanguageModel",
}
def test_azure_open_ai(client: TestClient):
response = client.get("/all")
assert response.status_code == 200
json_response = response.json()
language_models = json_response["llms"]
model = language_models["AzureOpenAI"]
template = model["template"]
assert template["model_name"].show is False
assert template["deployment_name"] == {
"required": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "",
"password": False,
"name": "deployment_name",
"advanced": False,
"type": "str",
"list": False,
}
def test_azure_chat_open_ai(client: TestClient):
response = client.get("/all")
assert response.status_code == 200
json_response = response.json()
language_models = json_response["llms"]
model = language_models["AzureChatOpenAI"]
template = model["template"]
assert template["model_name"].show is False
assert template["deployment_name"] == {
"required": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "",
"password": False,
"name": "deployment_name",
"advanced": False,
"type": "str",
"list": False,
}
assert template["openai_api_type"] == {
"required": False,
"placeholder": "",
"show": False,
"multiline": False,
"value": "azure",
"password": False,
"name": "openai_api_type",
"display_name": "OpenAI API Type",
"advanced": False,
"type": "str",
"list": False,
}
assert template["openai_api_version"] == {
"required": False,
"placeholder": "",
"show": True,
"multiline": False,
"value": "2023-03-15-preview",
"password": False,
"name": "openai_api_version",
"display_name": "OpenAI API Version",
"advanced": False,
"type": "str",
"list": False,
}