feat(YahooFinanceTool): enhance tool with new inputs for data retrieval methods (#3738)
* feat(YahooFinanceTool): enhance tool with new inputs for data retrieval methods * test: fix test * test: fix test units * test: fix import * fix: rename component * Fix instantiation of YfinanceToolComponent in complex_agent.py --------- Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com> Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
parent
b591d7105e
commit
11283655fe
2 changed files with 88 additions and 22 deletions
|
|
@ -1,35 +1,101 @@
|
|||
from typing import cast
|
||||
import ast
|
||||
import pprint
|
||||
|
||||
from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool
|
||||
import yfinance as yf
|
||||
from langchain.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langflow.base.langchain_utilities.model import LCToolComponent
|
||||
from langflow.field_typing import Data, Tool
|
||||
from langflow.inputs.inputs import MessageTextInput
|
||||
from langflow.template.field.base import Output
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.inputs import DropdownInput, IntInput, MessageTextInput
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class YfinanceToolComponent(LCToolComponent):
|
||||
display_name = "Yahoo Finance News Tool"
|
||||
description = "Tool for interacting with Yahoo Finance News."
|
||||
name = "YFinanceTool"
|
||||
display_name = "Yahoo Finance Tool"
|
||||
description = "Access financial data and market information using Yahoo Finance."
|
||||
icon = "trending-up"
|
||||
name = "YahooFinanceTool"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="input_value",
|
||||
display_name="Query",
|
||||
info="Input should be a company ticker. For example, AAPL for Apple, MSFT for Microsoft.",
|
||||
)
|
||||
name="symbol",
|
||||
display_name="Stock Symbol",
|
||||
info="The stock symbol to retrieve data for (e.g., AAPL, GOOG).",
|
||||
required=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="method",
|
||||
display_name="Data Method",
|
||||
info="The type of data to retrieve.",
|
||||
options=[
|
||||
"get_actions",
|
||||
"get_analysis",
|
||||
"get_balance_sheet",
|
||||
"get_calendar",
|
||||
"get_cashflow",
|
||||
"get_info",
|
||||
"get_institutional_holders",
|
||||
"get_news",
|
||||
"get_recommendations",
|
||||
"get_sustainability",
|
||||
],
|
||||
value="get_news",
|
||||
),
|
||||
IntInput(
|
||||
name="num_news",
|
||||
display_name="Number of News",
|
||||
info="The number of news articles to retrieve (only applicable for get_news).",
|
||||
value=5,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
Output(name="api_run_model", display_name="Data", method="run_model"),
|
||||
# Keep this for backwards compatibility
|
||||
Output(name="tool", display_name="Tool", method="build_tool"),
|
||||
]
|
||||
class YahooFinanceSchema(BaseModel):
|
||||
symbol: str = Field(..., description="The stock symbol to retrieve data for.")
|
||||
method: str = Field("get_info", description="The type of data to retrieve.")
|
||||
num_news: int | None = Field(5, description="The number of news articles to retrieve.")
|
||||
|
||||
def run_model(self) -> list[Data]:
|
||||
return self._yahoo_finance_tool(
|
||||
self.symbol,
|
||||
self.method,
|
||||
self.num_news,
|
||||
)
|
||||
|
||||
def build_tool(self) -> Tool:
|
||||
return cast(Tool, YahooFinanceNewsTool())
|
||||
return StructuredTool.from_function(
|
||||
name="yahoo_finance",
|
||||
description="Access financial data and market information from Yahoo Finance.",
|
||||
func=self._yahoo_finance_tool,
|
||||
args_schema=self.YahooFinanceSchema,
|
||||
)
|
||||
|
||||
def run_model(self) -> Data:
|
||||
tool = self.build_tool()
|
||||
return tool.run(self.input_value)
|
||||
def _yahoo_finance_tool(
|
||||
self,
|
||||
symbol: str,
|
||||
method: str,
|
||||
num_news: int | None = 5,
|
||||
) -> list[Data]:
|
||||
ticker = yf.Ticker(symbol)
|
||||
|
||||
try:
|
||||
if method == "get_info":
|
||||
result = ticker.info
|
||||
elif method == "get_news":
|
||||
result = ticker.news[:num_news]
|
||||
else:
|
||||
result = getattr(ticker, method)()
|
||||
|
||||
result = pprint.pformat(result)
|
||||
|
||||
if method == "get_news":
|
||||
data_list = [Data(data=article) for article in ast.literal_eval(result)]
|
||||
else:
|
||||
data_list = [Data(data={"result": result})]
|
||||
|
||||
return data_list
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"Error retrieving data: {str(e)}"
|
||||
self.status = error_message
|
||||
return [Data(data={"error": error_message})]
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@ def test_yfinance_tool_template():
|
|||
assert "outputs" in frontend_node
|
||||
output_names = [output["name"] for output in frontend_node["outputs"]]
|
||||
assert "api_run_model" in output_names
|
||||
assert "tool" in output_names
|
||||
assert "api_build_tool" in output_names
|
||||
assert all(output["types"] != [] for output in frontend_node["outputs"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue