feat: update sql component to support toolkit (#7652)
* update sql component * cache * Update src/backend/base/langflow/components/data/sql_executor.py Co-authored-by: Edwin Jose <edwin.jose@datastax.com> * Update src/backend/base/langflow/components/data/sql_executor.py Co-authored-by: Edwin Jose <edwin.jose@datastax.com> * [autofix.ci] apply automated fixes * fix: sql query component (#7700) * feat: Sql toolkit tests (#7842) Co-authored-by: Edwin Jose <edwin.jose@datastax.com> --------- Co-authored-by: Edwin Jose <edwin.jose@datastax.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gustavo Costa <gsantosaero@gmail.com>
This commit is contained in:
parent
15993f6eff
commit
a04e5f56b8
3 changed files with 203 additions and 57 deletions
|
|
@ -3,7 +3,7 @@ from .csv_to_data import CSVToDataComponent
|
|||
from .directory import DirectoryComponent
|
||||
from .file import FileComponent
|
||||
from .json_to_data import JSONToDataComponent
|
||||
from .sql_executor import SQLExecutorComponent
|
||||
from .sql_executor import SQLComponent
|
||||
from .url import URLComponent
|
||||
from .webhook import WebhookComponent
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ __all__ = [
|
|||
"DirectoryComponent",
|
||||
"FileComponent",
|
||||
"JSONToDataComponent",
|
||||
"SQLExecutorComponent",
|
||||
"SQLComponent",
|
||||
"URLComponent",
|
||||
"WebhookComponent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,74 +1,105 @@
|
|||
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.custom.custom_component.component_with_cache import ComponentWithCache
|
||||
from langflow.io import BoolInput, MessageTextInput, Output
|
||||
from langflow.schema.data import Data
|
||||
from langflow.schema.dataframe import DataFrame
|
||||
from langflow.schema.message import Message
|
||||
from langflow.services.cache.utils import CacheMiss
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Result
|
||||
|
||||
|
||||
class SQLExecutorComponent(CustomComponent):
|
||||
class SQLComponent(ComponentWithCache):
|
||||
"""A sql component."""
|
||||
|
||||
display_name = "SQL Query"
|
||||
description = "Execute SQL query."
|
||||
name = "SQLExecutor"
|
||||
beta: bool = True
|
||||
description = "Execute SQL Query"
|
||||
icon = "database"
|
||||
name = "SQLComponent"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"database_url": {
|
||||
"display_name": "Database URL",
|
||||
"info": "The URL of the database.",
|
||||
},
|
||||
"include_columns": {
|
||||
"display_name": "Include Columns",
|
||||
"info": "Include columns in the result.",
|
||||
},
|
||||
"passthrough": {
|
||||
"display_name": "Passthrough",
|
||||
"info": "If an error occurs, return the query instead of raising an exception.",
|
||||
},
|
||||
"add_error": {
|
||||
"display_name": "Add Error",
|
||||
"info": "Add the error to the result.",
|
||||
},
|
||||
}
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.db: SQLDatabase = None
|
||||
|
||||
def clean_up_uri(self, uri: str) -> str:
|
||||
if uri.startswith("postgresql://"):
|
||||
uri = uri.replace("postgresql://", "postgres://")
|
||||
return uri.strip()
|
||||
def maybe_create_db(self):
|
||||
if self.database_url != "":
|
||||
cached_db = self._shared_component_cache.get(self.database_url)
|
||||
if not isinstance(cached_db, CacheMiss):
|
||||
self.db = cached_db
|
||||
return
|
||||
self.log("Connecting to database")
|
||||
try:
|
||||
self.db = SQLDatabase.from_uri(self.database_url)
|
||||
except Exception as e:
|
||||
msg = f"An error occurred while connecting to the database: {e}"
|
||||
raise ValueError(msg) from e
|
||||
self._shared_component_cache.set(self.database_url, self.db)
|
||||
|
||||
def build(
|
||||
inputs = [
|
||||
MessageTextInput(name="database_url", display_name="Database URL", required=True),
|
||||
MessageTextInput(name="query", display_name="SQL Query", tool_mode=True, required=True),
|
||||
BoolInput(name="include_columns", display_name="Include Columns", value=True, tool_mode=True),
|
||||
BoolInput(
|
||||
name="add_error",
|
||||
display_name="Add Error",
|
||||
value=False,
|
||||
tool_mode=True,
|
||||
info="If True, the error will be added to the result",
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Message", name="text", method="build_component"),
|
||||
Output(display_name="Data", name="data", method="build_data"),
|
||||
Output(display_name="DataFrame", name="dataframe", method="build_dataframe"),
|
||||
]
|
||||
|
||||
def build_component(
|
||||
self,
|
||||
query: str,
|
||||
database_url: str,
|
||||
*,
|
||||
include_columns: bool = False,
|
||||
passthrough: bool = False,
|
||||
add_error: bool = False,
|
||||
**kwargs,
|
||||
) -> Text:
|
||||
_ = kwargs
|
||||
) -> Message:
|
||||
error = None
|
||||
self.maybe_create_db()
|
||||
try:
|
||||
database = SQLDatabase.from_uri(database_url)
|
||||
except Exception as e:
|
||||
msg = f"An error occurred while connecting to the database: {e}"
|
||||
raise ValueError(msg) from e
|
||||
try:
|
||||
tool = QuerySQLDataBaseTool(db=database)
|
||||
result = tool.run(query, include_columns=include_columns)
|
||||
result = self.db.run(self.query, include_columns=self.include_columns)
|
||||
self.status = result
|
||||
except Exception as e:
|
||||
except SQLAlchemyError as e:
|
||||
msg = f"An error occurred while running the SQL Query: {e}"
|
||||
self.log(msg)
|
||||
result = str(e)
|
||||
self.status = result
|
||||
if not passthrough:
|
||||
raise
|
||||
error = repr(e)
|
||||
|
||||
if add_error and error is not None:
|
||||
result = f"{result}\n\nError: {error}\n\nQuery: {query}"
|
||||
if self.add_error and error is not None:
|
||||
result = f"{result}\n\nError: {error}\n\nQuery: {self.query}"
|
||||
elif error is not None:
|
||||
# Then we won't add the error to the result
|
||||
# but since we are in passthrough mode, we will return the query
|
||||
result = query
|
||||
result = self.query
|
||||
|
||||
return result
|
||||
return Message(text=result)
|
||||
|
||||
def __execute_query(self) -> list[dict[str, Any]]:
|
||||
self.maybe_create_db()
|
||||
try:
|
||||
cursor: Result[Any] = self.db.run(self.query, fetch="cursor")
|
||||
return [x._asdict() for x in cursor.fetchall()]
|
||||
except SQLAlchemyError as e:
|
||||
msg = f"An error occurred while running the SQL Query: {e}"
|
||||
self.log(msg)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
def build_dataframe(self) -> DataFrame:
|
||||
result = self.__execute_query()
|
||||
df_result = DataFrame(result)
|
||||
self.status = df_result
|
||||
return df_result
|
||||
|
||||
def build_data(self) -> Data:
|
||||
result = self.__execute_query()
|
||||
data_result = Data(data={"result": result})
|
||||
self.status = data_result
|
||||
return data_result
|
||||
|
|
|
|||
115
src/backend/tests/unit/components/data/test_sql_executor.py
Normal file
115
src/backend/tests/unit/components/data/test_sql_executor.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langflow.components.data.sql_executor import SQLComponent
|
||||
from langflow.schema import Data, DataFrame, Message
|
||||
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
|
||||
|
||||
class TestSQLComponent(ComponentTestBaseWithoutClient):
|
||||
@pytest.fixture
|
||||
def test_db(self):
|
||||
"""Fixture that creates a temporary SQLite database for testing."""
|
||||
test_data_dir = Path(__file__).parent.parent.parent.parent / "data"
|
||||
db_path = test_data_dir / "test.db"
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS test (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT
|
||||
)
|
||||
""")
|
||||
cursor.execute("""
|
||||
INSERT INTO test (id, name)
|
||||
VALUES (1, 'name_test')
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
yield str(db_path)
|
||||
|
||||
Path(db_path).unlink()
|
||||
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
"""Return the component class to test."""
|
||||
return SQLComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self, test_db):
|
||||
"""Return the default kwargs for the component."""
|
||||
return {
|
||||
"database_url": f"sqlite:///{test_db}",
|
||||
"query": "SELECT * FROM test",
|
||||
"include_columns": True,
|
||||
"add_error": False,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
"""Return an empty list since this component doesn't have version-specific files."""
|
||||
return []
|
||||
|
||||
def test_successful_query_with_columns(self, component_class: type[SQLComponent], default_kwargs):
|
||||
"""Test a successful SQL query with columns included."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
result = component.build_component()
|
||||
|
||||
assert isinstance(result, Message)
|
||||
assert isinstance(result.text, str)
|
||||
assert result.text == "[{'id': 1, 'name': 'name_test'}]"
|
||||
|
||||
def test_successful_query_without_columns(self, component_class: type[SQLComponent], default_kwargs):
|
||||
"""Test a successful SQL query without columns included."""
|
||||
default_kwargs["include_columns"] = False
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
result = component.build_component()
|
||||
|
||||
assert isinstance(result, Message)
|
||||
assert isinstance(result.text, str)
|
||||
assert result.text == "[(1, 'name_test')]"
|
||||
assert component.status == "[(1, 'name_test')]"
|
||||
assert component.query == "SELECT * FROM test"
|
||||
|
||||
def test_query_error_with_add_error(self, component_class: type[SQLComponent], default_kwargs):
|
||||
"""Test a SQL query that raises an error with add_error=True."""
|
||||
default_kwargs["add_error"] = True
|
||||
default_kwargs["query"] = "SELECT * FROM non_existent_table"
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
result = component.build_component()
|
||||
|
||||
assert isinstance(result, Message)
|
||||
assert isinstance(result.text, str)
|
||||
assert "no such table: non_existent_table" in result.text
|
||||
assert "Error:" in result.text
|
||||
assert "Query: SELECT * FROM non_existent_table" in result.text
|
||||
|
||||
def test_build_dataframe(self, component_class: type[SQLComponent], default_kwargs):
|
||||
"""Test building a DataFrame from a SQL query."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
result = component.build_dataframe()
|
||||
|
||||
assert isinstance(result, DataFrame)
|
||||
assert len(result) == 1
|
||||
assert "id" in result.columns
|
||||
assert "name" in result.columns
|
||||
assert result.iloc[0]["id"] == 1
|
||||
assert result.iloc[0]["name"] == "name_test"
|
||||
|
||||
def test_build_data(self, component_class: type[SQLComponent], default_kwargs):
|
||||
"""Test building a Data object from a SQL query."""
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
result = component.build_data()
|
||||
|
||||
assert isinstance(result, Data)
|
||||
assert "result" in result.data
|
||||
assert len(result.data["result"]) == 1
|
||||
assert result.data["result"][0]["id"] == 1
|
||||
assert result.data["result"][0]["name"] == "name_test"
|
||||
Loading…
Add table
Add a link
Reference in a new issue