From a04e5f56b8fad6436e2acdb57d83dfb0d4d26bdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Est=C3=A9vez?= Date: Tue, 6 May 2025 12:08:33 -0400 Subject: [PATCH] feat: update sql component to support toolkit (#7652) * update sql component * cache * Update src/backend/base/langflow/components/data/sql_executor.py Co-authored-by: Edwin Jose * Update src/backend/base/langflow/components/data/sql_executor.py Co-authored-by: Edwin Jose * [autofix.ci] apply automated fixes * fix: sql query component (#7700) * feat: Sql toolkit tests (#7842) Co-authored-by: Edwin Jose --------- Co-authored-by: Edwin Jose Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gustavo Costa --- .../base/langflow/components/data/__init__.py | 4 +- .../langflow/components/data/sql_executor.py | 141 +++++++++++------- .../unit/components/data/test_sql_executor.py | 115 ++++++++++++++ 3 files changed, 203 insertions(+), 57 deletions(-) create mode 100644 src/backend/tests/unit/components/data/test_sql_executor.py diff --git a/src/backend/base/langflow/components/data/__init__.py b/src/backend/base/langflow/components/data/__init__.py index 770713cca..acf3d8585 100644 --- a/src/backend/base/langflow/components/data/__init__.py +++ b/src/backend/base/langflow/components/data/__init__.py @@ -3,7 +3,7 @@ from .csv_to_data import CSVToDataComponent from .directory import DirectoryComponent from .file import FileComponent from .json_to_data import JSONToDataComponent -from .sql_executor import SQLExecutorComponent +from .sql_executor import SQLComponent from .url import URLComponent from .webhook import WebhookComponent @@ -13,7 +13,7 @@ __all__ = [ "DirectoryComponent", "FileComponent", "JSONToDataComponent", - "SQLExecutorComponent", + "SQLComponent", "URLComponent", "WebhookComponent", ] diff --git a/src/backend/base/langflow/components/data/sql_executor.py b/src/backend/base/langflow/components/data/sql_executor.py index add27a8b9..582d61cf5 100644 --- a/src/backend/base/langflow/components/data/sql_executor.py +++ b/src/backend/base/langflow/components/data/sql_executor.py @@ -1,74 +1,105 @@ -from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool +from typing import TYPE_CHECKING, Any + from langchain_community.utilities import SQLDatabase +from sqlalchemy.exc import SQLAlchemyError -from langflow.custom import CustomComponent -from langflow.field_typing import Text +from langflow.custom.custom_component.component_with_cache import ComponentWithCache +from langflow.io import BoolInput, MessageTextInput, Output +from langflow.schema.data import Data +from langflow.schema.dataframe import DataFrame +from langflow.schema.message import Message +from langflow.services.cache.utils import CacheMiss + +if TYPE_CHECKING: + from sqlalchemy.engine import Result -class SQLExecutorComponent(CustomComponent): +class SQLComponent(ComponentWithCache): + """A sql component.""" + display_name = "SQL Query" - description = "Execute SQL query." - name = "SQLExecutor" - beta: bool = True + description = "Execute SQL Query" + icon = "database" + name = "SQLComponent" - def build_config(self): - return { - "database_url": { - "display_name": "Database URL", - "info": "The URL of the database.", - }, - "include_columns": { - "display_name": "Include Columns", - "info": "Include columns in the result.", - }, - "passthrough": { - "display_name": "Passthrough", - "info": "If an error occurs, return the query instead of raising an exception.", - }, - "add_error": { - "display_name": "Add Error", - "info": "Add the error to the result.", - }, - } + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.db: SQLDatabase = None - def clean_up_uri(self, uri: str) -> str: - if uri.startswith("postgresql://"): - uri = uri.replace("postgresql://", "postgres://") - return uri.strip() + def maybe_create_db(self): + if self.database_url != "": + cached_db = self._shared_component_cache.get(self.database_url) + if not isinstance(cached_db, CacheMiss): + self.db = cached_db + return + self.log("Connecting to database") + try: + self.db = SQLDatabase.from_uri(self.database_url) + except Exception as e: + msg = f"An error occurred while connecting to the database: {e}" + raise ValueError(msg) from e + self._shared_component_cache.set(self.database_url, self.db) - def build( + inputs = [ + MessageTextInput(name="database_url", display_name="Database URL", required=True), + MessageTextInput(name="query", display_name="SQL Query", tool_mode=True, required=True), + BoolInput(name="include_columns", display_name="Include Columns", value=True, tool_mode=True), + BoolInput( + name="add_error", + display_name="Add Error", + value=False, + tool_mode=True, + info="If True, the error will be added to the result", + ), + ] + + outputs = [ + Output(display_name="Message", name="text", method="build_component"), + Output(display_name="Data", name="data", method="build_data"), + Output(display_name="DataFrame", name="dataframe", method="build_dataframe"), + ] + + def build_component( self, - query: str, - database_url: str, - *, - include_columns: bool = False, - passthrough: bool = False, - add_error: bool = False, - **kwargs, - ) -> Text: - _ = kwargs + ) -> Message: error = None + self.maybe_create_db() try: - database = SQLDatabase.from_uri(database_url) - except Exception as e: - msg = f"An error occurred while connecting to the database: {e}" - raise ValueError(msg) from e - try: - tool = QuerySQLDataBaseTool(db=database) - result = tool.run(query, include_columns=include_columns) + result = self.db.run(self.query, include_columns=self.include_columns) self.status = result - except Exception as e: + except SQLAlchemyError as e: + msg = f"An error occurred while running the SQL Query: {e}" + self.log(msg) result = str(e) self.status = result - if not passthrough: - raise error = repr(e) - if add_error and error is not None: - result = f"{result}\n\nError: {error}\n\nQuery: {query}" + if self.add_error and error is not None: + result = f"{result}\n\nError: {error}\n\nQuery: {self.query}" elif error is not None: # Then we won't add the error to the result - # but since we are in passthrough mode, we will return the query - result = query + result = self.query - return result + return Message(text=result) + + def __execute_query(self) -> list[dict[str, Any]]: + self.maybe_create_db() + try: + cursor: Result[Any] = self.db.run(self.query, fetch="cursor") + return [x._asdict() for x in cursor.fetchall()] + except SQLAlchemyError as e: + msg = f"An error occurred while running the SQL Query: {e}" + self.log(msg) + raise ValueError(msg) from e + + def build_dataframe(self) -> DataFrame: + result = self.__execute_query() + df_result = DataFrame(result) + self.status = df_result + return df_result + + def build_data(self) -> Data: + result = self.__execute_query() + data_result = Data(data={"result": result}) + self.status = data_result + return data_result diff --git a/src/backend/tests/unit/components/data/test_sql_executor.py b/src/backend/tests/unit/components/data/test_sql_executor.py new file mode 100644 index 000000000..14ee99020 --- /dev/null +++ b/src/backend/tests/unit/components/data/test_sql_executor.py @@ -0,0 +1,115 @@ +import sqlite3 +from pathlib import Path + +import pytest +from langflow.components.data.sql_executor import SQLComponent +from langflow.schema import Data, DataFrame, Message + +from tests.base import ComponentTestBaseWithoutClient + + +class TestSQLComponent(ComponentTestBaseWithoutClient): + @pytest.fixture + def test_db(self): + """Fixture that creates a temporary SQLite database for testing.""" + test_data_dir = Path(__file__).parent.parent.parent.parent / "data" + db_path = test_data_dir / "test.db" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS test ( + id INTEGER PRIMARY KEY, + name TEXT + ) + """) + cursor.execute(""" + INSERT INTO test (id, name) + VALUES (1, 'name_test') + """) + conn.commit() + conn.close() + yield str(db_path) + + Path(db_path).unlink() + + @pytest.fixture + def component_class(self): + """Return the component class to test.""" + return SQLComponent + + @pytest.fixture + def default_kwargs(self, test_db): + """Return the default kwargs for the component.""" + return { + "database_url": f"sqlite:///{test_db}", + "query": "SELECT * FROM test", + "include_columns": True, + "add_error": False, + } + + @pytest.fixture + def file_names_mapping(self): + """Return an empty list since this component doesn't have version-specific files.""" + return [] + + def test_successful_query_with_columns(self, component_class: type[SQLComponent], default_kwargs): + """Test a successful SQL query with columns included.""" + component = component_class(**default_kwargs) + + result = component.build_component() + + assert isinstance(result, Message) + assert isinstance(result.text, str) + assert result.text == "[{'id': 1, 'name': 'name_test'}]" + + def test_successful_query_without_columns(self, component_class: type[SQLComponent], default_kwargs): + """Test a successful SQL query without columns included.""" + default_kwargs["include_columns"] = False + component = component_class(**default_kwargs) + + result = component.build_component() + + assert isinstance(result, Message) + assert isinstance(result.text, str) + assert result.text == "[(1, 'name_test')]" + assert component.status == "[(1, 'name_test')]" + assert component.query == "SELECT * FROM test" + + def test_query_error_with_add_error(self, component_class: type[SQLComponent], default_kwargs): + """Test a SQL query that raises an error with add_error=True.""" + default_kwargs["add_error"] = True + default_kwargs["query"] = "SELECT * FROM non_existent_table" + component = component_class(**default_kwargs) + + result = component.build_component() + + assert isinstance(result, Message) + assert isinstance(result.text, str) + assert "no such table: non_existent_table" in result.text + assert "Error:" in result.text + assert "Query: SELECT * FROM non_existent_table" in result.text + + def test_build_dataframe(self, component_class: type[SQLComponent], default_kwargs): + """Test building a DataFrame from a SQL query.""" + component = component_class(**default_kwargs) + + result = component.build_dataframe() + + assert isinstance(result, DataFrame) + assert len(result) == 1 + assert "id" in result.columns + assert "name" in result.columns + assert result.iloc[0]["id"] == 1 + assert result.iloc[0]["name"] == "name_test" + + def test_build_data(self, component_class: type[SQLComponent], default_kwargs): + """Test building a Data object from a SQL query.""" + component = component_class(**default_kwargs) + + result = component.build_data() + + assert isinstance(result, Data) + assert "result" in result.data + assert len(result.data["result"]) == 1 + assert result.data["result"][0]["id"] == 1 + assert result.data["result"][0]["name"] == "name_test"