fix(ollama): resolve model list loading issue and add Pytest for component testing (#3575)
* Commit to solve Model not loading issue The issue was that the url of the models: api/tags was not parsed correctly. It was having a // hence used urlencode to parse it properly. Th e correct apporach works only if the base_url is correct,i.e a valid ollama URL: for DS LF this must be a public ollama Server URL. * updated the component Ollama Component changed the get model to take in base url and the function will make the expected url for the model names. This makes the function better, than providing the model url as paramter. Added Pytest, 7 tests, 1 test excluded for future implememtstion: test_build_model_failure Make lint and Make format had touched multiple files * removed unwanted print statements removed unwanted print statements. make format, formatted a lot of .tsx files also * removed skipped tests * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3222153938
commit
46a9789028
13 changed files with 3122 additions and 1371 deletions
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
|
@ -41,8 +42,7 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
base_url_value = self.variables(base_url_value)
|
||||
elif not base_url_value:
|
||||
base_url_value = "http://localhost:11434"
|
||||
build_config["model_name"]["options"] = self.get_model(base_url_value + "/api/tags")
|
||||
|
||||
build_config["model_name"]["options"] = self.get_model(base_url_value)
|
||||
if field_name == "keep_alive_flag":
|
||||
if field_value == "Keep":
|
||||
build_config["keep_alive"]["value"] = "-1"
|
||||
|
|
@ -55,8 +55,9 @@ class ChatOllamaComponent(LCModelComponent):
|
|||
|
||||
return build_config
|
||||
|
||||
def get_model(self, url: str) -> list[str]:
|
||||
def get_model(self, base_url_value: str) -> list[str]:
|
||||
try:
|
||||
url = urljoin(base_url_value, "/api/tags")
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url)
|
||||
response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from langflow.components.models.OllamaModel import ChatOllamaComponent
|
||||
from langchain_community.chat_models.ollama import ChatOllama
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def component():
|
||||
return ChatOllamaComponent()
|
||||
|
||||
|
||||
@patch("httpx.Client.get")
|
||||
def test_get_model_success(mock_get, component):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
model_names = component.get_model(base_url)
|
||||
|
||||
expected_url = urljoin(base_url, "/api/tags")
|
||||
|
||||
mock_get.assert_called_once_with(expected_url)
|
||||
|
||||
assert model_names == ["model1", "model2"]
|
||||
|
||||
|
||||
@patch("httpx.Client.get")
|
||||
def test_get_model_failure(mock_get, component):
|
||||
# Mock the response for the HTTP GET request to raise an exception
|
||||
mock_get.side_effect = Exception("HTTP request failed")
|
||||
|
||||
url = "http://localhost:11434/api/tags"
|
||||
|
||||
# Assert that the ValueError is raised when an exception occurs
|
||||
with pytest.raises(ValueError, match="Could not retrieve models"):
|
||||
component.get_model(url)
|
||||
|
||||
|
||||
def test_update_build_config_mirostat_disabled(component):
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": 0.1},
|
||||
"mirostat_tau": {"advanced": False, "value": 5},
|
||||
}
|
||||
field_value = "Disabled"
|
||||
field_name = "mirostat"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is True
|
||||
assert updated_config["mirostat_tau"]["advanced"] is True
|
||||
assert updated_config["mirostat_eta"]["value"] is None
|
||||
assert updated_config["mirostat_tau"]["value"] is None
|
||||
|
||||
|
||||
def test_update_build_config_mirostat_enabled(component):
|
||||
build_config = {
|
||||
"mirostat_eta": {"advanced": False, "value": None},
|
||||
"mirostat_tau": {"advanced": False, "value": None},
|
||||
}
|
||||
field_value = "Mirostat 2.0"
|
||||
field_name = "mirostat"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["mirostat_eta"]["advanced"] is False
|
||||
assert updated_config["mirostat_tau"]["advanced"] is False
|
||||
assert updated_config["mirostat_eta"]["value"] == 0.2
|
||||
assert updated_config["mirostat_tau"]["value"] == 10
|
||||
|
||||
|
||||
@patch("httpx.Client.get")
|
||||
def test_update_build_config_model_name(mock_get, component):
|
||||
# Mock the response for the HTTP GET request
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
build_config = {
|
||||
"base_url": {"load_from_db": False, "value": None},
|
||||
"model_name": {"options": []},
|
||||
}
|
||||
field_value = None
|
||||
field_name = "model_name"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
|
||||
assert updated_config["model_name"]["options"] == ["model1", "model2"]
|
||||
|
||||
|
||||
def test_update_build_config_keep_alive(component):
|
||||
build_config = {"keep_alive": {"value": None, "advanced": False}}
|
||||
field_value = "Keep"
|
||||
field_name = "keep_alive_flag"
|
||||
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "-1"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
field_value = "Immediately"
|
||||
updated_config = component.update_build_config(build_config, field_value, field_name)
|
||||
assert updated_config["keep_alive"]["value"] == "0"
|
||||
assert updated_config["keep_alive"]["advanced"] is True
|
||||
|
||||
|
||||
@patch(
|
||||
"langchain_community.chat_models.ChatOllama",
|
||||
return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"),
|
||||
)
|
||||
def test_build_model(mock_chat_ollama, component):
|
||||
component.base_url = "http://localhost:11434"
|
||||
component.model_name = "llama3.1"
|
||||
component.mirostat = "Mirostat 2.0"
|
||||
component.mirostat_eta = 0.2 # Ensure this is set as a float
|
||||
component.mirostat_tau = 10.0 # Ensure this is set as a float
|
||||
component.temperature = 0.2
|
||||
component.verbose = True
|
||||
model = component.build_model()
|
||||
assert isinstance(model, ChatOllama)
|
||||
assert model.base_url == "http://localhost:11434"
|
||||
assert model.model == "llama3.1"
|
||||
4331
src/frontend/package-lock.json
generated
4331
src/frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -6,5 +6,5 @@ export default function getFieldTitle(
|
|||
): string {
|
||||
return template[templateField].display_name
|
||||
? template[templateField].display_name!
|
||||
: template[templateField].name ?? templateField;
|
||||
: (template[templateField].name ?? templateField);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,4 +48,4 @@ const CheckBoxDiv = ({
|
|||
</div>
|
||||
);
|
||||
|
||||
export { CheckBoxDiv, Checkbox };
|
||||
export { Checkbox, CheckBoxDiv };
|
||||
|
|
|
|||
|
|
@ -296,4 +296,4 @@ async function performStreamingRequest({
|
|||
}
|
||||
}
|
||||
|
||||
export { ApiInterceptor, api, performStreamingRequest };
|
||||
export { api, ApiInterceptor, performStreamingRequest };
|
||||
|
|
|
|||
|
|
@ -256,9 +256,9 @@ export default function IOFieldView({
|
|||
pagination={!left}
|
||||
rows={
|
||||
Array.isArray(flowPoolNode?.data?.artifacts)
|
||||
? flowPoolNode?.data?.artifacts?.map(
|
||||
? (flowPoolNode?.data?.artifacts?.map(
|
||||
(artifact) => artifact.data,
|
||||
) ?? []
|
||||
) ?? [])
|
||||
: [flowPoolNode?.data?.artifacts]
|
||||
}
|
||||
columnMode="union"
|
||||
|
|
|
|||
|
|
@ -131,8 +131,8 @@ export default function PromptModal({
|
|||
field_name = Array.isArray(
|
||||
apiReturn?.frontend_node?.custom_fields?.[""],
|
||||
)
|
||||
? apiReturn?.frontend_node?.custom_fields?.[""][0] ?? ""
|
||||
: apiReturn?.frontend_node?.custom_fields?.[""] ?? "";
|
||||
? (apiReturn?.frontend_node?.custom_fields?.[""][0] ?? "")
|
||||
: (apiReturn?.frontend_node?.custom_fields?.[""] ?? "");
|
||||
}
|
||||
if (apiReturn) {
|
||||
let inputVariables = apiReturn.input_variables ?? [];
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ export default function NodeToolbarComponent({
|
|||
function minimize() {
|
||||
if (isMinimal) {
|
||||
setShowState((show) => !show);
|
||||
setShowNode(data.showNode ?? true ? false : true);
|
||||
setShowNode((data.showNode ?? true) ? false : true);
|
||||
return;
|
||||
}
|
||||
setNoticeData({
|
||||
|
|
|
|||
|
|
@ -167,10 +167,10 @@ export default function ProfileSettingsPage(): JSX.Element {
|
|||
<GradientChooserComponent
|
||||
value={
|
||||
gradient == ""
|
||||
? userData?.profile_image ??
|
||||
? (userData?.profile_image ??
|
||||
gradients[
|
||||
parseInt(userData?.id ?? "", 30) % gradients.length
|
||||
]
|
||||
])
|
||||
: gradient
|
||||
}
|
||||
onChange={(value) => {
|
||||
|
|
|
|||
|
|
@ -43,10 +43,10 @@ const ProfileGradientFormComponent = ({
|
|||
<GradientChooserComponent
|
||||
value={
|
||||
gradient == ""
|
||||
? userData?.profile_image ??
|
||||
? (userData?.profile_image ??
|
||||
gradients[
|
||||
parseInt(userData?.id ?? "", 30) % gradients.length
|
||||
]
|
||||
])
|
||||
: gradient
|
||||
}
|
||||
onChange={(value) => {
|
||||
|
|
|
|||
|
|
@ -53,10 +53,10 @@ const ProfilePictureFormComponent = ({
|
|||
loading={isLoading || isFetching}
|
||||
value={
|
||||
profilePicture == ""
|
||||
? userData?.profile_image ??
|
||||
? (userData?.profile_image ??
|
||||
gradients[
|
||||
parseInt(userData?.id ?? "", 30) % gradients.length
|
||||
]
|
||||
])
|
||||
: profilePicture
|
||||
}
|
||||
onChange={(value) => {
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ export default function StorePage(): JSX.Element {
|
|||
setTotalRowsCount(
|
||||
filteredCategories?.length === 0
|
||||
? Number(res?.count ?? 0)
|
||||
: res?.results?.length ?? 0,
|
||||
: (res?.results?.length ?? 0),
|
||||
);
|
||||
}
|
||||
})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue