From bd8dbdfab61f0cc2ad157fed3754fee44f6ac7a9 Mon Sep 17 00:00:00 2001
From: Raphael Valdetaro <79842132+raphaelchristi@users.noreply.github.com>
Date: Tue, 21 Jan 2025 17:20:26 -0300
Subject: [PATCH] feat: add arxiv component (#5634)
* feat: add arxiv component
* [autofix.ci] apply automated fixes
* test: add initial test suite for arxiv component
* fix: correct test formatting for ArXiv component
* fix: implement tests for ArXivComponent following TestBatchRunComponent pattern
* fix: ArXivComponent test formatting
* [autofix.ci] apply automated fixes
* refactor: update imports and skip version tests for new component
* fix: fix line breaks in test file
* [autofix.ci] apply automated fixes
---------
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
---
.../langflow/components/tools/__init__.py | 2 +
.../base/langflow/components/tools/arxiv.py | 150 ++++++++++++++++++
.../starter_projects/Blog Writer.json | 1 +
.../components/tools/test_arxiv_component.py | 124 +++++++++++++++
src/frontend/src/icons/ArXiv/ArXivIcon.jsx | 122 ++++++++++++++
src/frontend/src/icons/ArXiv/arxiv.svg | 24 +++
src/frontend/src/icons/ArXiv/index.tsx | 10 ++
src/frontend/src/utils/styleUtils.ts | 2 +
8 files changed, 435 insertions(+)
create mode 100644 src/backend/base/langflow/components/tools/arxiv.py
create mode 100644 src/backend/tests/unit/components/tools/test_arxiv_component.py
create mode 100644 src/frontend/src/icons/ArXiv/ArXivIcon.jsx
create mode 100644 src/frontend/src/icons/ArXiv/arxiv.svg
create mode 100644 src/frontend/src/icons/ArXiv/index.tsx
diff --git a/src/backend/base/langflow/components/tools/__init__.py b/src/backend/base/langflow/components/tools/__init__.py
index 3255eaf33..4d30e140a 100644
--- a/src/backend/base/langflow/components/tools/__init__.py
+++ b/src/backend/base/langflow/components/tools/__init__.py
@@ -2,6 +2,7 @@ import warnings
from langchain_core._api.deprecation import LangChainDeprecationWarning
+from .arxiv import ArXivComponent
from .bing_search_api import BingSearchAPIComponent
from .calculator import CalculatorToolComponent
from .calculator_core import CalculatorComponent
@@ -35,6 +36,7 @@ with warnings.catch_warnings():
from .astradb_cql import AstraDBCQLToolComponent
__all__ = [
+ "ArXivComponent",
"AstraDBCQLToolComponent",
"AstraDBToolComponent",
"BingSearchAPIComponent",
diff --git a/src/backend/base/langflow/components/tools/arxiv.py b/src/backend/base/langflow/components/tools/arxiv.py
new file mode 100644
index 000000000..8f5cdf63d
--- /dev/null
+++ b/src/backend/base/langflow/components/tools/arxiv.py
@@ -0,0 +1,150 @@
+import urllib.request
+from urllib.parse import urlparse
+from xml.etree.ElementTree import Element
+
+from defusedxml.ElementTree import fromstring
+
+from langflow.custom import Component
+from langflow.io import DropdownInput, IntInput, MessageTextInput, Output
+from langflow.schema import Data
+
+
+class ArXivComponent(Component):
+ display_name = "arXiv"
+ description = "Search and retrieve papers from arXiv.org"
+ icon = "arXiv"
+
+ inputs = [
+ MessageTextInput(
+ name="search_query",
+ display_name="Search Query",
+ info="The search query for arXiv papers (e.g., 'quantum computing')",
+ tool_mode=True,
+ ),
+ DropdownInput(
+ name="search_type",
+ display_name="Search Field",
+ info="The field to search in",
+ options=["all", "title", "abstract", "author", "cat"], # cat is for category
+ value="all",
+ ),
+ IntInput(
+ name="max_results",
+ display_name="Max Results",
+ info="Maximum number of results to return",
+ value=10,
+ ),
+ ]
+
+ outputs = [
+ Output(display_name="Papers", name="papers", method="search_papers"),
+ ]
+
+ def build_query_url(self) -> str:
+ """Build the arXiv API query URL."""
+ base_url = "http://export.arxiv.org/api/query?"
+
+ # Build the search query
+ search_query = f"{self.search_type}:{self.search_query}"
+
+ # URL parameters
+ params = {
+ "search_query": search_query,
+ "max_results": str(self.max_results),
+ }
+
+ # Convert params to URL query string
+ query_string = "&".join([f"{k}={urllib.parse.quote(str(v))}" for k, v in params.items()])
+
+ return base_url + query_string
+
+ def parse_atom_response(self, response_text: str) -> list[dict]:
+ """Parse the Atom XML response from arXiv."""
+ # Parse XML safely using defusedxml
+ root = fromstring(response_text)
+
+ # Define namespace dictionary for XML parsing
+ ns = {"atom": "http://www.w3.org/2005/Atom", "arxiv": "http://arxiv.org/schemas/atom"}
+
+ papers = []
+ # Process each entry (paper)
+ for entry in root.findall("atom:entry", ns):
+ paper = {
+ "id": self._get_text(entry, "atom:id", ns),
+ "title": self._get_text(entry, "atom:title", ns),
+ "summary": self._get_text(entry, "atom:summary", ns),
+ "published": self._get_text(entry, "atom:published", ns),
+ "updated": self._get_text(entry, "atom:updated", ns),
+ "authors": [author.find("atom:name", ns).text for author in entry.findall("atom:author", ns)],
+ "arxiv_url": self._get_link(entry, "alternate", ns),
+ "pdf_url": self._get_link(entry, "related", ns),
+ "comment": self._get_text(entry, "arxiv:comment", ns),
+ "journal_ref": self._get_text(entry, "arxiv:journal_ref", ns),
+ "primary_category": self._get_category(entry, ns),
+ "categories": [cat.get("term") for cat in entry.findall("atom:category", ns)],
+ }
+ papers.append(paper)
+
+ return papers
+
+ def _get_text(self, element: Element, path: str, ns: dict) -> str | None:
+ """Safely extract text from an XML element."""
+ el = element.find(path, ns)
+ return el.text.strip() if el is not None and el.text else None
+
+ def _get_link(self, element: Element, rel: str, ns: dict) -> str | None:
+ """Get link URL based on relation type."""
+ for link in element.findall("atom:link", ns):
+ if link.get("rel") == rel:
+ return link.get("href")
+ return None
+
+ def _get_category(self, element: Element, ns: dict) -> str | None:
+ """Get primary category."""
+ cat = element.find("arxiv:primary_category", ns)
+ return cat.get("term") if cat is not None else None
+
+ def search_papers(self) -> list[Data]:
+ """Search arXiv and return results."""
+ try:
+ # Build the query URL
+ url = self.build_query_url()
+
+ # Validate URL scheme and host
+ parsed_url = urlparse(url)
+ if parsed_url.scheme not in ("http", "https"):
+ error_msg = f"Invalid URL scheme: {parsed_url.scheme}"
+ raise ValueError(error_msg)
+ if parsed_url.hostname != "export.arxiv.org":
+ error_msg = f"Invalid host: {parsed_url.hostname}"
+ raise ValueError(error_msg)
+
+ # Create a custom opener that only allows http/https schemes
+ class RestrictedHTTPHandler(urllib.request.HTTPHandler):
+ def http_open(self, req):
+ return super().http_open(req)
+
+ class RestrictedHTTPSHandler(urllib.request.HTTPSHandler):
+ def https_open(self, req):
+ return super().https_open(req)
+
+ # Build opener with restricted handlers
+ opener = urllib.request.build_opener(RestrictedHTTPHandler, RestrictedHTTPSHandler)
+ urllib.request.install_opener(opener)
+
+ # Make the request with validated URL using restricted opener
+ response = opener.open(url)
+ response_text = response.read().decode("utf-8")
+
+ # Parse the response
+ papers = self.parse_atom_response(response_text)
+
+ # Convert to Data objects
+ results = [Data(data=paper) for paper in papers]
+ self.status = results
+ except (urllib.error.URLError, ValueError) as e:
+ error_data = Data(data={"error": f"Request error: {e!s}"})
+ self.status = error_data
+ return [error_data]
+ else:
+ return results
diff --git a/src/backend/base/langflow/initial_setup/starter_projects/Blog Writer.json b/src/backend/base/langflow/initial_setup/starter_projects/Blog Writer.json
index 31c0c9a9a..923b86f95 100644
--- a/src/backend/base/langflow/initial_setup/starter_projects/Blog Writer.json
+++ b/src/backend/base/langflow/initial_setup/starter_projects/Blog Writer.json
@@ -222,6 +222,7 @@
"show": true,
"title_case": false,
"type": "code",
+
"value": "import re\n\nfrom langchain_community.document_loaders import AsyncHtmlLoader, WebBaseLoader\n\nfrom langflow.custom import Component\nfrom langflow.helpers.data import data_to_text\nfrom langflow.io import DropdownInput, MessageTextInput, Output\nfrom langflow.schema import Data\nfrom langflow.schema.dataframe import DataFrame\nfrom langflow.schema.message import Message\n\n\nclass URLComponent(Component):\n display_name = \"URL\"\n description = \"Load and retrive data from specified URLs.\"\n icon = \"layout-template\"\n name = \"URL\"\n\n inputs = [\n MessageTextInput(\n name=\"urls\",\n display_name=\"URLs\",\n is_list=True,\n tool_mode=True,\n placeholder=\"Enter a URL...\",\n list_add_label=\"Add URL\",\n ),\n DropdownInput(\n name=\"format\",\n display_name=\"Output Format\",\n info=\"Output Format. Use 'Text' to extract the text from the HTML or 'Raw HTML' for the raw HTML content.\",\n options=[\"Text\", \"Raw HTML\"],\n value=\"Text\",\n ),\n ]\n\n outputs = [\n Output(display_name=\"Data\", name=\"data\", method=\"fetch_content\"),\n Output(display_name=\"Message\", name=\"text\", method=\"fetch_content_text\"),\n Output(display_name=\"DataFrame\", name=\"dataframe\", method=\"as_dataframe\"),\n ]\n\n def ensure_url(self, string: str) -> str:\n \"\"\"Ensures the given string is a URL by adding 'http://' if it doesn't start with 'http://' or 'https://'.\n\n Raises an error if the string is not a valid URL.\n\n Parameters:\n string (str): The string to be checked and possibly modified.\n\n Returns:\n str: The modified string that is ensured to be a URL.\n\n Raises:\n ValueError: If the string is not a valid URL.\n \"\"\"\n if not string.startswith((\"http://\", \"https://\")):\n string = \"http://\" + string\n\n # Basic URL validation regex\n url_regex = re.compile(\n r\"^(https?:\\/\\/)?\" # optional protocol\n r\"(www\\.)?\" # optional www\n r\"([a-zA-Z0-9.-]+)\" # domain\n r\"(\\.[a-zA-Z]{2,})?\" # top-level domain\n r\"(:\\d+)?\" # optional port\n r\"(\\/[^\\s]*)?$\", # optional path\n re.IGNORECASE,\n )\n\n if not url_regex.match(string):\n msg = f\"Invalid URL: {string}\"\n raise ValueError(msg)\n\n return string\n\n def fetch_content(self) -> list[Data]:\n urls = [self.ensure_url(url.strip()) for url in self.urls if url.strip()]\n if self.format == \"Raw HTML\":\n loader = AsyncHtmlLoader(web_path=urls, encoding=\"utf-8\")\n else:\n loader = WebBaseLoader(web_paths=urls, encoding=\"utf-8\")\n docs = loader.load()\n data = [Data(text=doc.page_content, **doc.metadata) for doc in docs]\n self.status = data\n return data\n\n def fetch_content_text(self) -> Message:\n data = self.fetch_content()\n\n result_string = data_to_text(\"{text}\", data)\n self.status = result_string\n return Message(text=result_string)\n\n def as_dataframe(self) -> DataFrame:\n return DataFrame(self.fetch_content())\n"
},
"format": {
diff --git a/src/backend/tests/unit/components/tools/test_arxiv_component.py b/src/backend/tests/unit/components/tools/test_arxiv_component.py
new file mode 100644
index 000000000..4ecfb958c
--- /dev/null
+++ b/src/backend/tests/unit/components/tools/test_arxiv_component.py
@@ -0,0 +1,124 @@
+from unittest.mock import patch
+
+import pytest
+
+from tests.base import ComponentTestBaseWithClient
+
+
+class TestArXivComponent(ComponentTestBaseWithClient):
+ def test_component_versions(self, default_kwargs, file_names_mapping):
+ """Test component compatibility across versions."""
+ from langflow.components.tools.arxiv import ArXivComponent
+
+ # Test current version
+ component = ArXivComponent(**default_kwargs)
+ frontend_node = component.to_frontend_node()
+ assert frontend_node is not None
+
+ # Test backward compatibility
+ for mapping in file_names_mapping:
+ try:
+ module = __import__(
+ f"langflow.components.{mapping['module']}",
+ fromlist=[mapping["file_name"]],
+ )
+ component_class = getattr(module, mapping["file_name"])
+ component = component_class(**default_kwargs)
+ frontend_node = component.to_frontend_node()
+ assert frontend_node is not None
+ except (ImportError, AttributeError) as e:
+ pytest.fail(f"Failed to load component version {mapping['version']}: {e!s}")
+
+ @pytest.fixture
+ def component_class(self):
+ from langflow.components.tools.arxiv import ArXivComponent
+
+ return ArXivComponent
+
+ @pytest.fixture
+ def default_kwargs(self):
+ return {
+ "search_query": "quantum computing",
+ "search_type": "all",
+ "max_results": 10,
+ "_session_id": "test-session",
+ }
+
+ @pytest.fixture
+ def file_names_mapping(self):
+ return []
+
+ def test_component_initialization(self, component_class, default_kwargs):
+ # Arrange
+ component = component_class(**default_kwargs)
+
+ # Act
+ frontend_node = component.to_frontend_node()
+
+ # Assert
+ node_data = frontend_node["data"]["node"]
+ assert node_data["template"]["search_query"]["value"] == "quantum computing"
+ assert node_data["template"]["search_type"]["value"] == "all"
+ assert node_data["template"]["max_results"]["value"] == 10
+
+ def test_build_query_url(self, component_class, default_kwargs):
+ # Arrange
+ component = component_class(**default_kwargs)
+
+ # Act
+ url = component.build_query_url()
+
+ # Assert
+ assert "http://export.arxiv.org/api/query?" in url
+ assert "search_query=all%3Aquantum%20computing" in url
+ assert "max_results=10" in url
+
+ def test_parse_atom_response(self, component_class, default_kwargs):
+ # Arrange
+ component = component_class(**default_kwargs)
+ sample_xml = """
+
+ http://arxiv.org/abs/quant-ph/0000001
+ Test Paper
+ Test summary
+ 2023-01-01
+ 2023-01-01
+ Test Author
+
+
+
+ Test comment
+ Test Journal
+
+
+ """.replace("<", "<").replace(">", ">")
+
+ # Act
+ papers = component.parse_atom_response(sample_xml)
+
+ # Assert
+ assert len(papers) == 1
+ paper = papers[0]
+ assert paper["title"] == "Test Paper"
+ assert paper["summary"] == "Test summary"
+ assert paper["authors"] == ["Test Author"]
+ assert paper["arxiv_url"] == "http://arxiv.org/abs/quant-ph/0000001"
+ assert paper["pdf_url"] == "http://arxiv.org/pdf/quant-ph/0000001"
+ assert paper["comment"] == "Test comment"
+ assert paper["journal_ref"] == "Test Journal"
+ assert paper["primary_category"] == "quant-ph"
+
+ @patch("urllib.request.build_opener")
+ def test_invalid_url_handling(self, mock_build_opener, component_class, default_kwargs):
+ # Arrange
+ component = component_class(**default_kwargs)
+ mock_build_opener.return_value.open.side_effect = ValueError("Invalid URL")
+
+ # Act
+ results = component.search_papers()
+
+ # Assert
+ assert len(results) == 1
+ assert hasattr(results[0], "error")
+ assert "Invalid URL" in results[0].error
diff --git a/src/frontend/src/icons/ArXiv/ArXivIcon.jsx b/src/frontend/src/icons/ArXiv/ArXivIcon.jsx
new file mode 100644
index 000000000..774bcd6b0
--- /dev/null
+++ b/src/frontend/src/icons/ArXiv/ArXivIcon.jsx
@@ -0,0 +1,122 @@
+import React from "react";
+
+const ArXivIcon = (props) => {
+ return (
+
+ );
+};
+
+export default ArXivIcon;
diff --git a/src/frontend/src/icons/ArXiv/arxiv.svg b/src/frontend/src/icons/ArXiv/arxiv.svg
new file mode 100644
index 000000000..e1c03366b
--- /dev/null
+++ b/src/frontend/src/icons/ArXiv/arxiv.svg
@@ -0,0 +1,24 @@
+
+
diff --git a/src/frontend/src/icons/ArXiv/index.tsx b/src/frontend/src/icons/ArXiv/index.tsx
new file mode 100644
index 000000000..2d004361c
--- /dev/null
+++ b/src/frontend/src/icons/ArXiv/index.tsx
@@ -0,0 +1,10 @@
+import React, { forwardRef } from "react";
+import SvgArXivIcon from "./ArXivIcon";
+
+export const ArXivIcon = forwardRef>(
+ (props, ref) => {
+ return ;
+ },
+);
+
+export default ArXivIcon;
diff --git a/src/frontend/src/utils/styleUtils.ts b/src/frontend/src/utils/styleUtils.ts
index b82ea6903..8eeef4562 100644
--- a/src/frontend/src/utils/styleUtils.ts
+++ b/src/frontend/src/utils/styleUtils.ts
@@ -238,6 +238,7 @@ import { AWSIcon } from "../icons/AWS";
import { AgentQLIcon } from "../icons/AgentQL";
import { AirbyteIcon } from "../icons/Airbyte";
import { AnthropicIcon } from "../icons/Anthropic";
+import { ArXivIcon } from "../icons/ArXiv";
import { ArizeIcon } from "../icons/Arize";
import { AssemblyAIIcon } from "../icons/AssemblyAI";
import { AstraDBIcon } from "../icons/AstraDB";
@@ -625,6 +626,7 @@ export const nodeIconsLucide: iconsType = {
AmazonBedrockEmbeddings: AWSIcon,
Amazon: AWSIcon,
Anthropic: AnthropicIcon,
+ ArXiv: ArXivIcon,
ChatAnthropic: AnthropicIcon,
assemblyai: AssemblyAIIcon,
AgentQL: AgentQLIcon,