Refactored and improved PythonCodeStructuredTool, SearXNGTool, and RunnableExecutor (#3239)

* enhancement: Update PythonCodeStructuredTool to create inputs automatically and accept global variables

* [autofix.ci] apply automated fixes

* feat: Create a tool to search using SearXNG

* [autofix.ci] apply automated fixes

* refactor: reorganize imports and type annotations in PythonCodeStructuredTool.py for clarity and consistency

* refactor: clean up imports and enhance type annotations in SearXNGTool.py for improved readability and type safety

* refactor: Improved PythonCodeStructuredTool to allow arguments to have any types

* refactor: Formatted and refactored SearXNGTool

* refactor: Allowed RunnableExecutor to stream output and changed its build method to asynchronous.

---------

Co-authored-by: Haseong Kim <dynaferkim@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
This commit is contained in:
goliath-yamon 2024-08-09 16:30:23 +01:00 committed by GitHub
commit 5ce8cbda9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 448 additions and 74 deletions

View file

@ -1,7 +1,8 @@
from langflow.custom import Component
from langflow.inputs import HandleInput, MessageTextInput
from langflow.inputs import HandleInput, MessageTextInput, BoolInput
from langflow.schema.message import Message
from langflow.template import Output
from langchain.agents import AgentExecutor
class RunnableExecComponent(Component):
@ -30,6 +31,11 @@ class RunnableExecComponent(Component):
value="output",
advanced=True,
),
BoolInput(
name="use_stream",
display_name="Stream",
value=False,
),
]
outputs = [
@ -108,11 +114,24 @@ class RunnableExecComponent(Component):
status = f"Warning: The input key is not '{input_key}'. The input key is '{runnable.input_keys}'."
return input_dict, status
def build_executor(self) -> Message:
async def build_executor(self) -> Message:
input_dict, status = self.get_input_dict(self.runnable, self.input_key, self.input_value)
result = self.runnable.invoke(input_dict)
if not isinstance(self.runnable, AgentExecutor):
raise ValueError("The runnable must be an AgentExecutor")
if self.use_stream:
return self.astream_events(input_dict)
else:
result = await self.runnable.ainvoke(input_dict)
result_value, _status = self.get_output(result, self.input_key, self.output_key)
status += _status
status += f"\n\nOutput: {result_value}\n\nRaw Output: {result}"
self.status = status
return result_value
async def astream_events(self, input):
async for event in self.runnable.astream_events(input, version="v1"):
if event.get("event") != "on_chat_model_stream":
continue
yield event.get("data").get("chunk")

View file

@ -1,95 +1,307 @@
import ast
from typing import Any, Dict, List, Optional
import json
from typing import Any
from langchain.agents import Tool
from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.inputs.inputs import MultilineInput, MessageTextInput, BoolInput, DropdownInput, HandleInput, FieldTypes
from langchain_core.tools import StructuredTool
from langflow.io import Output
from langflow.custom import CustomComponent
from langflow.schema.dotdict import dotdict
from langflow.schema import Data
from pydantic.v1 import Field, create_model
from pydantic.v1.fields import Undefined
class PythonCodeStructuredTool(CustomComponent):
display_name = "PythonCodeTool"
class PythonCodeStructuredTool(LCToolComponent):
DEFAULT_KEYS = [
"code",
"_type",
"text_key",
"tool_code",
"tool_name",
"tool_description",
"return_direct",
"tool_function",
"global_variables",
"_classes",
"_functions",
]
display_name = "Python Code Structured Tool"
description = "structuredtool dataclass code to tool"
documentation = "https://python.langchain.com/docs/modules/tools/custom_tools/#structuredtool-dataclass"
name = "PythonCodeStructuredTool"
icon = "🐍"
field_order = ["name", "description", "tool_code", "return_direct", "tool_function", "tool_class"]
field_order = ["name", "description", "tool_code", "return_direct", "tool_function"]
def build_config(self) -> Dict[str, Any]:
return {
"tool_code": {
"display_name": "Tool Code",
"info": "Enter the dataclass code.",
"placeholder": "def my_function(args):\n pass",
"multiline": True,
"refresh_button": True,
"field_type": "code",
},
"name": {
"display_name": "Tool Name",
"info": "Enter the name of the tool.",
},
"description": {
"display_name": "Description",
"info": "Provide a brief description of what the tool does.",
},
"return_direct": {
"display_name": "Return Directly",
"info": "Should the tool return the function output directly?",
},
"tool_function": {
"display_name": "Tool Function",
"info": "Select the function for additional expressions.",
"options": [],
"refresh_button": True,
},
"tool_class": {
"display_name": "Tool Class",
"info": "Select the class for additional expressions.",
"options": [],
"refresh_button": True,
"required": False,
},
}
inputs = [
MultilineInput(
name="tool_code",
display_name="Tool Code",
info="Enter the dataclass code.",
placeholder="def my_function(args):\n pass",
required=True,
real_time_refresh=True,
refresh_button=True,
),
MessageTextInput(name="tool_name", display_name="Tool Name", info="Enter the name of the tool.", required=True),
MessageTextInput(
name="tool_description",
display_name="Description",
info="Enter the description of the tool.",
required=True,
),
BoolInput(
name="return_direct",
display_name="Return Directly",
info="Should the tool return the function output directly?",
),
DropdownInput(
name="tool_function",
display_name="Tool Function",
info="Select the function for additional expressions.",
options=[],
required=True,
real_time_refresh=True,
refresh_button=True,
),
HandleInput(
name="global_variables",
display_name="Global Variables",
info="Enter the global variables or Create Data Component.",
input_types=["Data"],
field_type=FieldTypes.DICT,
is_list=True,
),
MessageTextInput(name="_classes", display_name="Classes", advanced=True),
MessageTextInput(name="_functions", display_name="Functions", advanced=True),
]
def parse_source_name(self, code: str) -> Dict:
parsed_code = ast.parse(code)
class_names = [node.name for node in parsed_code.body if isinstance(node, ast.ClassDef)]
function_names = [node.name for node in parsed_code.body if isinstance(node, ast.FunctionDef)]
return {"class": class_names, "function": function_names}
outputs = [
Output(display_name="Tool", name="result_tool", method="build_tool"),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
if field_name == "tool_code" or field_name == "tool_function" or field_name == "tool_class":
try:
names = self.parse_source_name(build_config.tool_code.value)
build_config.tool_class.options = names["class"]
build_config.tool_function.options = names["function"]
except Exception as e:
self.status = f"Failed to extract class names: {str(e)}"
build_config.tool_class.options = ["Failed to parse", str(e)]
build_config.tool_function.options = []
if field_name is None:
return build_config
if field_name != "tool_code" and field_name != "tool_function":
return build_config
try:
named_functions = {}
[classes, functions] = self._parse_code(build_config["tool_code"]["value"])
existing_fields = {}
if len(build_config) > len(self.DEFAULT_KEYS):
for key in build_config.copy():
if key not in self.DEFAULT_KEYS:
existing_fields[key] = build_config.pop(key)
names = []
for func in functions:
named_functions[func["name"]] = func
names.append(func["name"])
for arg in func["args"]:
field_name = f"{func['name']}|{arg['name']}"
if field_name in existing_fields:
build_config[field_name] = existing_fields[field_name]
continue
field = MessageTextInput(
display_name=f"{arg['name']}: Description",
name=field_name,
info=f"Enter the description for {arg['name']}",
required=True,
)
build_config[field_name] = field.to_dict()
build_config["_functions"]["value"] = json.dumps(named_functions)
build_config["_classes"]["value"] = json.dumps(classes)
build_config["tool_function"]["options"] = names
except Exception as e:
self.status = f"Failed to extract names: {str(e)}"
build_config["tool_function"]["options"] = ["Failed to parse", str(e)]
return build_config
async def build(
self,
tool_code: str,
name: str,
description: str,
tool_function: List[str],
return_direct: bool,
tool_class: Optional[List[str]] = None,
) -> Tool:
local_namespace = {} # type: ignore
exec(tool_code, globals(), local_namespace)
async def build_tool(self) -> Tool:
_local_namespace = {} # type: ignore
modules = self._find_imports(self.tool_code)
import_code = ""
for module in modules["imports"]:
import_code += f"global {module}\nimport {module}\n"
for from_module in modules["from_imports"]:
for alias in from_module.names:
import_code += f"global {alias.name}\n"
import_code += (
f"from {from_module.module} import {', '.join([alias.name for alias in from_module.names])}\n"
)
exec(import_code, globals())
exec(self.tool_code, globals(), _local_namespace)
func = local_namespace[tool_function]
_class = None
class PythonCodeToolFunc:
params: dict = {}
if tool_class:
_class = local_namespace[tool_class]
def run(**kwargs):
for key in kwargs:
if key not in PythonCodeToolFunc.params:
PythonCodeToolFunc.params[key] = kwargs[key]
return _local_namespace[self.tool_function](**PythonCodeToolFunc.params)
_globals = globals()
_local = {} # type: ignore
_local[self.tool_function] = PythonCodeToolFunc
_globals.update(_local)
if isinstance(self.global_variables, list):
for data in self.global_variables:
if isinstance(data, Data):
_globals.update(data.data)
elif isinstance(self.global_variables, dict):
_globals.update(self.global_variables)
classes = json.loads(self._attributes["_classes"])
for class_dict in classes:
exec("\n".join(class_dict["code"]), _globals)
named_functions = json.loads(self._attributes["_functions"])
schema_fields = {}
for attr in self._attributes:
if attr in self.DEFAULT_KEYS:
continue
func_name = attr.split("|")[0]
field_name = attr.split("|")[1]
func_arg = self._find_arg(named_functions, func_name, field_name)
if func_arg is None:
raise Exception(f"Failed to find arg: {field_name}")
field_annotation = func_arg["annotation"]
field_description = self._get_value(self._attributes[attr], str)
if field_annotation:
exec(f"temp_annotation_type = {field_annotation}", _globals)
schema_annotation = _globals["temp_annotation_type"]
else:
schema_annotation = Any
schema_fields[field_name] = (
schema_annotation,
Field(
default=func_arg["default"] if "default" in func_arg else Undefined, description=field_description
),
)
if "temp_annotation_type" in _globals:
_globals.pop("temp_annotation_type")
PythonCodeToolSchema = None
if schema_fields:
PythonCodeToolSchema = create_model("PythonCodeToolSchema", **schema_fields) # type: ignore
tool = StructuredTool.from_function(
func=func, args_schema=_class, name=name, description=description, return_direct=return_direct
func=_local[self.tool_function].run,
args_schema=PythonCodeToolSchema,
name=self.tool_name,
description=self.tool_description,
return_direct=self.return_direct,
)
return tool # type: ignore
def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict):
"""
This function is called after the code validation is done.
"""
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
frontend_node["template"] = self.update_build_config(
frontend_node["template"], frontend_node["template"]["tool_code"]["value"], "tool_code"
)
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
for key in frontend_node["template"]:
if key in self.DEFAULT_KEYS:
continue
frontend_node["template"] = self.update_build_config(
frontend_node["template"], frontend_node["template"][key]["value"], key
)
frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node)
return frontend_node
def _parse_code(self, code: str) -> tuple[list[dict], list[dict]]:
parsed_code = ast.parse(code)
lines = code.split("\n")
classes = []
functions = []
for node in parsed_code.body:
if isinstance(node, ast.ClassDef):
class_lines = lines[node.lineno - 1 : node.end_lineno]
class_lines[-1] = class_lines[-1][: node.end_col_offset]
class_lines[0] = class_lines[0][node.col_offset :]
classes.append(
{
"name": node.name,
"code": class_lines,
}
)
continue
if not isinstance(node, ast.FunctionDef):
continue
func = {"name": node.name, "args": []}
for arg in node.args.args:
if arg.lineno != arg.end_lineno:
raise Exception("Multiline arguments are not supported")
func_arg = {
"name": arg.arg,
"annotation": None,
}
for default in node.args.defaults:
if (
arg.lineno > default.lineno
or arg.col_offset > default.col_offset
or arg.end_lineno < default.end_lineno
or arg.end_col_offset < default.end_col_offset
):
continue
if isinstance(default, ast.Name):
func_arg["default"] = default.id
elif isinstance(default, ast.Constant):
func_arg["default"] = default.value
if arg.annotation:
annotation_line = lines[arg.annotation.lineno - 1]
annotation_line = annotation_line[: arg.annotation.end_col_offset]
annotation_line = annotation_line[arg.annotation.col_offset :]
func_arg["annotation"] = annotation_line
if func_arg["annotation"].count("=") > 0:
func_arg["annotation"] = "=".join(func_arg["annotation"].split("=")[:-1]).strip()
func["args"].append(func_arg)
functions.append(func)
return classes, functions
def _find_imports(self, code: str) -> dotdict:
imports = []
from_imports = []
parsed_code = ast.parse(code)
for node in parsed_code.body:
if isinstance(node, ast.Import):
for alias in node.names:
imports.append(alias.name)
elif isinstance(node, ast.ImportFrom):
from_imports.append(node)
return {"imports": imports, "from_imports": from_imports}
def _get_value(self, value: Any, annotation: Any) -> Any:
return value if isinstance(value, annotation) else value["value"]
def _find_arg(self, named_functions: dict, func_name: str, arg_name: str) -> dict | None:
for arg in named_functions[func_name]["args"]:
if arg["name"] == arg_name:
return arg
return None

View file

@ -0,0 +1,141 @@
from typing import Any
import requests
import json
from pydantic.v1 import Field, create_model
from langchain.agents import Tool
from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.inputs import MessageTextInput, MultiselectInput, DropdownInput, IntInput
from langflow.schema.dotdict import dotdict
from langflow.io import Output
class SearXNGToolComponent(LCToolComponent):
search_headers: dict = {}
display_name = "SearXNG Search Tool"
description = "A component that searches for tools using SearXNG."
name = "SearXNGTool"
inputs = [
MessageTextInput(
name="url",
display_name="URL",
value="http://localhost",
required=True,
refresh_button=True,
),
IntInput(
name="max_results",
display_name="Max Results",
value=10,
required=True,
),
MultiselectInput(
name="categories",
display_name="Categories",
options=[],
value=[],
),
DropdownInput(
name="language",
display_name="Language",
options=[],
),
]
outputs = [
Output(display_name="Tool", name="result_tool", method="build_tool"),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict:
if field_name is None:
return build_config
if field_name != "url":
return build_config
try:
url = f"{field_value}/config"
response = requests.get(url=url, headers=self.search_headers.copy())
data = None
if response.headers.get("Content-Encoding") == "zstd":
data = json.loads(response.content)
else:
data = response.json()
build_config["categories"]["options"] = data["categories"].copy()
for selected_category in build_config["categories"]["value"]:
if selected_category not in build_config["categories"]["options"]:
build_config["categories"]["value"].remove(selected_category)
languages = []
for language in data["locales"]:
languages.append(language)
build_config["language"]["options"] = languages.copy()
except Exception as e:
self.status = f"Failed to extract names: {str(e)}"
build_config["categories"]["options"] = ["Failed to parse", str(e)]
return build_config
def build_tool(self) -> Tool:
class SearxSearch:
_url: str = ""
_categories: list[str] = []
_language: str = ""
_headers: dict = {}
_max_results: int = 10
@staticmethod
def search(query: str, categories: list[str] = []) -> list:
if not SearxSearch._categories and not categories:
raise ValueError("No categories provided.")
all_categories = SearxSearch._categories + list(set(categories) - set(SearxSearch._categories))
try:
url = f"{SearxSearch._url}/"
headers = SearxSearch._headers.copy()
response = requests.get(
url=url,
headers=headers,
params={
"q": query,
"categories": ",".join(all_categories),
"language": SearxSearch._language,
"format": "json",
},
).json()
results = []
num_results = min(SearxSearch._max_results, len(response["results"]))
for i in range(num_results):
results.append(response["results"][i])
return results
except Exception as e:
return [f"Failed to search: {str(e)}"]
SearxSearch._url = self.url
SearxSearch._categories = self.categories.copy()
SearxSearch._language = self.language
SearxSearch._headers = self.search_headers.copy()
SearxSearch._max_results = self.max_results
_globals = globals()
_local = {}
_local["SearxSearch"] = SearxSearch
_globals.update(_local)
schema_fields = {
"query": (str, Field(..., description="The query to search for.")),
"categories": (list[str], Field(default=[], description="The categories to search in.")),
}
SearxSearchSchema = create_model("SearxSearchSchema", **schema_fields) # type: ignore
tool = Tool.from_function(
func=_local["SearxSearch"].search,
args_schema=SearxSearchSchema,
name="searxng_search_tool",
description="A tool that searches for tools using SearXNG.\nThe available categories are: "
+ ", ".join(self.categories),
)
self.status = tool
return tool

View file

@ -5,6 +5,7 @@ from .GoogleSearchAPI import GoogleSearchAPIComponent
from .GoogleSerperAPI import GoogleSerperAPIComponent
from .PythonCodeStructuredTool import PythonCodeStructuredTool
from .SearchAPI import SearchAPIComponent
from .SearXNGTool import SearXNGToolComponent
from .SerpAPI import SerpAPIComponent
from .WikipediaAPI import WikipediaAPIComponent
from .WolframAlphaAPI import WolframAlphaAPIComponent
@ -18,6 +19,7 @@ __all__ = [
"PythonCodeStructuredTool",
"PythonREPLToolComponent",
"SearchAPIComponent",
"SearXNGToolComponent",
"SerpAPIComponent",
"WikipediaAPIComponent",
"WolframAlphaAPIComponent",