From d89f8e8e186b8e47c0d494d68e4c1ed696b63119 Mon Sep 17 00:00:00 2001 From: Raphael Valdetaro <79842132+raphaelchristi@users.noreply.github.com> Date: Thu, 16 Jan 2025 15:37:26 -0300 Subject: [PATCH] refactor: Refactor Wikipedia API component (#5432) * refactor(wikipedia): Refactor Wikipedia API component * test: add unit tests for WikipediaAPIComponent * [autofix.ci] apply automated fixes * refactor: improve WikipediaAPIComponent tests and fix lint issues * [autofix.ci] apply automated fixes * fix: resolve lint issues in WikipediaAPIComponent tests --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Edwin Jose --- .../components/tools/wikipedia_api.py | 28 +++--- .../components/tools/test_wikipedia_api.py | 85 +++++++++++++++++++ 2 files changed, 103 insertions(+), 10 deletions(-) create mode 100644 src/backend/tests/unit/components/tools/test_wikipedia_api.py diff --git a/src/backend/base/langflow/components/tools/wikipedia_api.py b/src/backend/base/langflow/components/tools/wikipedia_api.py index 623378d24..32adaa6b9 100644 --- a/src/backend/base/langflow/components/tools/wikipedia_api.py +++ b/src/backend/base/langflow/components/tools/wikipedia_api.py @@ -1,15 +1,13 @@ -from typing import cast - -from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities.wikipedia import WikipediaAPIWrapper -from langflow.base.langchain_utilities.model import LCToolComponent -from langflow.field_typing import Tool +from langflow.custom import Component from langflow.inputs import BoolInput, IntInput, MessageTextInput, MultilineInput +from langflow.io import Output from langflow.schema import Data +from langflow.schema.message import Message -class WikipediaAPIComponent(LCToolComponent): +class WikipediaAPIComponent(Component): display_name = "Wikipedia API" description = "Call Wikipedia API." name = "WikipediaAPI" @@ -19,6 +17,7 @@ class WikipediaAPIComponent(LCToolComponent): MultilineInput( name="input_value", display_name="Input", + tool_mode=True, ), MessageTextInput(name="lang", display_name="Language", value="en"), IntInput(name="k", display_name="Number of results", value=4, required=True), @@ -28,16 +27,25 @@ class WikipediaAPIComponent(LCToolComponent): ), ] - def run_model(self) -> list[Data]: + outputs = [ + Output(display_name="Data", name="data", method="fetch_content"), + Output(display_name="Text", name="text", method="fetch_content_text"), + ] + + def fetch_content(self) -> list[Data]: wrapper = self._build_wrapper() docs = wrapper.load(self.input_value) data = [Data.from_document(doc) for doc in docs] self.status = data return data - def build_tool(self) -> Tool: - wrapper = self._build_wrapper() - return cast("Tool", WikipediaQueryRun(api_wrapper=wrapper)) + def fetch_content_text(self) -> Message: + data = self.fetch_content() + result_string = "" + for item in data: + result_string += item.text + "\n" + self.status = result_string + return Message(text=result_string) def _build_wrapper(self) -> WikipediaAPIWrapper: return WikipediaAPIWrapper( diff --git a/src/backend/tests/unit/components/tools/test_wikipedia_api.py b/src/backend/tests/unit/components/tools/test_wikipedia_api.py new file mode 100644 index 000000000..fc331caba --- /dev/null +++ b/src/backend/tests/unit/components/tools/test_wikipedia_api.py @@ -0,0 +1,85 @@ +from unittest.mock import MagicMock + +import pytest +from langflow.components.tools import WikipediaAPIComponent +from langflow.custom import Component +from langflow.custom.utils import build_custom_component_template +from langflow.schema import Data +from langflow.schema.message import Message + + +def test_wikipedia_initialization(): + component = WikipediaAPIComponent() + assert component.display_name == "Wikipedia API" + assert component.description == "Call Wikipedia API." + assert component.icon == "Wikipedia" + + +def test_wikipedia_template(): + wikipedia = WikipediaAPIComponent() + component = Component(_code=wikipedia._code) + frontend_node, _ = build_custom_component_template(component) + + # Verify basic structure + assert isinstance(frontend_node, dict) + + # Verify inputs + assert "template" in frontend_node + input_names = [input_["name"] for input_ in frontend_node["template"].values() if isinstance(input_, dict)] + + expected_inputs = ["input_value", "lang", "k", "load_all_available_meta", "doc_content_chars_max"] + + for input_name in expected_inputs: + assert input_name in input_names + + +@pytest.fixture +def mock_wikipedia_wrapper(mocker): + return mocker.patch("langchain_community.utilities.wikipedia.WikipediaAPIWrapper") + + +def test_fetch_content(mock_wikipedia_wrapper): + component = WikipediaAPIComponent() + component.input_value = "test query" + component.k = 3 + component.lang = "en" + + # Mock the WikipediaAPIWrapper and its load method + mock_instance = MagicMock() + mock_wikipedia_wrapper.return_value = mock_instance + mock_doc = MagicMock() + mock_doc.page_content = "Test content" + mock_doc.metadata = {"source": "wikipedia", "title": "Test Page"} + mock_instance.load.return_value = [mock_doc] + + # Mock the _build_wrapper method to return our mock instance + component._build_wrapper = MagicMock(return_value=mock_instance) + + result = component.fetch_content() + + # Verify wrapper was built with correct params + component._build_wrapper.assert_called_once() + mock_instance.load.assert_called_once_with("test query") + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == "Test content" + + +def test_fetch_content_text(): + component = WikipediaAPIComponent() + component.fetch_content = MagicMock(return_value=[Data(text="First result"), Data(text="Second result")]) + + result = component.fetch_content_text() + + assert isinstance(result, Message) + assert result.text == "First result\nSecond result\n" + + +def test_wikipedia_error_handling(): + component = WikipediaAPIComponent() + + # Mock _build_wrapper to raise exception + component._build_wrapper = MagicMock(side_effect=Exception("API Error")) + + with pytest.raises(Exception, match="API Error"): + component.fetch_content()