diff --git a/src/backend/base/langflow/components/tools/__init__.py b/src/backend/base/langflow/components/tools/__init__.py index e405ce6f4..deda86293 100644 --- a/src/backend/base/langflow/components/tools/__init__.py +++ b/src/backend/base/langflow/components/tools/__init__.py @@ -25,6 +25,7 @@ from .tavily_search import TavilySearchToolComponent from .wikidata_api import WikidataAPIComponent from .wikipedia_api import WikipediaAPIComponent from .wolfram_alpha_api import WolframAlphaAPIComponent +from .yahoo import YfinanceComponent from .yahoo_finance import YfinanceToolComponent from .youtube_transcripts import YouTubeTranscriptsComponent @@ -59,6 +60,7 @@ __all__ = [ "WikidataAPIComponent", "WikipediaAPIComponent", "WolframAlphaAPIComponent", + "YfinanceComponent", "YfinanceToolComponent", "YouTubeTranscriptsComponent", ] diff --git a/src/backend/base/langflow/components/tools/yahoo.py b/src/backend/base/langflow/components/tools/yahoo.py new file mode 100644 index 000000000..5874d5154 --- /dev/null +++ b/src/backend/base/langflow/components/tools/yahoo.py @@ -0,0 +1,142 @@ +import ast +import pprint +from enum import Enum + +import yfinance as yf +from langchain_core.tools import ToolException +from loguru import logger +from pydantic import BaseModel, Field + +from langflow.custom import Component +from langflow.inputs import DropdownInput, IntInput, MessageTextInput +from langflow.io import Output +from langflow.schema import Data +from langflow.schema.message import Message + + +class YahooFinanceMethod(Enum): + GET_INFO = "get_info" + GET_NEWS = "get_news" + GET_ACTIONS = "get_actions" + GET_ANALYSIS = "get_analysis" + GET_BALANCE_SHEET = "get_balance_sheet" + GET_CALENDAR = "get_calendar" + GET_CASHFLOW = "get_cashflow" + GET_INSTITUTIONAL_HOLDERS = "get_institutional_holders" + GET_RECOMMENDATIONS = "get_recommendations" + GET_SUSTAINABILITY = "get_sustainability" + GET_MAJOR_HOLDERS = "get_major_holders" + GET_MUTUALFUND_HOLDERS = "get_mutualfund_holders" + GET_INSIDER_PURCHASES = "get_insider_purchases" + GET_INSIDER_TRANSACTIONS = "get_insider_transactions" + GET_INSIDER_ROSTER_HOLDERS = "get_insider_roster_holders" + GET_DIVIDENDS = "get_dividends" + GET_CAPITAL_GAINS = "get_capital_gains" + GET_SPLITS = "get_splits" + GET_SHARES = "get_shares" + GET_FAST_INFO = "get_fast_info" + GET_SEC_FILINGS = "get_sec_filings" + GET_RECOMMENDATIONS_SUMMARY = "get_recommendations_summary" + GET_UPGRADES_DOWNGRADES = "get_upgrades_downgrades" + GET_EARNINGS = "get_earnings" + GET_INCOME_STMT = "get_income_stmt" + + +class YahooFinanceSchema(BaseModel): + symbol: str = Field(..., description="The stock symbol to retrieve data for.") + method: YahooFinanceMethod = Field(YahooFinanceMethod.GET_INFO, description="The type of data to retrieve.") + num_news: int | None = Field(5, description="The number of news articles to retrieve.") + + +class YfinanceComponent(Component): + display_name = "Yahoo Finance" + description = """Uses [yfinance](https://pypi.org/project/yfinance/) (unofficial package) \ +to access financial data and market information from Yahoo Finance.""" + icon = "trending-up" + + inputs = [ + MessageTextInput( + name="symbol", + display_name="Stock Symbol", + info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).", + tool_mode=True, + ), + DropdownInput( + name="method", + display_name="Data Method", + info="The type of data to retrieve.", + options=list(YahooFinanceMethod), + value="get_news", + ), + IntInput( + name="num_news", + display_name="Number of News", + info="The number of news articles to retrieve (only applicable for get_news).", + value=5, + ), + ] + + outputs = [ + Output(display_name="Data", name="data", method="fetch_content"), + Output(display_name="Text", name="text", method="fetch_content_text"), + ] + + def run_model(self) -> list[Data]: + return self.fetch_content() + + def fetch_content_text(self) -> Message: + data = self.fetch_content() + result_string = "" + for item in data: + result_string += item.text + "\n" + self.status = result_string + return Message(text=result_string) + + def _fetch_yfinance_data(self, ticker: yf.Ticker, method: YahooFinanceMethod, num_news: int | None) -> str: + try: + if method == YahooFinanceMethod.GET_INFO: + result = ticker.info + elif method == YahooFinanceMethod.GET_NEWS: + result = ticker.news[:num_news] + else: + result = getattr(ticker, method.value)() + return pprint.pformat(result) + except Exception as e: + error_message = f"Error retrieving data: {e}" + logger.debug(error_message) + self.status = error_message + raise ToolException(error_message) from e + + def fetch_content(self) -> list[Data]: + try: + return self._yahoo_finance_tool( + self.symbol, + YahooFinanceMethod(self.method), + self.num_news, + ) + except ToolException: + raise + except Exception as e: + error_message = f"Unexpected error: {e}" + logger.debug(error_message) + self.status = error_message + raise ToolException(error_message) from e + + def _yahoo_finance_tool( + self, + symbol: str, + method: YahooFinanceMethod, + num_news: int | None = 5, + ) -> list[Data]: + ticker = yf.Ticker(symbol) + result = self._fetch_yfinance_data(ticker, method, num_news) + + if method == YahooFinanceMethod.GET_NEWS: + data_list = [ + Data(text=f"{article['title']}: {article['link']}", data=article) + for article in ast.literal_eval(result) + ] + else: + data_list = [Data(text=result, data={"result": result})] + + return data_list diff --git a/src/backend/base/langflow/components/tools/yahoo_finance.py b/src/backend/base/langflow/components/tools/yahoo_finance.py index 110c890d5..7f7e4a09e 100644 --- a/src/backend/base/langflow/components/tools/yahoo_finance.py +++ b/src/backend/base/langflow/components/tools/yahoo_finance.py @@ -49,11 +49,12 @@ class YahooFinanceSchema(BaseModel): class YfinanceToolComponent(LCToolComponent): - display_name = "Yahoo Finance" + display_name = "Yahoo Finance [DEPRECATED]" description = """Uses [yfinance](https://pypi.org/project/yfinance/) (unofficial package) \ to access financial data and market information from Yahoo Finance.""" icon = "trending-up" name = "YahooFinanceTool" + legacy = True inputs = [ MessageTextInput( diff --git a/src/backend/tests/unit/components/tools/test_yfinance_tool.py b/src/backend/tests/unit/components/tools/test_yfinance_tool.py index 29133929f..1d43a2ab9 100644 --- a/src/backend/tests/unit/components/tools/test_yfinance_tool.py +++ b/src/backend/tests/unit/components/tools/test_yfinance_tool.py @@ -1,14 +1,83 @@ -from langflow.components.tools import YfinanceToolComponent -from langflow.custom import Component +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.tools import ToolException +from langflow.components.tools import YfinanceComponent +from langflow.components.tools.yahoo import YahooFinanceMethod from langflow.custom.utils import build_custom_component_template +from langflow.schema import Data -def test_yfinance_tool_template(): - yf_tool = YfinanceToolComponent() - component = Component(_code=yf_tool._code) - frontend_node, _ = build_custom_component_template(component) - assert "outputs" in frontend_node - output_names = [output["name"] for output in frontend_node["outputs"]] - assert "api_run_model" in output_names - assert "api_build_tool" in output_names - assert all(output["types"] != [] for output in frontend_node["outputs"]) +class TestYfinanceComponent: + @pytest.fixture + def component_class(self): + return YfinanceComponent + + @pytest.fixture + def default_kwargs(self): + return {"symbol": "AAPL", "method": YahooFinanceMethod.GET_INFO, "num_news": 5, "_session_id": "test-session"} + + @pytest.fixture + def file_names_mapping(self): + return [] + + def test_initialization(self, component_class): + component = component_class() + assert component.display_name == "Yahoo Finance" + assert component.icon == "trending-up" + assert "yfinance" in component.description + + def test_template_structure(self, component_class): + component = component_class() + frontend_node, _ = build_custom_component_template(component) + + assert "template" in frontend_node + input_names = [input_["name"] for input_ in frontend_node["template"].values() if isinstance(input_, dict)] + + expected_inputs = ["symbol", "method", "num_news"] + for input_name in expected_inputs: + assert input_name in input_names + + @patch("langflow.components.tools.yahoo.yf.Ticker") + def test_fetch_info(self, mock_ticker, component_class, default_kwargs): + component = component_class(**default_kwargs) + + # Setup mock + mock_instance = MagicMock() + mock_ticker.return_value = mock_instance + mock_instance.info = {"companyName": "Apple Inc."} + + result = component.fetch_content() + + assert isinstance(result, list) + assert len(result) == 1 + assert "Apple Inc." in result[0].text + + @patch("langflow.components.tools.yahoo.yf.Ticker") + def test_fetch_news(self, mock_ticker, component_class): + component = component_class(symbol="AAPL", method=YahooFinanceMethod.GET_NEWS, num_news=2) + + # Setup mock + mock_instance = MagicMock() + mock_ticker.return_value = mock_instance + mock_instance.news = [ + {"title": "News 1", "link": "http://example.com/1"}, + {"title": "News 2", "link": "http://example.com/2"}, + ] + + result = component.fetch_content() + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, Data) for item in result) + assert "News 1" in result[0].text + assert "http://example.com/1" in result[0].text + + def test_error_handling(self, component_class, default_kwargs): + component = component_class(**default_kwargs) + + with patch.object(component, "_fetch_yfinance_data") as mock_fetch: + mock_fetch.side_effect = Exception("API Error") + + with pytest.raises(ToolException): + component.fetch_content()