refactor(yahoo-finance): Refactor Yahoo Finance API component to support tool mode (#5434)

* refactor(yahoo-finance): Refactor Yahoo Finance API component

* fix(yahoo-finance): Revert method input changes and fix enum error

* fix: update yfinance tool test and remove duplicate method

* fix: rename component class to avoid conflict with legacy version

* [autofix.ci] apply automated fixes

* test: add ToolException import to yfinance tool tests

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
Raphael Valdetaro 2025-01-20 12:07:37 -03:00 committed by GitHub
commit 06139ef2df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 226 additions and 12 deletions

View file

@ -25,6 +25,7 @@ from .tavily_search import TavilySearchToolComponent
from .wikidata_api import WikidataAPIComponent
from .wikipedia_api import WikipediaAPIComponent
from .wolfram_alpha_api import WolframAlphaAPIComponent
from .yahoo import YfinanceComponent
from .yahoo_finance import YfinanceToolComponent
from .youtube_transcripts import YouTubeTranscriptsComponent
@ -59,6 +60,7 @@ __all__ = [
"WikidataAPIComponent",
"WikipediaAPIComponent",
"WolframAlphaAPIComponent",
"YfinanceComponent",
"YfinanceToolComponent",
"YouTubeTranscriptsComponent",
]

View file

@ -0,0 +1,142 @@
import ast
import pprint
from enum import Enum
import yfinance as yf
from langchain_core.tools import ToolException
from loguru import logger
from pydantic import BaseModel, Field
from langflow.custom import Component
from langflow.inputs import DropdownInput, IntInput, MessageTextInput
from langflow.io import Output
from langflow.schema import Data
from langflow.schema.message import Message
class YahooFinanceMethod(Enum):
GET_INFO = "get_info"
GET_NEWS = "get_news"
GET_ACTIONS = "get_actions"
GET_ANALYSIS = "get_analysis"
GET_BALANCE_SHEET = "get_balance_sheet"
GET_CALENDAR = "get_calendar"
GET_CASHFLOW = "get_cashflow"
GET_INSTITUTIONAL_HOLDERS = "get_institutional_holders"
GET_RECOMMENDATIONS = "get_recommendations"
GET_SUSTAINABILITY = "get_sustainability"
GET_MAJOR_HOLDERS = "get_major_holders"
GET_MUTUALFUND_HOLDERS = "get_mutualfund_holders"
GET_INSIDER_PURCHASES = "get_insider_purchases"
GET_INSIDER_TRANSACTIONS = "get_insider_transactions"
GET_INSIDER_ROSTER_HOLDERS = "get_insider_roster_holders"
GET_DIVIDENDS = "get_dividends"
GET_CAPITAL_GAINS = "get_capital_gains"
GET_SPLITS = "get_splits"
GET_SHARES = "get_shares"
GET_FAST_INFO = "get_fast_info"
GET_SEC_FILINGS = "get_sec_filings"
GET_RECOMMENDATIONS_SUMMARY = "get_recommendations_summary"
GET_UPGRADES_DOWNGRADES = "get_upgrades_downgrades"
GET_EARNINGS = "get_earnings"
GET_INCOME_STMT = "get_income_stmt"
class YahooFinanceSchema(BaseModel):
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
method: YahooFinanceMethod = Field(YahooFinanceMethod.GET_INFO, description="The type of data to retrieve.")
num_news: int | None = Field(5, description="The number of news articles to retrieve.")
class YfinanceComponent(Component):
display_name = "Yahoo Finance"
description = """Uses [yfinance](https://pypi.org/project/yfinance/) (unofficial package) \
to access financial data and market information from Yahoo Finance."""
icon = "trending-up"
inputs = [
MessageTextInput(
name="symbol",
display_name="Stock Symbol",
info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).",
tool_mode=True,
),
DropdownInput(
name="method",
display_name="Data Method",
info="The type of data to retrieve.",
options=list(YahooFinanceMethod),
value="get_news",
),
IntInput(
name="num_news",
display_name="Number of News",
info="The number of news articles to retrieve (only applicable for get_news).",
value=5,
),
]
outputs = [
Output(display_name="Data", name="data", method="fetch_content"),
Output(display_name="Text", name="text", method="fetch_content_text"),
]
def run_model(self) -> list[Data]:
return self.fetch_content()
def fetch_content_text(self) -> Message:
data = self.fetch_content()
result_string = ""
for item in data:
result_string += item.text + "\n"
self.status = result_string
return Message(text=result_string)
def _fetch_yfinance_data(self, ticker: yf.Ticker, method: YahooFinanceMethod, num_news: int | None) -> str:
try:
if method == YahooFinanceMethod.GET_INFO:
result = ticker.info
elif method == YahooFinanceMethod.GET_NEWS:
result = ticker.news[:num_news]
else:
result = getattr(ticker, method.value)()
return pprint.pformat(result)
except Exception as e:
error_message = f"Error retrieving data: {e}"
logger.debug(error_message)
self.status = error_message
raise ToolException(error_message) from e
def fetch_content(self) -> list[Data]:
try:
return self._yahoo_finance_tool(
self.symbol,
YahooFinanceMethod(self.method),
self.num_news,
)
except ToolException:
raise
except Exception as e:
error_message = f"Unexpected error: {e}"
logger.debug(error_message)
self.status = error_message
raise ToolException(error_message) from e
def _yahoo_finance_tool(
self,
symbol: str,
method: YahooFinanceMethod,
num_news: int | None = 5,
) -> list[Data]:
ticker = yf.Ticker(symbol)
result = self._fetch_yfinance_data(ticker, method, num_news)
if method == YahooFinanceMethod.GET_NEWS:
data_list = [
Data(text=f"{article['title']}: {article['link']}", data=article)
for article in ast.literal_eval(result)
]
else:
data_list = [Data(text=result, data={"result": result})]
return data_list

View file

@ -49,11 +49,12 @@ class YahooFinanceSchema(BaseModel):
class YfinanceToolComponent(LCToolComponent):
display_name = "Yahoo Finance"
display_name = "Yahoo Finance [DEPRECATED]"
description = """Uses [yfinance](https://pypi.org/project/yfinance/) (unofficial package) \
to access financial data and market information from Yahoo Finance."""
icon = "trending-up"
name = "YahooFinanceTool"
legacy = True
inputs = [
MessageTextInput(

View file

@ -1,14 +1,83 @@
from langflow.components.tools import YfinanceToolComponent
from langflow.custom import Component
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.tools import ToolException
from langflow.components.tools import YfinanceComponent
from langflow.components.tools.yahoo import YahooFinanceMethod
from langflow.custom.utils import build_custom_component_template
from langflow.schema import Data
def test_yfinance_tool_template():
yf_tool = YfinanceToolComponent()
component = Component(_code=yf_tool._code)
frontend_node, _ = build_custom_component_template(component)
assert "outputs" in frontend_node
output_names = [output["name"] for output in frontend_node["outputs"]]
assert "api_run_model" in output_names
assert "api_build_tool" in output_names
assert all(output["types"] != [] for output in frontend_node["outputs"])
class TestYfinanceComponent:
@pytest.fixture
def component_class(self):
return YfinanceComponent
@pytest.fixture
def default_kwargs(self):
return {"symbol": "AAPL", "method": YahooFinanceMethod.GET_INFO, "num_news": 5, "_session_id": "test-session"}
@pytest.fixture
def file_names_mapping(self):
return []
def test_initialization(self, component_class):
component = component_class()
assert component.display_name == "Yahoo Finance"
assert component.icon == "trending-up"
assert "yfinance" in component.description
def test_template_structure(self, component_class):
component = component_class()
frontend_node, _ = build_custom_component_template(component)
assert "template" in frontend_node
input_names = [input_["name"] for input_ in frontend_node["template"].values() if isinstance(input_, dict)]
expected_inputs = ["symbol", "method", "num_news"]
for input_name in expected_inputs:
assert input_name in input_names
@patch("langflow.components.tools.yahoo.yf.Ticker")
def test_fetch_info(self, mock_ticker, component_class, default_kwargs):
component = component_class(**default_kwargs)
# Setup mock
mock_instance = MagicMock()
mock_ticker.return_value = mock_instance
mock_instance.info = {"companyName": "Apple Inc."}
result = component.fetch_content()
assert isinstance(result, list)
assert len(result) == 1
assert "Apple Inc." in result[0].text
@patch("langflow.components.tools.yahoo.yf.Ticker")
def test_fetch_news(self, mock_ticker, component_class):
component = component_class(symbol="AAPL", method=YahooFinanceMethod.GET_NEWS, num_news=2)
# Setup mock
mock_instance = MagicMock()
mock_ticker.return_value = mock_instance
mock_instance.news = [
{"title": "News 1", "link": "http://example.com/1"},
{"title": "News 2", "link": "http://example.com/2"},
]
result = component.fetch_content()
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(item, Data) for item in result)
assert "News 1" in result[0].text
assert "http://example.com/1" in result[0].text
def test_error_handling(self, component_class, default_kwargs):
component = component_class(**default_kwargs)
with patch.object(component, "_fetch_yfinance_data") as mock_fetch:
mock_fetch.side_effect = Exception("API Error")
with pytest.raises(ToolException):
component.fetch_content()