refactor: Enhance tools with enums and improved error handling (#4493)
* fix: Enhance extract_class_name function to identify Component subclasses * Add TODO for improving Component inheritance check in validate.py * Add YahooFinanceMethod enum and improve error handling in Yahoo Finance tool - Introduced YahooFinanceMethod enum to standardize method options. - Updated YahooFinanceSchema to use the new enum for method selection. - Enhanced error handling by raising ToolException on data retrieval failure. - Refactored method handling in _yahoo_finance_tool to use enum values. * Enhance TavilySearchToolComponent with Enums and Improved Error Handling - Introduced `TavilySearchDepth` and `TavilySearchTopic` enums for better type safety and clarity. - Updated `TavilySearchSchema` to use enums for `search_depth` and `topic` fields. - Added validation for enum values in `run_model` and `_tavily_search` methods. - Improved error handling by raising `ToolException` for HTTP and unexpected errors. - Updated dropdown inputs to use enum options directly. * Add error handling and parameter flexibility to SerpAPI tool - Introduced `ToolException` for improved error handling in SerpAPI searches. - Added `SerpAPISchema` for structured search parameters. - Modified `_build_wrapper` to accept dynamic parameters. - Enhanced `search_func` to rebuild wrapper with new parameters and handle exceptions. * feat: Enhance Glean Search API integration Refactor the API wrapper and schema for better clarity and maintainability. Improve error handling for search results and streamline request preparation. * Add error handling to DuckDuckGo search function using ToolException --------- Co-authored-by: Eric Hare <ericrhare@gmail.com>
This commit is contained in:
parent
bbaec2b8da
commit
7dfce1dc63
5 changed files with 261 additions and 169 deletions
|
|
@ -2,6 +2,7 @@ from typing import Any
|
|||
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_community.tools import DuckDuckGoSearchRun
|
||||
from langchain_core.tools import ToolException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langflow.base.langchain_utilities.model import LCToolComponent
|
||||
|
|
@ -38,14 +39,18 @@ class DuckDuckGoSearchComponent(LCToolComponent):
|
|||
wrapper = self._build_wrapper()
|
||||
|
||||
def search_func(query: str, max_results: int = 5, max_snippet_length: int = 100) -> list[dict[str, Any]]:
|
||||
full_results = wrapper.run(f"{query} (site:*)")
|
||||
result_list = full_results.split("\n")[:max_results]
|
||||
limited_results = []
|
||||
for result in result_list:
|
||||
limited_result = {
|
||||
"snippet": result[:max_snippet_length],
|
||||
}
|
||||
limited_results.append(limited_result)
|
||||
try:
|
||||
full_results = wrapper.run(f"{query} (site:*)")
|
||||
result_list = full_results.split("\n")[:max_results]
|
||||
limited_results = []
|
||||
for result in result_list:
|
||||
limited_result = {
|
||||
"snippet": result[:max_snippet_length],
|
||||
}
|
||||
limited_results.append(limited_result)
|
||||
except Exception as e:
|
||||
msg = f"Error in DuckDuckGo Search: {e!s}"
|
||||
raise ToolException(msg) from e
|
||||
return limited_results
|
||||
|
||||
tool = StructuredTool.from_function(
|
||||
|
|
@ -67,5 +72,5 @@ class DuckDuckGoSearchComponent(LCToolComponent):
|
|||
}
|
||||
)
|
||||
data_list = [Data(data=result, text=result.get("snippet", "")) for result in results]
|
||||
self.status = data_list
|
||||
self.status = data_list # type: ignore[assignment]
|
||||
return data_list
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from typing import Any
|
|||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.tools import StructuredTool, ToolException
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import Field
|
||||
|
||||
from langflow.base.langchain_utilities.model import LCToolComponent
|
||||
|
|
@ -13,6 +13,89 @@ from langflow.inputs import IntInput, MultilineInput, NestedDictInput, SecretStr
|
|||
from langflow.schema import Data
|
||||
|
||||
|
||||
class GleanSearchAPISchema(BaseModel):
|
||||
query: str = Field(..., description="The search query")
|
||||
page_size: int = Field(10, description="Maximum number of results to return")
|
||||
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")
|
||||
|
||||
|
||||
class GleanAPIWrapper(BaseModel):
|
||||
"""Wrapper around Glean API."""
|
||||
|
||||
glean_api_url: str
|
||||
glean_access_token: str
|
||||
act_as: str = "langflow-component@datastax.com" # TODO: Detect this
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
query: str,
|
||||
page_size: int = 10,
|
||||
request_options: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
# Ensure there's a trailing slash
|
||||
url = self.glean_api_url
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
return {
|
||||
"url": urljoin(url, "search"),
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {self.glean_access_token}",
|
||||
"X-Scio-ActAs": self.act_as,
|
||||
},
|
||||
"payload": {
|
||||
"query": query,
|
||||
"pageSize": page_size,
|
||||
"requestOptions": request_options,
|
||||
},
|
||||
}
|
||||
|
||||
def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
results = self._search_api_results(query, **kwargs)
|
||||
|
||||
if len(results) == 0:
|
||||
msg = "No good Glean Search Result was found"
|
||||
raise AssertionError(msg)
|
||||
|
||||
return results
|
||||
|
||||
def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
try:
|
||||
results = self.results(query, **kwargs)
|
||||
|
||||
processed_results = []
|
||||
for result in results:
|
||||
if "title" in result:
|
||||
result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}])
|
||||
if "text" not in result["snippets"][0]:
|
||||
result["snippets"][0]["text"] = result["title"]
|
||||
|
||||
processed_results.append(result)
|
||||
except Exception as e:
|
||||
error_message = f"Error in Glean Search API: {e!s}"
|
||||
raise ToolException(error_message) from e
|
||||
|
||||
return processed_results
|
||||
|
||||
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
request_details = self._prepare_request(query, **kwargs)
|
||||
|
||||
response = httpx.post(
|
||||
request_details["url"],
|
||||
json=request_details["payload"],
|
||||
headers=request_details["headers"],
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
return response_json.get("results", [])
|
||||
|
||||
@staticmethod
|
||||
def _result_as_string(result: dict) -> str:
|
||||
return json.dumps(result, indent=4)
|
||||
|
||||
|
||||
class GleanSearchAPIComponent(LCToolComponent):
|
||||
display_name = "Glean Search API"
|
||||
description = "Call Glean Search API"
|
||||
|
|
@ -30,83 +113,6 @@ class GleanSearchAPIComponent(LCToolComponent):
|
|||
NestedDictInput(name="request_options", display_name="Request Options", required=False),
|
||||
]
|
||||
|
||||
class GleanAPIWrapper(BaseModel):
|
||||
"""Wrapper around Glean API."""
|
||||
|
||||
glean_api_url: str
|
||||
glean_access_token: str
|
||||
act_as: str = "langflow-component@datastax.com" # TODO: Detect this
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
query: str,
|
||||
page_size: int = 10,
|
||||
request_options: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
# Ensure there's a trailing slash
|
||||
url = self.glean_api_url
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
return {
|
||||
"url": urljoin(url, "search"),
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {self.glean_access_token}",
|
||||
"X-Scio-ActAs": self.act_as,
|
||||
},
|
||||
"payload": {
|
||||
"query": query,
|
||||
"pageSize": page_size,
|
||||
"requestOptions": request_options,
|
||||
},
|
||||
}
|
||||
|
||||
def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
results = self._search_api_results(query, **kwargs)
|
||||
|
||||
if len(results) == 0:
|
||||
msg = "No good Glean Search Result was found"
|
||||
raise AssertionError(msg)
|
||||
|
||||
return results
|
||||
|
||||
def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
results = self.results(query, **kwargs)
|
||||
|
||||
processed_results = []
|
||||
for result in results:
|
||||
if "title" in result:
|
||||
result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}])
|
||||
if "text" not in result["snippets"][0]:
|
||||
result["snippets"][0]["text"] = result["title"]
|
||||
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
request_details = self._prepare_request(query, **kwargs)
|
||||
|
||||
response = httpx.post(
|
||||
request_details["url"],
|
||||
json=request_details["payload"],
|
||||
headers=request_details["headers"],
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
|
||||
return response_json.get("results", [])
|
||||
|
||||
@staticmethod
|
||||
def _result_as_string(result: dict) -> str:
|
||||
return json.dumps(result, indent=4)
|
||||
|
||||
class GleanSearchAPISchema(BaseModel):
|
||||
query: str = Field(..., description="The search query")
|
||||
page_size: int = Field(10, description="Maximum number of results to return")
|
||||
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")
|
||||
|
||||
def build_tool(self) -> Tool:
|
||||
wrapper = self._build_wrapper(
|
||||
glean_api_url=self.glean_api_url,
|
||||
|
|
@ -117,7 +123,7 @@ class GleanSearchAPIComponent(LCToolComponent):
|
|||
name="glean_search_api",
|
||||
description="Search Glean for relevant results.",
|
||||
func=wrapper.run,
|
||||
args_schema=self.GleanSearchAPISchema,
|
||||
args_schema=GleanSearchAPISchema,
|
||||
)
|
||||
|
||||
self.status = "Glean Search API Tool for Langchain"
|
||||
|
|
@ -137,7 +143,7 @@ class GleanSearchAPIComponent(LCToolComponent):
|
|||
|
||||
# Build the data
|
||||
data = [Data(data=result, text=result["snippets"][0]["text"]) for result in results]
|
||||
self.status = data
|
||||
self.status = data # type: ignore[assignment]
|
||||
|
||||
return data
|
||||
|
||||
|
|
@ -146,7 +152,7 @@ class GleanSearchAPIComponent(LCToolComponent):
|
|||
glean_api_url: str,
|
||||
glean_access_token: str,
|
||||
):
|
||||
return self.GleanAPIWrapper(
|
||||
return GleanAPIWrapper(
|
||||
glean_api_url=glean_api_url,
|
||||
glean_access_token=glean_access_token,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Any
|
|||
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_community.utilities.serpapi import SerpAPIWrapper
|
||||
from langchain_core.tools import ToolException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -11,6 +12,23 @@ from langflow.inputs import DictInput, IntInput, MultilineInput, SecretStrInput
|
|||
from langflow.schema import Data
|
||||
|
||||
|
||||
class SerpAPISchema(BaseModel):
|
||||
"""Schema for SerpAPI search parameters."""
|
||||
|
||||
query: str = Field(..., description="The search query")
|
||||
params: dict[str, Any] | None = Field(
|
||||
default={
|
||||
"engine": "google",
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
},
|
||||
description="Additional search parameters",
|
||||
)
|
||||
max_results: int = Field(5, description="Maximum number of results to return")
|
||||
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")
|
||||
|
||||
|
||||
class SerpAPIComponent(LCToolComponent):
|
||||
display_name = "Serp Search API"
|
||||
description = "Call Serp Search API with result limiting"
|
||||
|
|
@ -27,46 +45,50 @@ class SerpAPIComponent(LCToolComponent):
|
|||
IntInput(name="max_snippet_length", display_name="Max Snippet Length", value=100, advanced=True),
|
||||
]
|
||||
|
||||
class SerpAPISchema(BaseModel):
|
||||
query: str = Field(..., description="The search query")
|
||||
params: dict[str, Any] | None = Field(default_factory=dict, description="Additional search parameters")
|
||||
max_results: int = Field(5, description="Maximum number of results to return")
|
||||
max_snippet_length: int = Field(100, description="Maximum length of each result snippet")
|
||||
|
||||
def _build_wrapper(self) -> SerpAPIWrapper:
|
||||
if self.search_params:
|
||||
def _build_wrapper(self, params: dict[str, Any] | None = None) -> SerpAPIWrapper:
|
||||
"""Build a SerpAPIWrapper with the provided parameters."""
|
||||
params = params or {}
|
||||
if params:
|
||||
return SerpAPIWrapper(
|
||||
serpapi_api_key=self.serpapi_api_key,
|
||||
params=self.search_params,
|
||||
params=params,
|
||||
)
|
||||
return SerpAPIWrapper(serpapi_api_key=self.serpapi_api_key)
|
||||
|
||||
def build_tool(self) -> Tool:
|
||||
wrapper = self._build_wrapper()
|
||||
wrapper = self._build_wrapper(self.search_params) # noqa: F841
|
||||
|
||||
def search_func(
|
||||
query: str, params: dict[str, Any] | None = None, max_results: int = 5, max_snippet_length: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
params = params or {}
|
||||
full_results = wrapper.results(query, **params)
|
||||
organic_results = full_results.get("organic_results", [])[:max_results]
|
||||
try:
|
||||
# rebuild the wrapper if params are provided
|
||||
if params:
|
||||
wrapper = self._build_wrapper(params)
|
||||
|
||||
limited_results = []
|
||||
for result in organic_results:
|
||||
limited_result = {
|
||||
"title": result.get("title", "")[:max_snippet_length],
|
||||
"link": result.get("link", ""),
|
||||
"snippet": result.get("snippet", "")[:max_snippet_length],
|
||||
}
|
||||
limited_results.append(limited_result)
|
||||
full_results = wrapper.results(query)
|
||||
organic_results = full_results.get("organic_results", [])[:max_results]
|
||||
|
||||
limited_results = []
|
||||
for result in organic_results:
|
||||
limited_result = {
|
||||
"title": result.get("title", "")[:max_snippet_length],
|
||||
"link": result.get("link", ""),
|
||||
"snippet": result.get("snippet", "")[:max_snippet_length],
|
||||
}
|
||||
limited_results.append(limited_result)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error in SerpAPI search: {e!s}"
|
||||
logger.debug(error_message)
|
||||
raise ToolException(error_message) from e
|
||||
return limited_results
|
||||
|
||||
tool = StructuredTool.from_function(
|
||||
name="serp_search_api",
|
||||
description="Search for recent results using SerpAPI with result limiting",
|
||||
func=search_func,
|
||||
args_schema=self.SerpAPISchema,
|
||||
args_schema=SerpAPISchema,
|
||||
)
|
||||
|
||||
self.status = "SerpAPI Tool created"
|
||||
|
|
@ -91,5 +113,5 @@ class SerpAPIComponent(LCToolComponent):
|
|||
self.status = f"Error: {e}"
|
||||
return [Data(data={"error": str(e)}, text=str(e))]
|
||||
|
||||
self.status = data_list
|
||||
self.status = data_list # type: ignore[assignment]
|
||||
return data_list
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Any
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_core.tools import ToolException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -11,6 +12,25 @@ from langflow.inputs import BoolInput, DropdownInput, IntInput, MessageTextInput
|
|||
from langflow.schema import Data
|
||||
|
||||
|
||||
class TavilySearchDepth(Enum):
|
||||
BASIC = "basic"
|
||||
ADVANCED = "advanced"
|
||||
|
||||
|
||||
class TavilySearchTopic(Enum):
|
||||
GENERAL = "general"
|
||||
NEWS = "news"
|
||||
|
||||
|
||||
class TavilySearchSchema(BaseModel):
|
||||
query: str = Field(..., description="The search query you want to execute with Tavily.")
|
||||
search_depth: TavilySearchDepth = Field(TavilySearchDepth.BASIC, description="The depth of the search.")
|
||||
topic: TavilySearchTopic = Field(TavilySearchTopic.GENERAL, description="The category of the search.")
|
||||
max_results: int = Field(5, description="The maximum number of search results to return.")
|
||||
include_images: bool = Field(default=False, description="Include a list of query-related images in the response.")
|
||||
include_answer: bool = Field(default=False, description="Include a short answer to original query.")
|
||||
|
||||
|
||||
class TavilySearchToolComponent(LCToolComponent):
|
||||
display_name = "Tavily AI Search"
|
||||
description = """**Tavily AI** is a search engine optimized for LLMs and RAG, \
|
||||
|
|
@ -38,16 +58,16 @@ Note: Check 'Advanced' for all options.
|
|||
name="search_depth",
|
||||
display_name="Search Depth",
|
||||
info="The depth of the search.",
|
||||
options=["basic", "advanced"],
|
||||
value="advanced",
|
||||
options=list(TavilySearchDepth),
|
||||
value=TavilySearchDepth.ADVANCED,
|
||||
advanced=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="topic",
|
||||
display_name="Search Topic",
|
||||
info="The category of the search.",
|
||||
options=["general", "news"],
|
||||
value="general",
|
||||
options=list(TavilySearchTopic),
|
||||
value=TavilySearchTopic.GENERAL,
|
||||
advanced=True,
|
||||
),
|
||||
IntInput(
|
||||
|
|
@ -73,21 +93,32 @@ Note: Check 'Advanced' for all options.
|
|||
),
|
||||
]
|
||||
|
||||
class TavilySearchSchema(BaseModel):
|
||||
query: str = Field(..., description="The search query you want to execute with Tavily.")
|
||||
search_depth: str = Field("basic", description="The depth of the search.")
|
||||
topic: str = Field("general", description="The category of the search.")
|
||||
max_results: int = Field(5, description="The maximum number of search results to return.")
|
||||
include_images: bool = Field(
|
||||
default=False, description="Include a list of query-related images in the response."
|
||||
)
|
||||
include_answer: bool = Field(default=False, description="Include a short answer to original query.")
|
||||
|
||||
def run_model(self) -> list[Data]:
|
||||
# Convert string values to enum instances with validation
|
||||
try:
|
||||
search_depth_enum = (
|
||||
self.search_depth
|
||||
if isinstance(self.search_depth, TavilySearchDepth)
|
||||
else TavilySearchDepth(str(self.search_depth).lower())
|
||||
)
|
||||
except ValueError as e:
|
||||
error_message = f"Invalid search depth value: {e!s}"
|
||||
self.status = error_message
|
||||
return [Data(data={"error": error_message})]
|
||||
|
||||
try:
|
||||
topic_enum = (
|
||||
self.topic if isinstance(self.topic, TavilySearchTopic) else TavilySearchTopic(str(self.topic).lower())
|
||||
)
|
||||
except ValueError as e:
|
||||
error_message = f"Invalid topic value: {e!s}"
|
||||
self.status = error_message
|
||||
return [Data(data={"error": error_message})]
|
||||
|
||||
return self._tavily_search(
|
||||
self.query,
|
||||
search_depth=self.search_depth,
|
||||
topic=self.topic,
|
||||
search_depth=search_depth_enum,
|
||||
topic=topic_enum,
|
||||
max_results=self.max_results,
|
||||
include_images=self.include_images,
|
||||
include_answer=self.include_answer,
|
||||
|
|
@ -98,19 +129,27 @@ Note: Check 'Advanced' for all options.
|
|||
name="tavily_search",
|
||||
description="Perform a web search using the Tavily API.",
|
||||
func=self._tavily_search,
|
||||
args_schema=self.TavilySearchSchema,
|
||||
args_schema=TavilySearchSchema,
|
||||
)
|
||||
|
||||
def _tavily_search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
search_depth: str = "basic",
|
||||
topic: str = "general",
|
||||
search_depth: TavilySearchDepth = TavilySearchDepth.BASIC,
|
||||
topic: TavilySearchTopic = TavilySearchTopic.GENERAL,
|
||||
max_results: int = 5,
|
||||
include_images: bool = False,
|
||||
include_answer: bool = False,
|
||||
) -> list[Data]:
|
||||
# Validate enum values
|
||||
if not isinstance(search_depth, TavilySearchDepth):
|
||||
msg = f"Invalid search_depth value: {search_depth}"
|
||||
raise TypeError(msg)
|
||||
if not isinstance(topic, TavilySearchTopic):
|
||||
msg = f"Invalid topic value: {topic}"
|
||||
raise TypeError(msg)
|
||||
|
||||
try:
|
||||
url = "https://api.tavily.com/search"
|
||||
headers = {
|
||||
|
|
@ -120,8 +159,8 @@ Note: Check 'Advanced' for all options.
|
|||
payload = {
|
||||
"api_key": self.api_key,
|
||||
"query": query,
|
||||
"search_depth": search_depth,
|
||||
"topic": topic,
|
||||
"search_depth": search_depth.value,
|
||||
"topic": topic.value,
|
||||
"max_results": max_results,
|
||||
"include_images": include_images,
|
||||
"include_answer": include_answer,
|
||||
|
|
@ -151,15 +190,16 @@ Note: Check 'Advanced' for all options.
|
|||
if include_images and search_results.get("images"):
|
||||
data_results.append(Data(data={"images": search_results["images"]}))
|
||||
|
||||
self.status = data_results # type: ignore[assignment]
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_message = f"HTTP error: {e.response.status_code} - {e.response.text}"
|
||||
logger.debug(error_message)
|
||||
self.status = error_message
|
||||
return [Data(data={"error": error_message})]
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.opt(exception=True).debug("Error running Tavily Search")
|
||||
raise ToolException(error_message) from e
|
||||
except Exception as e:
|
||||
error_message = f"Unexpected error: {e}"
|
||||
logger.opt(exception=True).debug("Error running Tavily Search")
|
||||
self.status = error_message
|
||||
return [Data(data={"error": error_message})]
|
||||
|
||||
self.status: Any = data_results
|
||||
raise ToolException(error_message) from e
|
||||
return data_results
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import ast
|
||||
import pprint
|
||||
from enum import Enum
|
||||
|
||||
import yfinance as yf
|
||||
from langchain.tools import StructuredTool
|
||||
from langchain_core.tools import ToolException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -12,6 +14,40 @@ from langflow.inputs import DropdownInput, IntInput, MessageTextInput
|
|||
from langflow.schema import Data
|
||||
|
||||
|
||||
class YahooFinanceMethod(Enum):
|
||||
GET_INFO = "get_info"
|
||||
GET_NEWS = "get_news"
|
||||
GET_ACTIONS = "get_actions"
|
||||
GET_ANALYSIS = "get_analysis"
|
||||
GET_BALANCE_SHEET = "get_balance_sheet"
|
||||
GET_CALENDAR = "get_calendar"
|
||||
GET_CASHFLOW = "get_cashflow"
|
||||
GET_INSTITUTIONAL_HOLDERS = "get_institutional_holders"
|
||||
GET_RECOMMENDATIONS = "get_recommendations"
|
||||
GET_SUSTAINABILITY = "get_sustainability"
|
||||
GET_MAJOR_HOLDERS = "get_major_holders"
|
||||
GET_MUTUALFUND_HOLDERS = "get_mutualfund_holders"
|
||||
GET_INSIDER_PURCHASES = "get_insider_purchases"
|
||||
GET_INSIDER_TRANSACTIONS = "get_insider_transactions"
|
||||
GET_INSIDER_ROSTER_HOLDERS = "get_insider_roster_holders"
|
||||
GET_DIVIDENDS = "get_dividends"
|
||||
GET_CAPITAL_GAINS = "get_capital_gains"
|
||||
GET_SPLITS = "get_splits"
|
||||
GET_SHARES = "get_shares"
|
||||
GET_FAST_INFO = "get_fast_info"
|
||||
GET_SEC_FILINGS = "get_sec_filings"
|
||||
GET_RECOMMENDATIONS_SUMMARY = "get_recommendations_summary"
|
||||
GET_UPGRADES_DOWNGRADES = "get_upgrades_downgrades"
|
||||
GET_EARNINGS = "get_earnings"
|
||||
GET_INCOME_STMT = "get_income_stmt"
|
||||
|
||||
|
||||
class YahooFinanceSchema(BaseModel):
|
||||
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
|
||||
method: YahooFinanceMethod = Field(YahooFinanceMethod.GET_INFO, description="The type of data to retrieve.")
|
||||
num_news: int | None = Field(5, description="The number of news articles to retrieve.")
|
||||
|
||||
|
||||
class YfinanceToolComponent(LCToolComponent):
|
||||
display_name = "Yahoo Finance"
|
||||
description = "Access financial data and market information using Yahoo Finance."
|
||||
|
|
@ -23,24 +59,12 @@ class YfinanceToolComponent(LCToolComponent):
|
|||
name="symbol",
|
||||
display_name="Stock Symbol",
|
||||
info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).",
|
||||
required=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="method",
|
||||
display_name="Data Method",
|
||||
info="The type of data to retrieve.",
|
||||
options=[
|
||||
"get_actions",
|
||||
"get_analysis",
|
||||
"get_balance_sheet",
|
||||
"get_calendar",
|
||||
"get_cashflow",
|
||||
"get_info",
|
||||
"get_institutional_holders",
|
||||
"get_news",
|
||||
"get_recommendations",
|
||||
"get_sustainability",
|
||||
],
|
||||
options=list(YahooFinanceMethod),
|
||||
value="get_news",
|
||||
),
|
||||
IntInput(
|
||||
|
|
@ -51,11 +75,6 @@ class YfinanceToolComponent(LCToolComponent):
|
|||
),
|
||||
]
|
||||
|
||||
class YahooFinanceSchema(BaseModel):
|
||||
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
|
||||
method: str = Field("get_info", description="The type of data to retrieve.")
|
||||
num_news: int = Field(5, description="The number of news articles to retrieve.")
|
||||
|
||||
def run_model(self) -> list[Data]:
|
||||
return self._yahoo_finance_tool(
|
||||
self.symbol,
|
||||
|
|
@ -68,36 +87,36 @@ class YfinanceToolComponent(LCToolComponent):
|
|||
name="yahoo_finance",
|
||||
description="Access financial data and market information from Yahoo Finance.",
|
||||
func=self._yahoo_finance_tool,
|
||||
args_schema=self.YahooFinanceSchema,
|
||||
args_schema=YahooFinanceSchema,
|
||||
)
|
||||
|
||||
def _yahoo_finance_tool(
|
||||
self,
|
||||
symbol: str,
|
||||
method: str,
|
||||
method: YahooFinanceMethod,
|
||||
num_news: int | None = 5,
|
||||
) -> list[Data]:
|
||||
ticker = yf.Ticker(symbol)
|
||||
|
||||
try:
|
||||
if method == "get_info":
|
||||
if method == YahooFinanceMethod.GET_INFO:
|
||||
result = ticker.info
|
||||
elif method == "get_news":
|
||||
elif method == YahooFinanceMethod.GET_NEWS:
|
||||
result = ticker.news[:num_news]
|
||||
else:
|
||||
result = getattr(ticker, method)()
|
||||
result = getattr(ticker, method.value)()
|
||||
|
||||
result = pprint.pformat(result)
|
||||
|
||||
if method == "get_news":
|
||||
if method == YahooFinanceMethod.GET_NEWS:
|
||||
data_list = [Data(data=article) for article in ast.literal_eval(result)]
|
||||
else:
|
||||
data_list = [Data(data={"result": result})]
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
except Exception as e:
|
||||
error_message = f"Error retrieving data: {e}"
|
||||
logger.opt(exception=True).debug(error_message)
|
||||
logger.debug(error_message)
|
||||
self.status = error_message
|
||||
return [Data(data={"error": error_message})]
|
||||
raise ToolException(error_message) from e
|
||||
|
||||
return data_list
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue