fix(ollama): resolve model list loading issue and add Pytest for component testing (#3575)

* Commit to solve Model not loading issue

The issue was that the url of the models: api/tags was not parsed correctly.
It was having a // hence used urlencode to parse it properly.

Th e correct apporach works only if the base_url is correct,i.e a valid ollama URL:
for DS LF this must be a public ollama Server URL.

* updated the component Ollama Component

changed the get model to take in base url and the function will make the expected url for the model names. This makes the function better, than providing the model url as paramter.

Added Pytest, 7 tests, 1 test excluded for future implememtstion: test_build_model_failure

Make lint and Make format had touched multiple files

* removed unwanted print statements

removed unwanted print statements.

make format, formatted a lot of .tsx files also

* removed skipped tests

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Edwin Jose 2024-08-27 19:21:06 -04:00 committed by GitHub
commit 46a9789028
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 3122 additions and 1371 deletions

View file

@ -1,4 +1,5 @@
from typing import Any
from urllib.parse import urljoin
import httpx
from langchain_community.chat_models import ChatOllama
@ -41,8 +42,7 @@ class ChatOllamaComponent(LCModelComponent):
base_url_value = self.variables(base_url_value)
elif not base_url_value:
base_url_value = "http://localhost:11434"
build_config["model_name"]["options"] = self.get_model(base_url_value + "/api/tags")
build_config["model_name"]["options"] = self.get_model(base_url_value)
if field_name == "keep_alive_flag":
if field_value == "Keep":
build_config["keep_alive"]["value"] = "-1"
@ -55,8 +55,9 @@ class ChatOllamaComponent(LCModelComponent):
return build_config
def get_model(self, url: str) -> list[str]:
def get_model(self, base_url_value: str) -> list[str]:
try:
url = urljoin(base_url_value, "/api/tags")
with httpx.Client() as client:
response = client.get(url)
response.raise_for_status()

View file

@ -0,0 +1,125 @@
import pytest
from unittest.mock import patch, MagicMock
from langflow.components.models.OllamaModel import ChatOllamaComponent
from langchain_community.chat_models.ollama import ChatOllama
from urllib.parse import urljoin
@pytest.fixture
def component():
return ChatOllamaComponent()
@patch("httpx.Client.get")
def test_get_model_success(mock_get, component):
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
base_url = "http://localhost:11434"
model_names = component.get_model(base_url)
expected_url = urljoin(base_url, "/api/tags")
mock_get.assert_called_once_with(expected_url)
assert model_names == ["model1", "model2"]
@patch("httpx.Client.get")
def test_get_model_failure(mock_get, component):
# Mock the response for the HTTP GET request to raise an exception
mock_get.side_effect = Exception("HTTP request failed")
url = "http://localhost:11434/api/tags"
# Assert that the ValueError is raised when an exception occurs
with pytest.raises(ValueError, match="Could not retrieve models"):
component.get_model(url)
def test_update_build_config_mirostat_disabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": 0.1},
"mirostat_tau": {"advanced": False, "value": 5},
}
field_value = "Disabled"
field_name = "mirostat"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is True
assert updated_config["mirostat_tau"]["advanced"] is True
assert updated_config["mirostat_eta"]["value"] is None
assert updated_config["mirostat_tau"]["value"] is None
def test_update_build_config_mirostat_enabled(component):
build_config = {
"mirostat_eta": {"advanced": False, "value": None},
"mirostat_tau": {"advanced": False, "value": None},
}
field_value = "Mirostat 2.0"
field_name = "mirostat"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["mirostat_eta"]["advanced"] is False
assert updated_config["mirostat_tau"]["advanced"] is False
assert updated_config["mirostat_eta"]["value"] == 0.2
assert updated_config["mirostat_tau"]["value"] == 10
@patch("httpx.Client.get")
def test_update_build_config_model_name(mock_get, component):
# Mock the response for the HTTP GET request
mock_response = MagicMock()
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
build_config = {
"base_url": {"load_from_db": False, "value": None},
"model_name": {"options": []},
}
field_value = None
field_name = "model_name"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["model_name"]["options"] == ["model1", "model2"]
def test_update_build_config_keep_alive(component):
build_config = {"keep_alive": {"value": None, "advanced": False}}
field_value = "Keep"
field_name = "keep_alive_flag"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "-1"
assert updated_config["keep_alive"]["advanced"] is True
field_value = "Immediately"
updated_config = component.update_build_config(build_config, field_value, field_name)
assert updated_config["keep_alive"]["value"] == "0"
assert updated_config["keep_alive"]["advanced"] is True
@patch(
"langchain_community.chat_models.ChatOllama",
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
)
def test_build_model(mock_chat_ollama, component):
component.base_url = "http://localhost:11434"
component.model_name = "llama3.1"
component.mirostat = "Mirostat 2.0"
component.mirostat_eta = 0.2 # Ensure this is set as a float
component.mirostat_tau = 10.0 # Ensure this is set as a float
component.temperature = 0.2
component.verbose = True
model = component.build_model()
assert isinstance(model, ChatOllama)
assert model.base_url == "http://localhost:11434"
assert model.model == "llama3.1"

File diff suppressed because it is too large Load diff

View file

@ -6,5 +6,5 @@ export default function getFieldTitle(
): string {
return template[templateField].display_name
? template[templateField].display_name!
: template[templateField].name ?? templateField;
: (template[templateField].name ?? templateField);
}

View file

@ -48,4 +48,4 @@ const CheckBoxDiv = ({
</div>
);
export { CheckBoxDiv, Checkbox };
export { Checkbox, CheckBoxDiv };

View file

@ -296,4 +296,4 @@ async function performStreamingRequest({
}
}
export { ApiInterceptor, api, performStreamingRequest };
export { api, ApiInterceptor, performStreamingRequest };

View file

@ -256,9 +256,9 @@ export default function IOFieldView({
pagination={!left}
rows={
Array.isArray(flowPoolNode?.data?.artifacts)
? flowPoolNode?.data?.artifacts?.map(
? (flowPoolNode?.data?.artifacts?.map(
(artifact) => artifact.data,
) ?? []
) ?? [])
: [flowPoolNode?.data?.artifacts]
}
columnMode="union"

View file

@ -131,8 +131,8 @@ export default function PromptModal({
field_name = Array.isArray(
apiReturn?.frontend_node?.custom_fields?.[""],
)
? apiReturn?.frontend_node?.custom_fields?.[""][0] ?? ""
: apiReturn?.frontend_node?.custom_fields?.[""] ?? "";
? (apiReturn?.frontend_node?.custom_fields?.[""][0] ?? "")
: (apiReturn?.frontend_node?.custom_fields?.[""] ?? "");
}
if (apiReturn) {
let inputVariables = apiReturn.input_variables ?? [];

View file

@ -81,7 +81,7 @@ export default function NodeToolbarComponent({
function minimize() {
if (isMinimal) {
setShowState((show) => !show);
setShowNode(data.showNode ?? true ? false : true);
setShowNode((data.showNode ?? true) ? false : true);
return;
}
setNoticeData({

View file

@ -167,10 +167,10 @@ export default function ProfileSettingsPage(): JSX.Element {
<GradientChooserComponent
value={
gradient == ""
? userData?.profile_image ??
? (userData?.profile_image ??
gradients[
parseInt(userData?.id ?? "", 30) % gradients.length
]
])
: gradient
}
onChange={(value) => {

View file

@ -43,10 +43,10 @@ const ProfileGradientFormComponent = ({
<GradientChooserComponent
value={
gradient == ""
? userData?.profile_image ??
? (userData?.profile_image ??
gradients[
parseInt(userData?.id ?? "", 30) % gradients.length
]
])
: gradient
}
onChange={(value) => {

View file

@ -53,10 +53,10 @@ const ProfilePictureFormComponent = ({
loading={isLoading || isFetching}
value={
profilePicture == ""
? userData?.profile_image ??
? (userData?.profile_image ??
gradients[
parseInt(userData?.id ?? "", 30) % gradients.length
]
])
: profilePicture
}
onChange={(value) => {

View file

@ -127,7 +127,7 @@ export default function StorePage(): JSX.Element {
setTotalRowsCount(
filteredCategories?.length === 0
? Number(res?.count ?? 0)
: res?.results?.length ?? 0,
: (res?.results?.length ?? 0),
);
}
})