refactor(yahoo-finance): Refactor Yahoo Finance API component to support tool mode (#5434)
* refactor(yahoo-finance): Refactor Yahoo Finance API component * fix(yahoo-finance): Revert method input changes and fix enum error * fix: update yfinance tool test and remove duplicate method * fix: rename component class to avoid conflict with legacy version * [autofix.ci] apply automated fixes * test: add ToolException import to yfinance tool tests * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
parent
18e21cfec9
commit
06139ef2df
4 changed files with 226 additions and 12 deletions
|
|
@ -25,6 +25,7 @@ from .tavily_search import TavilySearchToolComponent
|
|||
from .wikidata_api import WikidataAPIComponent
|
||||
from .wikipedia_api import WikipediaAPIComponent
|
||||
from .wolfram_alpha_api import WolframAlphaAPIComponent
|
||||
from .yahoo import YfinanceComponent
|
||||
from .yahoo_finance import YfinanceToolComponent
|
||||
from .youtube_transcripts import YouTubeTranscriptsComponent
|
||||
|
||||
|
|
@ -59,6 +60,7 @@ __all__ = [
|
|||
"WikidataAPIComponent",
|
||||
"WikipediaAPIComponent",
|
||||
"WolframAlphaAPIComponent",
|
||||
"YfinanceComponent",
|
||||
"YfinanceToolComponent",
|
||||
"YouTubeTranscriptsComponent",
|
||||
]
|
||||
|
|
|
|||
142
src/backend/base/langflow/components/tools/yahoo.py
Normal file
142
src/backend/base/langflow/components/tools/yahoo.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
import ast
|
||||
import pprint
|
||||
from enum import Enum
|
||||
|
||||
import yfinance as yf
|
||||
from langchain_core.tools import ToolException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.inputs import DropdownInput, IntInput, MessageTextInput
|
||||
from langflow.io import Output
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
|
||||
|
||||
class YahooFinanceMethod(Enum):
|
||||
GET_INFO = "get_info"
|
||||
GET_NEWS = "get_news"
|
||||
GET_ACTIONS = "get_actions"
|
||||
GET_ANALYSIS = "get_analysis"
|
||||
GET_BALANCE_SHEET = "get_balance_sheet"
|
||||
GET_CALENDAR = "get_calendar"
|
||||
GET_CASHFLOW = "get_cashflow"
|
||||
GET_INSTITUTIONAL_HOLDERS = "get_institutional_holders"
|
||||
GET_RECOMMENDATIONS = "get_recommendations"
|
||||
GET_SUSTAINABILITY = "get_sustainability"
|
||||
GET_MAJOR_HOLDERS = "get_major_holders"
|
||||
GET_MUTUALFUND_HOLDERS = "get_mutualfund_holders"
|
||||
GET_INSIDER_PURCHASES = "get_insider_purchases"
|
||||
GET_INSIDER_TRANSACTIONS = "get_insider_transactions"
|
||||
GET_INSIDER_ROSTER_HOLDERS = "get_insider_roster_holders"
|
||||
GET_DIVIDENDS = "get_dividends"
|
||||
GET_CAPITAL_GAINS = "get_capital_gains"
|
||||
GET_SPLITS = "get_splits"
|
||||
GET_SHARES = "get_shares"
|
||||
GET_FAST_INFO = "get_fast_info"
|
||||
GET_SEC_FILINGS = "get_sec_filings"
|
||||
GET_RECOMMENDATIONS_SUMMARY = "get_recommendations_summary"
|
||||
GET_UPGRADES_DOWNGRADES = "get_upgrades_downgrades"
|
||||
GET_EARNINGS = "get_earnings"
|
||||
GET_INCOME_STMT = "get_income_stmt"
|
||||
|
||||
|
||||
class YahooFinanceSchema(BaseModel):
|
||||
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
|
||||
method: YahooFinanceMethod = Field(YahooFinanceMethod.GET_INFO, description="The type of data to retrieve.")
|
||||
num_news: int | None = Field(5, description="The number of news articles to retrieve.")
|
||||
|
||||
|
||||
class YfinanceComponent(Component):
|
||||
display_name = "Yahoo Finance"
|
||||
description = """Uses [yfinance](https://pypi.org/project/yfinance/) (unofficial package) \
|
||||
to access financial data and market information from Yahoo Finance."""
|
||||
icon = "trending-up"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="symbol",
|
||||
display_name="Stock Symbol",
|
||||
info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).",
|
||||
tool_mode=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="method",
|
||||
display_name="Data Method",
|
||||
info="The type of data to retrieve.",
|
||||
options=list(YahooFinanceMethod),
|
||||
value="get_news",
|
||||
),
|
||||
IntInput(
|
||||
name="num_news",
|
||||
display_name="Number of News",
|
||||
info="The number of news articles to retrieve (only applicable for get_news).",
|
||||
value=5,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(display_name="Data", name="data", method="fetch_content"),
|
||||
Output(display_name="Text", name="text", method="fetch_content_text"),
|
||||
]
|
||||
|
||||
def run_model(self) -> list[Data]:
|
||||
return self.fetch_content()
|
||||
|
||||
def fetch_content_text(self) -> Message:
|
||||
data = self.fetch_content()
|
||||
result_string = ""
|
||||
for item in data:
|
||||
result_string += item.text + "\n"
|
||||
self.status = result_string
|
||||
return Message(text=result_string)
|
||||
|
||||
def _fetch_yfinance_data(self, ticker: yf.Ticker, method: YahooFinanceMethod, num_news: int | None) -> str:
|
||||
try:
|
||||
if method == YahooFinanceMethod.GET_INFO:
|
||||
result = ticker.info
|
||||
elif method == YahooFinanceMethod.GET_NEWS:
|
||||
result = ticker.news[:num_news]
|
||||
else:
|
||||
result = getattr(ticker, method.value)()
|
||||
return pprint.pformat(result)
|
||||
except Exception as e:
|
||||
error_message = f"Error retrieving data: {e}"
|
||||
logger.debug(error_message)
|
||||
self.status = error_message
|
||||
raise ToolException(error_message) from e
|
||||
|
||||
def fetch_content(self) -> list[Data]:
|
||||
try:
|
||||
return self._yahoo_finance_tool(
|
||||
self.symbol,
|
||||
YahooFinanceMethod(self.method),
|
||||
self.num_news,
|
||||
)
|
||||
except ToolException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = f"Unexpected error: {e}"
|
||||
logger.debug(error_message)
|
||||
self.status = error_message
|
||||
raise ToolException(error_message) from e
|
||||
|
||||
def _yahoo_finance_tool(
|
||||
self,
|
||||
symbol: str,
|
||||
method: YahooFinanceMethod,
|
||||
num_news: int | None = 5,
|
||||
) -> list[Data]:
|
||||
ticker = yf.Ticker(symbol)
|
||||
result = self._fetch_yfinance_data(ticker, method, num_news)
|
||||
|
||||
if method == YahooFinanceMethod.GET_NEWS:
|
||||
data_list = [
|
||||
Data(text=f"{article['title']}: {article['link']}", data=article)
|
||||
for article in ast.literal_eval(result)
|
||||
]
|
||||
else:
|
||||
data_list = [Data(text=result, data={"result": result})]
|
||||
|
||||
return data_list
|
||||
|
|
@ -49,11 +49,12 @@ class YahooFinanceSchema(BaseModel):
|
|||
|
||||
|
||||
class YfinanceToolComponent(LCToolComponent):
|
||||
display_name = "Yahoo Finance"
|
||||
display_name = "Yahoo Finance [DEPRECATED]"
|
||||
description = """Uses [yfinance](https://pypi.org/project/yfinance/) (unofficial package) \
|
||||
to access financial data and market information from Yahoo Finance."""
|
||||
icon = "trending-up"
|
||||
name = "YahooFinanceTool"
|
||||
legacy = True
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
|
|
|
|||
|
|
@ -1,14 +1,83 @@
|
|||
from langflow.components.tools import YfinanceToolComponent
|
||||
from langflow.custom import Component
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import ToolException
|
||||
from langflow.components.tools import YfinanceComponent
|
||||
from langflow.components.tools.yahoo import YahooFinanceMethod
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
def test_yfinance_tool_template():
|
||||
yf_tool = YfinanceToolComponent()
|
||||
component = Component(_code=yf_tool._code)
|
||||
frontend_node, _ = build_custom_component_template(component)
|
||||
assert "outputs" in frontend_node
|
||||
output_names = [output["name"] for output in frontend_node["outputs"]]
|
||||
assert "api_run_model" in output_names
|
||||
assert "api_build_tool" in output_names
|
||||
assert all(output["types"] != [] for output in frontend_node["outputs"])
|
||||
class TestYfinanceComponent:
|
||||
@pytest.fixture
|
||||
def component_class(self):
|
||||
return YfinanceComponent
|
||||
|
||||
@pytest.fixture
|
||||
def default_kwargs(self):
|
||||
return {"symbol": "AAPL", "method": YahooFinanceMethod.GET_INFO, "num_news": 5, "_session_id": "test-session"}
|
||||
|
||||
@pytest.fixture
|
||||
def file_names_mapping(self):
|
||||
return []
|
||||
|
||||
def test_initialization(self, component_class):
|
||||
component = component_class()
|
||||
assert component.display_name == "Yahoo Finance"
|
||||
assert component.icon == "trending-up"
|
||||
assert "yfinance" in component.description
|
||||
|
||||
def test_template_structure(self, component_class):
|
||||
component = component_class()
|
||||
frontend_node, _ = build_custom_component_template(component)
|
||||
|
||||
assert "template" in frontend_node
|
||||
input_names = [input_["name"] for input_ in frontend_node["template"].values() if isinstance(input_, dict)]
|
||||
|
||||
expected_inputs = ["symbol", "method", "num_news"]
|
||||
for input_name in expected_inputs:
|
||||
assert input_name in input_names
|
||||
|
||||
@patch("langflow.components.tools.yahoo.yf.Ticker")
|
||||
def test_fetch_info(self, mock_ticker, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
# Setup mock
|
||||
mock_instance = MagicMock()
|
||||
mock_ticker.return_value = mock_instance
|
||||
mock_instance.info = {"companyName": "Apple Inc."}
|
||||
|
||||
result = component.fetch_content()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert "Apple Inc." in result[0].text
|
||||
|
||||
@patch("langflow.components.tools.yahoo.yf.Ticker")
|
||||
def test_fetch_news(self, mock_ticker, component_class):
|
||||
component = component_class(symbol="AAPL", method=YahooFinanceMethod.GET_NEWS, num_news=2)
|
||||
|
||||
# Setup mock
|
||||
mock_instance = MagicMock()
|
||||
mock_ticker.return_value = mock_instance
|
||||
mock_instance.news = [
|
||||
{"title": "News 1", "link": "http://example.com/1"},
|
||||
{"title": "News 2", "link": "http://example.com/2"},
|
||||
]
|
||||
|
||||
result = component.fetch_content()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(item, Data) for item in result)
|
||||
assert "News 1" in result[0].text
|
||||
assert "http://example.com/1" in result[0].text
|
||||
|
||||
def test_error_handling(self, component_class, default_kwargs):
|
||||
component = component_class(**default_kwargs)
|
||||
|
||||
with patch.object(component, "_fetch_yfinance_data") as mock_fetch:
|
||||
mock_fetch.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(ToolException):
|
||||
component.fetch_content()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue