refactor: Enhance tools with enums and improved error handling (#4493)

* fix: Enhance extract_class_name function to identify Component subclasses

* Add TODO for improving Component inheritance check in validate.py

* Add YahooFinanceMethod enum and improve error handling in Yahoo Finance tool

- Introduced YahooFinanceMethod enum to standardize method options.
- Updated YahooFinanceSchema to use the new enum for method selection.
- Enhanced error handling by raising ToolException on data retrieval failure.
- Refactored method handling in _yahoo_finance_tool to use enum values.

* Enhance TavilySearchToolComponent with Enums and Improved Error Handling

- Introduced `TavilySearchDepth` and `TavilySearchTopic` enums for better type safety and clarity.
- Updated `TavilySearchSchema` to use enums for `search_depth` and `topic` fields.
- Added validation for enum values in `run_model` and `_tavily_search` methods.
- Improved error handling by raising `ToolException` for HTTP and unexpected errors.
- Updated dropdown inputs to use enum options directly.

* Add error handling and parameter flexibility to SerpAPI tool

- Introduced `ToolException` for improved error handling in SerpAPI searches.
- Added `SerpAPISchema` for structured search parameters.
- Modified `_build_wrapper` to accept dynamic parameters.
- Enhanced `search_func` to rebuild wrapper with new parameters and handle exceptions.

* feat: Enhance Glean Search API integration

Refactor the API wrapper and schema for better clarity and maintainability. Improve error handling for search results and streamline request preparation.

* Add error handling to DuckDuckGo search function using ToolException

---------

Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-11-12 08:59:32 -03:00 committed by GitHub
commit 7dfce1dc63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 261 additions and 169 deletions

View file

@ -2,6 +2,7 @@ from typing import Any
from langchain.tools import StructuredTool
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import ToolException
from pydantic import BaseModel, Field
from langflow.base.langchain_utilities.model import LCToolComponent
@ -38,14 +39,18 @@ class DuckDuckGoSearchComponent(LCToolComponent):
wrapper = self._build_wrapper()
def search_func(query: str, max_results: int = 5, max_snippet_length: int = 100) -> list[dict[str, Any]]:
full_results = wrapper.run(f"{query} (site:*)")
result_list = full_results.split("\n")[:max_results]
limited_results = []
for result in result_list:
limited_result = {
"snippet": result[:max_snippet_length],
}
limited_results.append(limited_result)
try:
full_results = wrapper.run(f"{query} (site:*)")
result_list = full_results.split("\n")[:max_results]
limited_results = []
for result in result_list:
limited_result = {
"snippet": result[:max_snippet_length],
}
limited_results.append(limited_result)
except Exception as e:
msg = f"Error in DuckDuckGo Search: {e!s}"
raise ToolException(msg) from e
return limited_results
tool = StructuredTool.from_function(
@ -67,5 +72,5 @@ class DuckDuckGoSearchComponent(LCToolComponent):
}
)
data_list = [Data(data=result, text=result.get("snippet", "")) for result in results]
self.status = data_list
self.status = data_list # type: ignore[assignment]
return data_list

View file

@ -3,8 +3,8 @@ from typing import Any
from urllib.parse import urljoin
import httpx
from langchain.tools import StructuredTool
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import StructuredTool, ToolException
from pydantic import BaseModel
from pydantic.v1 import Field
from langflow.base.langchain_utilities.model import LCToolComponent
@ -13,6 +13,89 @@ from langflow.inputs import IntInput, MultilineInput, NestedDictInput, SecretStr
from langflow.schema import Data
class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")
class GleanAPIWrapper(BaseModel):
"""Wrapper around Glean API."""
glean_api_url: str
glean_access_token: str
act_as: str = "langflow-component@datastax.com" # TODO: Detect this
def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"
return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}
def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)
if len(results) == 0:
msg = "No good Glean Search Result was found"
raise AssertionError(msg)
return results
def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
try:
results = self.results(query, **kwargs)
processed_results = []
for result in results:
if "title" in result:
result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}])
if "text" not in result["snippets"][0]:
result["snippets"][0]["text"] = result["title"]
processed_results.append(result)
except Exception as e:
error_message = f"Error in Glean Search API: {e!s}"
raise ToolException(error_message) from e
return processed_results
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)
response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)
response.raise_for_status()
response_json = response.json()
return response_json.get("results", [])
@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)
class GleanSearchAPIComponent(LCToolComponent):
display_name = "Glean Search API"
description = "Call Glean Search API"
@ -30,83 +113,6 @@ class GleanSearchAPIComponent(LCToolComponent):
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]
class GleanAPIWrapper(BaseModel):
"""Wrapper around Glean API."""
glean_api_url: str
glean_access_token: str
act_as: str = "langflow-component@datastax.com" # TODO: Detect this
def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"
return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}
def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)
if len(results) == 0:
msg = "No good Glean Search Result was found"
raise AssertionError(msg)
return results
def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self.results(query, **kwargs)
processed_results = []
for result in results:
if "title" in result:
result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}])
if "text" not in result["snippets"][0]:
result["snippets"][0]["text"] = result["title"]
processed_results.append(result)
return processed_results
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)
response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)
response.raise_for_status()
response_json = response.json()
return response_json.get("results", [])
@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)
class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")
def build_tool(self) -> Tool:
wrapper = self._build_wrapper(
glean_api_url=self.glean_api_url,
@ -117,7 +123,7 @@ class GleanSearchAPIComponent(LCToolComponent):
name="glean_search_api",
description="Search Glean for relevant results.",
func=wrapper.run,
args_schema=self.GleanSearchAPISchema,
args_schema=GleanSearchAPISchema,
)
self.status = "Glean Search API Tool for Langchain"
@ -137,7 +143,7 @@ class GleanSearchAPIComponent(LCToolComponent):
# Build the data
data = [Data(data=result, text=result["snippets"][0]["text"]) for result in results]
self.status = data
self.status = data # type: ignore[assignment]
return data
@ -146,7 +152,7 @@ class GleanSearchAPIComponent(LCToolComponent):
glean_api_url: str,
glean_access_token: str,
):
return self.GleanAPIWrapper(
return GleanAPIWrapper(
glean_api_url=glean_api_url,
glean_access_token=glean_access_token,
)

View file

@ -2,6 +2,7 @@ from typing import Any
from langchain.tools import StructuredTool
from langchain_community.utilities.serpapi import SerpAPIWrapper
from langchain_core.tools import ToolException
from loguru import logger
from pydantic import BaseModel, Field
@ -11,6 +12,23 @@ from langflow.inputs import DictInput, IntInput, MultilineInput, SecretStrInput
from langflow.schema import Data
class SerpAPISchema(BaseModel):
"""Schema for SerpAPI search parameters."""
query: str = Field(..., description="The search query")
params: dict[str, Any] | None = Field(
default={
"engine": "google",
"google_domain": "google.com",
"gl": "us",
"hl": "en",
},
description="Additional search parameters",
)
max_results: int = Field(5, description="Maximum number of results to return")
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")
class SerpAPIComponent(LCToolComponent):
display_name = "Serp Search API"
description = "Call Serp Search API with result limiting"
@ -27,46 +45,50 @@ class SerpAPIComponent(LCToolComponent):
IntInput(name="max_snippet_length", display_name="Max Snippet Length", value=100, advanced=True),
]
class SerpAPISchema(BaseModel):
query: str = Field(..., description="The search query")
params: dict[str, Any] | None = Field(default_factory=dict, description="Additional search parameters")
max_results: int = Field(5, description="Maximum number of results to return")
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")
def _build_wrapper(self) -> SerpAPIWrapper:
if self.search_params:
def _build_wrapper(self, params: dict[str, Any] | None = None) -> SerpAPIWrapper:
"""Build a SerpAPIWrapper with the provided parameters."""
params = params or {}
if params:
return SerpAPIWrapper(
serpapi_api_key=self.serpapi_api_key,
params=self.search_params,
params=params,
)
return SerpAPIWrapper(serpapi_api_key=self.serpapi_api_key)
def build_tool(self) -> Tool:
wrapper = self._build_wrapper()
wrapper = self._build_wrapper(self.search_params) # noqa: F841
def search_func(
query: str, params: dict[str, Any] | None = None, max_results: int = 5, max_snippet_length: int = 100
) -> list[dict[str, Any]]:
params = params or {}
full_results = wrapper.results(query, **params)
organic_results = full_results.get("organic_results", [])[:max_results]
try:
# rebuild the wrapper if params are provided
if params:
wrapper = self._build_wrapper(params)
limited_results = []
for result in organic_results:
limited_result = {
"title": result.get("title", "")[:max_snippet_length],
"link": result.get("link", ""),
"snippet": result.get("snippet", "")[:max_snippet_length],
}
limited_results.append(limited_result)
full_results = wrapper.results(query)
organic_results = full_results.get("organic_results", [])[:max_results]
limited_results = []
for result in organic_results:
limited_result = {
"title": result.get("title", "")[:max_snippet_length],
"link": result.get("link", ""),
"snippet": result.get("snippet", "")[:max_snippet_length],
}
limited_results.append(limited_result)
except Exception as e:
error_message = f"Error in SerpAPI search: {e!s}"
logger.debug(error_message)
raise ToolException(error_message) from e
return limited_results
tool = StructuredTool.from_function(
name="serp_search_api",
description="Search for recent results using SerpAPI with result limiting",
func=search_func,
args_schema=self.SerpAPISchema,
args_schema=SerpAPISchema,
)
self.status = "SerpAPI Tool created"
@ -91,5 +113,5 @@ class SerpAPIComponent(LCToolComponent):
self.status = f"Error: {e}"
return [Data(data={"error": str(e)}, text=str(e))]
self.status = data_list
self.status = data_list # type: ignore[assignment]
return data_list

View file

@ -1,7 +1,8 @@
from typing import Any
from enum import Enum
import httpx
from langchain.tools import StructuredTool
from langchain_core.tools import ToolException
from loguru import logger
from pydantic import BaseModel, Field
@ -11,6 +12,25 @@ from langflow.inputs import BoolInput, DropdownInput, IntInput, MessageTextInput
from langflow.schema import Data
class TavilySearchDepth(Enum):
BASIC = "basic"
ADVANCED = "advanced"
class TavilySearchTopic(Enum):
GENERAL = "general"
NEWS = "news"
class TavilySearchSchema(BaseModel):
query: str = Field(..., description="The search query you want to execute with Tavily.")
search_depth: TavilySearchDepth = Field(TavilySearchDepth.BASIC, description="The depth of the search.")
topic: TavilySearchTopic = Field(TavilySearchTopic.GENERAL, description="The category of the search.")
max_results: int = Field(5, description="The maximum number of search results to return.")
include_images: bool = Field(default=False, description="Include a list of query-related images in the response.")
include_answer: bool = Field(default=False, description="Include a short answer to original query.")
class TavilySearchToolComponent(LCToolComponent):
display_name = "Tavily AI Search"
description = """**Tavily AI** is a search engine optimized for LLMs and RAG, \
@ -38,16 +58,16 @@ Note: Check 'Advanced' for all options.
name="search_depth",
display_name="Search Depth",
info="The depth of the search.",
options=["basic", "advanced"],
value="advanced",
options=list(TavilySearchDepth),
value=TavilySearchDepth.ADVANCED,
advanced=True,
),
DropdownInput(
name="topic",
display_name="Search Topic",
info="The category of the search.",
options=["general", "news"],
value="general",
options=list(TavilySearchTopic),
value=TavilySearchTopic.GENERAL,
advanced=True,
),
IntInput(
@ -73,21 +93,32 @@ Note: Check 'Advanced' for all options.
),
]
class TavilySearchSchema(BaseModel):
query: str = Field(..., description="The search query you want to execute with Tavily.")
search_depth: str = Field("basic", description="The depth of the search.")
topic: str = Field("general", description="The category of the search.")
max_results: int = Field(5, description="The maximum number of search results to return.")
include_images: bool = Field(
default=False, description="Include a list of query-related images in the response."
)
include_answer: bool = Field(default=False, description="Include a short answer to original query.")
def run_model(self) -> list[Data]:
# Convert string values to enum instances with validation
try:
search_depth_enum = (
self.search_depth
if isinstance(self.search_depth, TavilySearchDepth)
else TavilySearchDepth(str(self.search_depth).lower())
)
except ValueError as e:
error_message = f"Invalid search depth value: {e!s}"
self.status = error_message
return [Data(data={"error": error_message})]
try:
topic_enum = (
self.topic if isinstance(self.topic, TavilySearchTopic) else TavilySearchTopic(str(self.topic).lower())
)
except ValueError as e:
error_message = f"Invalid topic value: {e!s}"
self.status = error_message
return [Data(data={"error": error_message})]
return self._tavily_search(
self.query,
search_depth=self.search_depth,
topic=self.topic,
search_depth=search_depth_enum,
topic=topic_enum,
max_results=self.max_results,
include_images=self.include_images,
include_answer=self.include_answer,
@ -98,19 +129,27 @@ Note: Check 'Advanced' for all options.
name="tavily_search",
description="Perform a web search using the Tavily API.",
func=self._tavily_search,
args_schema=self.TavilySearchSchema,
args_schema=TavilySearchSchema,
)
def _tavily_search(
self,
query: str,
*,
search_depth: str = "basic",
topic: str = "general",
search_depth: TavilySearchDepth = TavilySearchDepth.BASIC,
topic: TavilySearchTopic = TavilySearchTopic.GENERAL,
max_results: int = 5,
include_images: bool = False,
include_answer: bool = False,
) -> list[Data]:
# Validate enum values
if not isinstance(search_depth, TavilySearchDepth):
msg = f"Invalid search_depth value: {search_depth}"
raise TypeError(msg)
if not isinstance(topic, TavilySearchTopic):
msg = f"Invalid topic value: {topic}"
raise TypeError(msg)
try:
url = "https://api.tavily.com/search"
headers = {
@ -120,8 +159,8 @@ Note: Check 'Advanced' for all options.
payload = {
"api_key": self.api_key,
"query": query,
"search_depth": search_depth,
"topic": topic,
"search_depth": search_depth.value,
"topic": topic.value,
"max_results": max_results,
"include_images": include_images,
"include_answer": include_answer,
@ -151,15 +190,16 @@ Note: Check 'Advanced' for all options.
if include_images and search_results.get("images"):
data_results.append(Data(data={"images": search_results["images"]}))
self.status = data_results # type: ignore[assignment]
except httpx.HTTPStatusError as e:
error_message = f"HTTP error: {e.response.status_code} - {e.response.text}"
logger.debug(error_message)
self.status = error_message
return [Data(data={"error": error_message})]
except Exception as e: # noqa: BLE001
logger.opt(exception=True).debug("Error running Tavily Search")
raise ToolException(error_message) from e
except Exception as e:
error_message = f"Unexpected error: {e}"
logger.opt(exception=True).debug("Error running Tavily Search")
self.status = error_message
return [Data(data={"error": error_message})]
self.status: Any = data_results
raise ToolException(error_message) from e
return data_results

View file

@ -1,8 +1,10 @@
import ast
import pprint
from enum import Enum
import yfinance as yf
from langchain.tools import StructuredTool
from langchain_core.tools import ToolException
from loguru import logger
from pydantic import BaseModel, Field
@ -12,6 +14,40 @@ from langflow.inputs import DropdownInput, IntInput, MessageTextInput
from langflow.schema import Data
class YahooFinanceMethod(Enum):
GET_INFO = "get_info"
GET_NEWS = "get_news"
GET_ACTIONS = "get_actions"
GET_ANALYSIS = "get_analysis"
GET_BALANCE_SHEET = "get_balance_sheet"
GET_CALENDAR = "get_calendar"
GET_CASHFLOW = "get_cashflow"
GET_INSTITUTIONAL_HOLDERS = "get_institutional_holders"
GET_RECOMMENDATIONS = "get_recommendations"
GET_SUSTAINABILITY = "get_sustainability"
GET_MAJOR_HOLDERS = "get_major_holders"
GET_MUTUALFUND_HOLDERS = "get_mutualfund_holders"
GET_INSIDER_PURCHASES = "get_insider_purchases"
GET_INSIDER_TRANSACTIONS = "get_insider_transactions"
GET_INSIDER_ROSTER_HOLDERS = "get_insider_roster_holders"
GET_DIVIDENDS = "get_dividends"
GET_CAPITAL_GAINS = "get_capital_gains"
GET_SPLITS = "get_splits"
GET_SHARES = "get_shares"
GET_FAST_INFO = "get_fast_info"
GET_SEC_FILINGS = "get_sec_filings"
GET_RECOMMENDATIONS_SUMMARY = "get_recommendations_summary"
GET_UPGRADES_DOWNGRADES = "get_upgrades_downgrades"
GET_EARNINGS = "get_earnings"
GET_INCOME_STMT = "get_income_stmt"
class YahooFinanceSchema(BaseModel):
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
method: YahooFinanceMethod = Field(YahooFinanceMethod.GET_INFO, description="The type of data to retrieve.")
num_news: int | None = Field(5, description="The number of news articles to retrieve.")
class YfinanceToolComponent(LCToolComponent):
display_name = "Yahoo Finance"
description = "Access financial data and market information using Yahoo Finance."
@ -23,24 +59,12 @@ class YfinanceToolComponent(LCToolComponent):
name="symbol",
display_name="Stock Symbol",
info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).",
required=True,
),
DropdownInput(
name="method",
display_name="Data Method",
info="The type of data to retrieve.",
options=[
"get_actions",
"get_analysis",
"get_balance_sheet",
"get_calendar",
"get_cashflow",
"get_info",
"get_institutional_holders",
"get_news",
"get_recommendations",
"get_sustainability",
],
options=list(YahooFinanceMethod),
value="get_news",
),
IntInput(
@ -51,11 +75,6 @@ class YfinanceToolComponent(LCToolComponent):
),
]
class YahooFinanceSchema(BaseModel):
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
method: str = Field("get_info", description="The type of data to retrieve.")
num_news: int = Field(5, description="The number of news articles to retrieve.")
def run_model(self) -> list[Data]:
return self._yahoo_finance_tool(
self.symbol,
@ -68,36 +87,36 @@ class YfinanceToolComponent(LCToolComponent):
name="yahoo_finance",
description="Access financial data and market information from Yahoo Finance.",
func=self._yahoo_finance_tool,
args_schema=self.YahooFinanceSchema,
args_schema=YahooFinanceSchema,
)
def _yahoo_finance_tool(
self,
symbol: str,
method: str,
method: YahooFinanceMethod,
num_news: int | None = 5,
) -> list[Data]:
ticker = yf.Ticker(symbol)
try:
if method == "get_info":
if method == YahooFinanceMethod.GET_INFO:
result = ticker.info
elif method == "get_news":
elif method == YahooFinanceMethod.GET_NEWS:
result = ticker.news[:num_news]
else:
result = getattr(ticker, method)()
result = getattr(ticker, method.value)()
result = pprint.pformat(result)
if method == "get_news":
if method == YahooFinanceMethod.GET_NEWS:
data_list = [Data(data=article) for article in ast.literal_eval(result)]
else:
data_list = [Data(data={"result": result})]
except Exception as e: # noqa: BLE001
except Exception as e:
error_message = f"Error retrieving data: {e}"
logger.opt(exception=True).debug(error_message)
logger.debug(error_message)
self.status = error_message
return [Data(data={"error": error_message})]
raise ToolException(error_message) from e
return data_list