feat(YahooFinanceTool): enhance tool with new inputs for data retrieval methods (#3738)

* feat(YahooFinanceTool): enhance tool with new inputs for data retrieval methods

* test: fix test

* test: fix test units

* test: fix import

* fix: rename component

* Fix instantiation of YfinanceToolComponent in complex_agent.py

---------

Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
namastex888 2024-10-03 12:33:28 -03:00 committed by GitHub
commit 11283655fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 88 additions and 22 deletions

View file

@ -1,35 +1,101 @@
from typing import cast
import ast
import pprint
from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool
import yfinance as yf
from langchain.tools import StructuredTool
from pydantic import BaseModel, Field
from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.field_typing import Data, Tool
from langflow.inputs.inputs import MessageTextInput
from langflow.template.field.base import Output
from langflow.field_typing import Tool
from langflow.inputs import DropdownInput, IntInput, MessageTextInput
from langflow.schema import Data
class YfinanceToolComponent(LCToolComponent):
display_name = "Yahoo Finance News Tool"
description = "Tool for interacting with Yahoo Finance News."
name = "YFinanceTool"
display_name = "Yahoo Finance Tool"
description = "Access financial data and market information using Yahoo Finance."
icon = "trending-up"
name = "YahooFinanceTool"
inputs = [
MessageTextInput(
name="input_value",
display_name="Query",
info="Input should be a company ticker. For example, AAPL for Apple, MSFT for Microsoft.",
)
name="symbol",
display_name="Stock Symbol",
info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).",
required=True,
),
DropdownInput(
name="method",
display_name="Data Method",
info="The type of data to retrieve.",
options=[
"get_actions",
"get_analysis",
"get_balance_sheet",
"get_calendar",
"get_cashflow",
"get_info",
"get_institutional_holders",
"get_news",
"get_recommendations",
"get_sustainability",
],
value="get_news",
),
IntInput(
name="num_news",
display_name="Number of News",
info="The number of news articles to retrieve (only applicable for get_news).",
value=5,
),
]
outputs = [
Output(name="api_run_model", display_name="Data", method="run_model"),
# Keep this for backwards compatibility
Output(name="tool", display_name="Tool", method="build_tool"),
]
class YahooFinanceSchema(BaseModel):
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
method: str = Field("get_info", description="The type of data to retrieve.")
num_news: int | None = Field(5, description="The number of news articles to retrieve.")
def run_model(self) -> list[Data]:
return self._yahoo_finance_tool(
self.symbol,
self.method,
self.num_news,
)
def build_tool(self) -> Tool:
return cast(Tool, YahooFinanceNewsTool())
return StructuredTool.from_function(
name="yahoo_finance",
description="Access financial data and market information from Yahoo Finance.",
func=self._yahoo_finance_tool,
args_schema=self.YahooFinanceSchema,
)
def run_model(self) -> Data:
tool = self.build_tool()
return tool.run(self.input_value)
def _yahoo_finance_tool(
self,
symbol: str,
method: str,
num_news: int | None = 5,
) -> list[Data]:
ticker = yf.Ticker(symbol)
try:
if method == "get_info":
result = ticker.info
elif method == "get_news":
result = ticker.news[:num_news]
else:
result = getattr(ticker, method)()
result = pprint.pformat(result)
if method == "get_news":
data_list = [Data(data=article) for article in ast.literal_eval(result)]
else:
data_list = [Data(data={"result": result})]
return data_list
except Exception as e:
error_message = f"Error retrieving data: {str(e)}"
self.status = error_message
return [Data(data={"error": error_message})]

View file

@ -17,5 +17,5 @@ def test_yfinance_tool_template():
assert "outputs" in frontend_node
output_names = [output["name"] for output in frontend_node["outputs"]]
assert "api_run_model" in output_names
assert "tool" in output_names
assert "api_build_tool" in output_names
assert all(output["types"] != [] for output in frontend_node["outputs"])