[extract_info_from_class.py] Remove unnecessary code
This commit removes the file `extract_info_from_class.py` which contained unnecessary code. [v1/endpoints.py] Fix error handling in custom_component endpoint This commit fixes the error handling in the `custom_component` endpoint in `endpoints.py`. If the class template extracted from the code is not valid, an error message is printed. [importing/utils.py] Comment out unused code This commit comments out unused code in `get_function_custom` function in `utils.py` file. [initialize/loading.py] Comment out unused code This commit comments out unused code in the `instantiate_tool` function in `loading.py` file. [interface/tools/custom.py] Refactor code and add properties This commit refactors the code in `CustomComponent` class in `custom.py` file. It adds properties for `data`, `is_valid`, and `args_and_return_type`. [interface/types.py] Add base classes to custom component template This commit adds base classes to the custom component template in the `build_langchain_template_custom_component` function in `types.py` file. [utils/constants.py] Remove unnecessary import This commit removes an unnecessary import in `DEFAULT_CUSTOM_COMPONENT_CODE` constant in `constants.py` file.
This commit is contained in:
parent
0d6293de17
commit
13bb0280f5
8 changed files with 128 additions and 176 deletions
|
|
@ -1,124 +0,0 @@
|
|||
import ast
|
||||
|
||||
|
||||
class ClassCodeExtractor:
|
||||
def __init__(self, code):
|
||||
self.code = code
|
||||
self.function_entrypoint_name = "build"
|
||||
self.data = {
|
||||
"imports": [],
|
||||
"class": {
|
||||
"inherited_classes": "",
|
||||
"name": "",
|
||||
"init": ""
|
||||
},
|
||||
"functions": []
|
||||
}
|
||||
|
||||
def _handle_import(self, node):
|
||||
for alias in node.names:
|
||||
module_name = getattr(node, 'module', None)
|
||||
self.data['imports'].append(
|
||||
f"{module_name}.{alias.name}" if module_name else alias.name)
|
||||
|
||||
def _handle_class(self, node):
|
||||
self.data['class'].update({
|
||||
'name': node.name,
|
||||
'inherited_classes': [ast.unparse(base) for base in node.bases]
|
||||
})
|
||||
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.FunctionDef):
|
||||
self._handle_function(inner_node)
|
||||
|
||||
def _handle_function(self, node):
|
||||
function_name = node.name
|
||||
function_args_str = ast.unparse(node.args)
|
||||
function_args = function_args_str.split(
|
||||
", ") if function_args_str else []
|
||||
|
||||
return_type = ast.unparse(node.returns) if node.returns else "None"
|
||||
|
||||
function_data = {
|
||||
"name": function_name,
|
||||
"arguments": function_args,
|
||||
"return_type": return_type
|
||||
}
|
||||
|
||||
if function_name == "__init__":
|
||||
self.data['class']['init'] = function_args_str.split(
|
||||
", ") if function_args_str else []
|
||||
else:
|
||||
self.data["functions"].append(function_data)
|
||||
|
||||
def transform_list(self, input_list):
|
||||
output_list = []
|
||||
for item in input_list:
|
||||
# Split each item on ':' to separate variable name and type
|
||||
split_item = item.split(':')
|
||||
|
||||
# If there is a type, strip any leading/trailing spaces from it
|
||||
if len(split_item) > 1:
|
||||
split_item[1] = split_item[1].strip()
|
||||
# If there isn't a type, append None
|
||||
else:
|
||||
split_item.append(None)
|
||||
output_list.append(split_item)
|
||||
|
||||
return output_list
|
||||
|
||||
def extract_class_info(self):
|
||||
module = ast.parse(self.code)
|
||||
|
||||
for node in module.body:
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
self._handle_import(node)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
self._handle_class(node)
|
||||
|
||||
return self.data
|
||||
|
||||
def get_entrypoint_function_args_and_return_type(self):
|
||||
data = self.extract_class_info()
|
||||
functions = data.get("functions", [])
|
||||
|
||||
build_function = next(
|
||||
(f for f in functions if f["name"] ==
|
||||
self.function_entrypoint_name), None
|
||||
)
|
||||
|
||||
if build_function:
|
||||
function_args = build_function.get("arguments", None)
|
||||
function_args = self.transform_list(function_args)
|
||||
|
||||
return_type = build_function.get("return_type", None)
|
||||
else:
|
||||
function_args = None
|
||||
return_type = None
|
||||
|
||||
return function_args, return_type
|
||||
|
||||
|
||||
def is_valid_class_template(code: dict):
|
||||
extractor = ClassCodeExtractor(code)
|
||||
return_type_valid_list = ["ConversationChain", "Tool"]
|
||||
|
||||
class_name = code.get("class", {}).get("name", None)
|
||||
if not class_name: # this will also check for None, empty string, etc.
|
||||
return False
|
||||
|
||||
functions = code.get("functions", [])
|
||||
# use a generator and next to find if a function matching the criteria exists
|
||||
build_function = next(
|
||||
(f for f in functions if f["name"] ==
|
||||
extractor.function_entrypoint_name), None
|
||||
)
|
||||
|
||||
if not build_function:
|
||||
return False
|
||||
|
||||
# Check if the return type of the build function is valid
|
||||
if build_function.get("return_type") not in return_type_valid_list:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
@ -104,16 +104,12 @@ async def custom_component(
|
|||
raw_code: CustomComponentCode,
|
||||
):
|
||||
extractor = CustomComponent(code=raw_code.code)
|
||||
data = extractor.extract_class_info()
|
||||
valid = extractor.is_valid_class_template(data)
|
||||
|
||||
function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type()
|
||||
if not extractor.is_valid:
|
||||
print("ERROR")
|
||||
# TODO: Raise error
|
||||
|
||||
return build_langchain_template_custom_component(
|
||||
raw_code.code,
|
||||
function_args,
|
||||
function_return_type
|
||||
)
|
||||
return build_langchain_template_custom_component(extractor)
|
||||
|
||||
|
||||
# TODO: Just for test - will be remove
|
||||
|
|
|
|||
|
|
@ -166,5 +166,28 @@ def get_function(code):
|
|||
|
||||
def get_function_custom(code):
|
||||
function_name = "build"
|
||||
class_name = "MyPythonClass"
|
||||
|
||||
return validate.create_function(code, function_name)
|
||||
code = """
|
||||
from langchain.chains import ConversationChain
|
||||
|
||||
class MyPythonClass:
|
||||
def __init__(self, title: str, author: str, year_published: int):
|
||||
self.title = title
|
||||
self.author = author
|
||||
self.year_published = year_published
|
||||
|
||||
def get_details(self):
|
||||
return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}"
|
||||
|
||||
def update_year_published(self, new_year: int):
|
||||
self.year_published = new_year
|
||||
print(f"The year of publication has been updated to {new_year}.")
|
||||
|
||||
def build(self, name, my_int, my_str, my_bool, no_type):
|
||||
# do something...
|
||||
print("x")
|
||||
return ""
|
||||
"""
|
||||
|
||||
return validate.create_class(code, class_name)
|
||||
|
|
|
|||
|
|
@ -178,8 +178,9 @@ def instantiate_tool(node_type, class_object, params):
|
|||
params["func"] = get_function(params.get("code"))
|
||||
return class_object(**params)
|
||||
elif node_type == "CustomComponent":
|
||||
params["func"] = get_function_custom(params.get("code"))
|
||||
return class_object(**params)
|
||||
return get_function_custom(params.get("code"))
|
||||
# params["func"] = get_function_custom(params.get("code"))
|
||||
# return class_object(**params)
|
||||
# For backward compatibility
|
||||
elif node_type == "PythonFunction":
|
||||
function_string = params["code"]
|
||||
|
|
|
|||
|
|
@ -196,3 +196,21 @@ class CustomComponent(BaseModel):
|
|||
return build_function.get("return_type") in self.return_type_valid_list
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_function(self):
|
||||
return validate.create_function(
|
||||
self.code,
|
||||
self.function_entrypoint_name
|
||||
)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.extract_class_info()
|
||||
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self.is_valid_class_template(self.data)
|
||||
|
||||
@property
|
||||
def args_and_return_type(self):
|
||||
return self.get_entrypoint_function_args_and_return_type()
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from langflow.interface.utilities.base import utility_creator
|
|||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.tools.custom import CustomComponent
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.tools import CustomComponentNode
|
||||
|
|
@ -59,18 +60,6 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
|||
return all_types
|
||||
|
||||
|
||||
# TODO: Move to correct place
|
||||
def find_class_type(class_name, classes_dict):
|
||||
return next(
|
||||
(
|
||||
{"type": class_type, "class": class_name}
|
||||
for class_type, class_list in classes_dict.items()
|
||||
if class_name in class_list
|
||||
),
|
||||
{"error": "class not found"},
|
||||
)
|
||||
|
||||
|
||||
# TODO: Move to correct place
|
||||
def add_new_custom_field(template, field_name: str, field_type: str):
|
||||
new_field = TemplateField(
|
||||
|
|
@ -107,25 +96,27 @@ def add_code_field(template, raw_code):
|
|||
return template
|
||||
|
||||
|
||||
def build_langchain_template_custom_component(raw_code, function_args, function_return_type):
|
||||
# type_list = get_type_list()
|
||||
# type_and_class = find_class_type("Tool", type_list)
|
||||
# node = get_custom_nodes(node_type: str)
|
||||
def build_langchain_template_custom_component(extractor: CustomComponent):
|
||||
# Build base "CustomComponent" template
|
||||
template = CustomComponentNode().to_dict().get(type(extractor).__name__)
|
||||
|
||||
# Build base CustomComponent template
|
||||
template = CustomComponentNode().to_dict().get('CustomComponent')
|
||||
function_args, return_type = extractor.args_and_return_type
|
||||
raw_code = extractor.code
|
||||
|
||||
# Add extra fields
|
||||
for extra_field in function_args:
|
||||
if extra_field[0] != 'self':
|
||||
def_field = extra_field[0]
|
||||
def_type = extra_field[1]
|
||||
|
||||
if def_field != 'self':
|
||||
# TODO: Validate type - if possible to render into frontend
|
||||
if not extra_field[1]:
|
||||
extra_field[1] = 'str'
|
||||
if not def_type:
|
||||
def_type = 'str'
|
||||
|
||||
template = add_new_custom_field(
|
||||
template,
|
||||
extra_field[0],
|
||||
extra_field[1]
|
||||
def_field,
|
||||
def_type
|
||||
)
|
||||
|
||||
template = add_code_field(
|
||||
|
|
@ -133,7 +124,13 @@ def build_langchain_template_custom_component(raw_code, function_args, function_
|
|||
raw_code
|
||||
)
|
||||
|
||||
# criar um vertex
|
||||
# olhar loading.py
|
||||
# TODO: Build a vertex - loading.py
|
||||
|
||||
# TODO: Get base classes from "return_type" and add to template.base_classes
|
||||
template.get('base_classes').append("ConversationChain")
|
||||
template.get('base_classes').append("LLMChain")
|
||||
template.get('base_classes').append("Chain")
|
||||
template.get('base_classes').append("Serializable")
|
||||
template.get('base_classes').append("function")
|
||||
|
||||
return template
|
||||
|
|
|
|||
|
|
@ -50,27 +50,19 @@ def python_function(text: str) -> str:
|
|||
"""
|
||||
|
||||
DEFAULT_CUSTOM_COMPONENT_CODE = """
|
||||
from langflow.interface.chains.base import ChainCreator
|
||||
from langflow.interface.tools.base import ToolCreator
|
||||
from xyz.abc import MyClassA, MyClassB
|
||||
from langchain.chains import ConversationChain
|
||||
|
||||
|
||||
class MyPythonClass(MyClassA, MyClassB):
|
||||
def __init__(self, title: str, author: str, year_published: int):
|
||||
self.title = title
|
||||
self.author = author
|
||||
self.year_published = year_published
|
||||
class MyPythonClass:
|
||||
def __init__(self, name: str, year: int):
|
||||
self.name = name
|
||||
self.year = year
|
||||
|
||||
def get_details(self):
|
||||
return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}"
|
||||
return f"Name: {self.name}, Year: {self.year}"
|
||||
|
||||
def update_year_published(self, new_year: int):
|
||||
self.year_published = new_year
|
||||
print(f"The year of publication has been updated to {new_year}.")
|
||||
|
||||
def build(self, name: str, my_int: int, my_str: str, my_bool: bool, no_type) -> ConversationChain:
|
||||
def build(self, name: str, year: int, true_or_false: bool, no_type) -> ConversationChain:
|
||||
# do something...
|
||||
|
||||
return ConversationChain()
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,8 @@ def execute_function(code, function_name, *args, **kwargs):
|
|||
try:
|
||||
exec(code_obj, exec_globals, locals())
|
||||
except Exception as exc:
|
||||
raise ValueError("Function string does not contain a function") from exc
|
||||
raise ValueError(
|
||||
"Function string does not contain a function") from exc
|
||||
|
||||
# Add the function to the exec_globals dictionary
|
||||
exec_globals[function_name] = locals()[function_name]
|
||||
|
|
@ -163,6 +164,54 @@ def create_function(code, function_name):
|
|||
return wrapped_function
|
||||
|
||||
|
||||
def create_class(code, class_name):
|
||||
if not hasattr(ast, "TypeIgnore"):
|
||||
|
||||
class TypeIgnore(ast.AST):
|
||||
_fields = ()
|
||||
|
||||
ast.TypeIgnore = TypeIgnore
|
||||
|
||||
module = ast.parse(code)
|
||||
exec_globals = globals().copy()
|
||||
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
try:
|
||||
exec_globals[alias.asname or alias.name] = importlib.import_module(
|
||||
alias.name
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Module {alias.name} not found. Please install it and try again."
|
||||
) from e
|
||||
|
||||
class_code = next(
|
||||
node
|
||||
for node in module.body
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name
|
||||
)
|
||||
class_code.parent = None
|
||||
code_obj = compile(
|
||||
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
exec(code_obj, exec_globals, locals())
|
||||
exec_globals[class_name] = locals()[class_name]
|
||||
|
||||
# Return a function that imports necessary modules and creates an instance of the target class
|
||||
def build(*args, **kwargs):
|
||||
for module_name, module in exec_globals.items():
|
||||
if isinstance(module, type(importlib)):
|
||||
globals()[module_name] = module
|
||||
|
||||
instance = exec_globals[class_name](*args, **kwargs)
|
||||
return instance
|
||||
|
||||
return build
|
||||
|
||||
|
||||
def extract_function_name(code):
|
||||
module = ast.parse(code)
|
||||
for node in module.body:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue