Initial support for Azure LLM nodes. (#443)
This commit is contained in:
commit
94b346196b
4 changed files with 98 additions and 2 deletions
|
|
@ -51,6 +51,7 @@ embeddings:
|
|||
llms:
|
||||
- OpenAI
|
||||
# - AzureOpenAI
|
||||
# - AzureChatOpenAI
|
||||
- ChatOpenAI
|
||||
- LlamaCpp
|
||||
- CTransformers
|
||||
|
|
|
|||
|
|
@ -11,14 +11,15 @@ from langchain import (
|
|||
text_splitter,
|
||||
)
|
||||
from langchain.agents import agent_toolkits
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
|
||||
## LLMs
|
||||
llm_type_to_cls_dict = llms.type_to_cls_dict
|
||||
llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
|
||||
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
|
||||
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
|
||||
|
||||
## Chains
|
||||
|
|
|
|||
|
|
@ -12,6 +12,18 @@ class LLMFrontendNode(FrontendNode):
|
|||
field.name.title().replace("Openai", "OpenAI").replace("_", " ")
|
||||
).replace("Api", "API")
|
||||
|
||||
@staticmethod
|
||||
def format_azure_field(field: TemplateField):
|
||||
if field.name == "model_name":
|
||||
field.show = False # Azure uses deployment_name instead of model_name.
|
||||
if field.name == "openai_api_type":
|
||||
field.show = False
|
||||
field.password = False
|
||||
field.value = "azure"
|
||||
if field.name == "openai_api_version":
|
||||
field.password = False
|
||||
field.value = "2023-03-15-preview"
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
display_names_dict = {
|
||||
|
|
@ -43,8 +55,16 @@ class LLMFrontendNode(FrontendNode):
|
|||
field.field_type = "code"
|
||||
field.advanced = True
|
||||
field.show = True
|
||||
elif field.name in ["model_name", "temperature", "model_file", "model_type"]:
|
||||
elif field.name in [
|
||||
"model_name",
|
||||
"temperature",
|
||||
"model_file",
|
||||
"model_type",
|
||||
"deployment_name",
|
||||
]:
|
||||
field.advanced = False
|
||||
field.show = True
|
||||
|
||||
LLMFrontendNode.format_openai_field(field)
|
||||
if "azure" in name.lower():
|
||||
LLMFrontendNode.format_azure_field(field)
|
||||
|
|
|
|||
|
|
@ -482,3 +482,77 @@ def test_chat_open_ai(client: TestClient):
|
|||
"ChatOpenAI",
|
||||
"BaseLanguageModel",
|
||||
}
|
||||
|
||||
|
||||
def test_azure_open_ai(client: TestClient):
|
||||
response = client.get("/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
language_models = json_response["llms"]
|
||||
|
||||
model = language_models["AzureOpenAI"]
|
||||
template = model["template"]
|
||||
|
||||
assert template["model_name"].show is False
|
||||
assert template["deployment_name"] == {
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "",
|
||||
"password": False,
|
||||
"name": "deployment_name",
|
||||
"advanced": False,
|
||||
"type": "str",
|
||||
"list": False,
|
||||
}
|
||||
|
||||
|
||||
def test_azure_chat_open_ai(client: TestClient):
|
||||
response = client.get("/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
language_models = json_response["llms"]
|
||||
|
||||
model = language_models["AzureChatOpenAI"]
|
||||
template = model["template"]
|
||||
|
||||
assert template["model_name"].show is False
|
||||
assert template["deployment_name"] == {
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "",
|
||||
"password": False,
|
||||
"name": "deployment_name",
|
||||
"advanced": False,
|
||||
"type": "str",
|
||||
"list": False,
|
||||
}
|
||||
assert template["openai_api_type"] == {
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
"value": "azure",
|
||||
"password": False,
|
||||
"name": "openai_api_type",
|
||||
"display_name": "OpenAI API Type",
|
||||
"advanced": False,
|
||||
"type": "str",
|
||||
"list": False,
|
||||
}
|
||||
assert template["openai_api_version"] == {
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": False,
|
||||
"value": "2023-03-15-preview",
|
||||
"password": False,
|
||||
"name": "openai_api_version",
|
||||
"display_name": "OpenAI API Version",
|
||||
"advanced": False,
|
||||
"type": "str",
|
||||
"list": False,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue