refactor: use utility for BaseModel serialization and add SQL component tests (#8437)
* Update component_tool.py * Update test_component_toolkit.py * [autofix.ci] apply automated fixes * Update component_tool.py * [autofix.ci] apply automated fixes * Update component_tool.py * Update component_tool.py * [autofix.ci] apply automated fixes * fix tests * [autofix.ci] apply automated fixes * Update test_component_toolkit.py * Update test_component_toolkit.py * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ca23dc4c53
commit
7aee1bc1c3
2 changed files with 123 additions and 15 deletions
|
|
@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Literal
|
|||
import pandas as pd
|
||||
from langchain_core.tools import BaseTool, ToolException
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.base.tools.constants import TOOL_OUTPUT_NAME
|
||||
from langflow.io.schema import create_input_schema, create_input_schema_from_dict
|
||||
from langflow.schema.data import Data
|
||||
from langflow.schema.message import Message
|
||||
from langflow.serialization.serialization import serialize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
|
@ -26,7 +26,6 @@ if TYPE_CHECKING:
|
|||
from langflow.schema.content_block import ContentBlock
|
||||
from langflow.schema.dotdict import dotdict
|
||||
|
||||
|
||||
TOOL_TYPES_SET = {"Tool", "BaseTool", "StructuredTool"}
|
||||
|
||||
|
||||
|
|
@ -108,9 +107,8 @@ def _build_output_function(component: Component, output_method: Callable, event_
|
|||
return result.get_text()
|
||||
if isinstance(result, Data):
|
||||
return result.data
|
||||
if isinstance(result, BaseModel):
|
||||
return result.model_dump()
|
||||
return result
|
||||
# removing the model_dump() call here because it is not serializable
|
||||
return serialize(result)
|
||||
|
||||
return _patch_send_message_decorator(component, output_function)
|
||||
|
||||
|
|
@ -132,9 +130,8 @@ def _build_output_async_function(
|
|||
return result.get_text()
|
||||
if isinstance(result, Data):
|
||||
return result.data
|
||||
if isinstance(result, BaseModel):
|
||||
return result.model_dump()
|
||||
return result
|
||||
# removing the model_dump() call here because it is not serializable
|
||||
return serialize(result)
|
||||
|
||||
return _patch_send_message_decorator(component, output_function)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,103 @@
|
|||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langflow.base.tools.component_tool import ComponentToolkit
|
||||
from langflow.components.data.sql_executor import SQLComponent
|
||||
from langflow.components.input_output.chat_output import ChatOutput
|
||||
from langflow.components.langchain_utilities import ToolCallingAgentComponent
|
||||
from langflow.components.languagemodels import OpenAIModelComponent
|
||||
from langflow.components.tools.calculator import CalculatorToolComponent
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.schema.data import Data
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""Fixture that creates a temporary SQLite database for testing."""
|
||||
test_data_dir = Path(__file__).parent.parent.parent.parent / "data"
|
||||
db_path = test_data_dir / "test.db"
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
# Create students table
|
||||
cursor.execute("""
|
||||
CREATE TABLE students (
|
||||
id INTEGER PRIMARY KEY,
|
||||
first_name TEXT NOT NULL,
|
||||
last_name TEXT NOT NULL,
|
||||
age INTEGER,
|
||||
gpa REAL,
|
||||
major TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Create courses table
|
||||
cursor.execute("""
|
||||
CREATE TABLE courses (
|
||||
id INTEGER PRIMARY KEY,
|
||||
course_name TEXT NOT NULL,
|
||||
instructor TEXT,
|
||||
credits INTEGER
|
||||
)
|
||||
""")
|
||||
|
||||
# Create enrollment junction table
|
||||
cursor.execute("""
|
||||
CREATE TABLE enrollments (
|
||||
student_id INTEGER,
|
||||
course_id INTEGER,
|
||||
grade TEXT,
|
||||
PRIMARY KEY (student_id, course_id),
|
||||
FOREIGN KEY (student_id) REFERENCES students (id),
|
||||
FOREIGN KEY (course_id) REFERENCES courses (id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert sample student data
|
||||
students = [
|
||||
(1, "John", "Smith", 20, 3.5, "Computer Science"),
|
||||
(2, "Emma", "Johnson", 21, 3.8, "Mathematics"),
|
||||
(3, "Michael", "Williams", 19, 3.2, "Physics"),
|
||||
(4, "Olivia", "Brown", 22, 3.9, "Biology"),
|
||||
(5, "James", "Davis", 20, 3.1, "Chemistry"),
|
||||
]
|
||||
|
||||
cursor.executemany("INSERT INTO students VALUES (?, ?, ?, ?, ?, ?)", students)
|
||||
|
||||
# Insert sample course data
|
||||
courses = [
|
||||
(101, "Introduction to Programming", "Dr. Jones", 3),
|
||||
(102, "Calculus I", "Dr. Smith", 4),
|
||||
(103, "Physics 101", "Dr. Brown", 4),
|
||||
(104, "Biology Fundamentals", "Dr. Wilson", 3),
|
||||
(105, "Chemistry Basics", "Dr. Miller", 3),
|
||||
]
|
||||
|
||||
cursor.executemany("INSERT INTO courses VALUES (?, ?, ?, ?)", courses)
|
||||
|
||||
# Insert sample enrollment data
|
||||
enrollments = [
|
||||
(1, 101, "A"),
|
||||
(1, 102, "B+"),
|
||||
(2, 102, "A"),
|
||||
(2, 103, "A-"),
|
||||
(3, 103, "B"),
|
||||
(3, 105, "C+"),
|
||||
(4, 104, "A"),
|
||||
(5, 105, "B+"),
|
||||
]
|
||||
|
||||
cursor.executemany("INSERT INTO enrollments VALUES (?, ?, ?)", enrollments)
|
||||
|
||||
# Commit changes and close connection
|
||||
conn.commit()
|
||||
conn.close()
|
||||
yield str(db_path)
|
||||
|
||||
Path(db_path).unlink()
|
||||
|
||||
|
||||
def test_component_tool():
|
||||
calculator_component = CalculatorToolComponent()
|
||||
component_toolkit = ComponentToolkit(component=calculator_component)
|
||||
|
|
@ -29,9 +116,9 @@ def test_component_tool():
|
|||
assert component_toolkit.component == calculator_component
|
||||
|
||||
result = component_tool.invoke(input={"expression": "1+1"})
|
||||
assert isinstance(result[0], Data)
|
||||
assert "result" in result[0].data
|
||||
assert result[0].result == "2"
|
||||
assert isinstance(result[0], dict)
|
||||
assert "result" in result[0]["data"]
|
||||
assert result[0]["data"]["result"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
|
|
@ -41,10 +128,10 @@ async def test_component_tool_with_api_key():
|
|||
openai_llm = OpenAIModelComponent()
|
||||
openai_llm.set(api_key=os.environ["OPENAI_API_KEY"])
|
||||
tool_calling_agent = ToolCallingAgentComponent()
|
||||
|
||||
tools = await chat_output.to_toolkit()
|
||||
tool_calling_agent.set(
|
||||
llm=openai_llm.build_model,
|
||||
tools=[chat_output.to_toolkit],
|
||||
tools=list(tools),
|
||||
input_value="Which tools are available? Please tell its name.",
|
||||
)
|
||||
|
||||
|
|
@ -52,5 +139,29 @@ async def test_component_tool_with_api_key():
|
|||
g.session_id = "test"
|
||||
assert g is not None
|
||||
results = [result async for result in g.async_start()]
|
||||
assert len(results) == 4
|
||||
assert len(results) == 3
|
||||
assert "message_response" in tool_calling_agent._outputs_map["response"].value.get_text()
|
||||
|
||||
|
||||
@pytest.mark.api_key_required
|
||||
@pytest.mark.usefixtures("client")
|
||||
async def test_sql_component_to_toolkit(test_db):
|
||||
sql_component = SQLComponent()
|
||||
sql_component.set(database_url=f"sqlite:///{test_db}")
|
||||
tool = await sql_component.to_toolkit()
|
||||
openai_llm = OpenAIModelComponent()
|
||||
openai_llm.set(api_key=os.environ["OPENAI_API_KEY"])
|
||||
tool_calling_agent = ToolCallingAgentComponent()
|
||||
|
||||
tool_calling_agent.set(
|
||||
llm=openai_llm.build_model,
|
||||
tools=list(tool),
|
||||
input_value="run SELECT * FROM courses to get course details.",
|
||||
)
|
||||
|
||||
g = Graph(start=tool_calling_agent, end=tool_calling_agent)
|
||||
g.session_id = "test"
|
||||
assert g is not None
|
||||
results = [result async for result in g.async_start()]
|
||||
assert len(results) > 0
|
||||
assert "Physics 101" in tool_calling_agent._outputs_map["response"].value.get_text()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue