diff --git a/src/backend/langflow/api/extract_info_from_class.py b/src/backend/langflow/api/extract_info_from_class.py deleted file mode 100644 index 2d76d6604..000000000 --- a/src/backend/langflow/api/extract_info_from_class.py +++ /dev/null @@ -1,124 +0,0 @@ -import ast - - -class ClassCodeExtractor: - def __init__(self, code): - self.code = code - self.function_entrypoint_name = "build" - self.data = { - "imports": [], - "class": { - "inherited_classes": "", - "name": "", - "init": "" - }, - "functions": [] - } - - def _handle_import(self, node): - for alias in node.names: - module_name = getattr(node, 'module', None) - self.data['imports'].append( - f"{module_name}.{alias.name}" if module_name else alias.name) - - def _handle_class(self, node): - self.data['class'].update({ - 'name': node.name, - 'inherited_classes': [ast.unparse(base) for base in node.bases] - }) - - for inner_node in node.body: - if isinstance(inner_node, ast.FunctionDef): - self._handle_function(inner_node) - - def _handle_function(self, node): - function_name = node.name - function_args_str = ast.unparse(node.args) - function_args = function_args_str.split( - ", ") if function_args_str else [] - - return_type = ast.unparse(node.returns) if node.returns else "None" - - function_data = { - "name": function_name, - "arguments": function_args, - "return_type": return_type - } - - if function_name == "__init__": - self.data['class']['init'] = function_args_str.split( - ", ") if function_args_str else [] - else: - self.data["functions"].append(function_data) - - def transform_list(self, input_list): - output_list = [] - for item in input_list: - # Split each item on ':' to separate variable name and type - split_item = item.split(':') - - # If there is a type, strip any leading/trailing spaces from it - if len(split_item) > 1: - split_item[1] = split_item[1].strip() - # If there isn't a type, append None - else: - split_item.append(None) - output_list.append(split_item) - - return output_list - - def extract_class_info(self): - module = ast.parse(self.code) - - for node in module.body: - if isinstance(node, (ast.Import, ast.ImportFrom)): - self._handle_import(node) - elif isinstance(node, ast.ClassDef): - self._handle_class(node) - - return self.data - - def get_entrypoint_function_args_and_return_type(self): - data = self.extract_class_info() - functions = data.get("functions", []) - - build_function = next( - (f for f in functions if f["name"] == - self.function_entrypoint_name), None - ) - - if build_function: - function_args = build_function.get("arguments", None) - function_args = self.transform_list(function_args) - - return_type = build_function.get("return_type", None) - else: - function_args = None - return_type = None - - return function_args, return_type - - -def is_valid_class_template(code: dict): - extractor = ClassCodeExtractor(code) - return_type_valid_list = ["ConversationChain", "Tool"] - - class_name = code.get("class", {}).get("name", None) - if not class_name: # this will also check for None, empty string, etc. - return False - - functions = code.get("functions", []) - # use a generator and next to find if a function matching the criteria exists - build_function = next( - (f for f in functions if f["name"] == - extractor.function_entrypoint_name), None - ) - - if not build_function: - return False - - # Check if the return type of the build function is valid - if build_function.get("return_type") not in return_type_valid_list: - return False - - return True diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index a9868d15e..58a1101b7 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -104,16 +104,12 @@ async def custom_component( raw_code: CustomComponentCode, ): extractor = CustomComponent(code=raw_code.code) - data = extractor.extract_class_info() - valid = extractor.is_valid_class_template(data) - function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type() + if not extractor.is_valid: + print("ERROR") + # TODO: Raise error - return build_langchain_template_custom_component( - raw_code.code, - function_args, - function_return_type - ) + return build_langchain_template_custom_component(extractor) # TODO: Just for test - will be remove diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index c2325378e..843b745fa 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -166,5 +166,28 @@ def get_function(code): def get_function_custom(code): function_name = "build" + class_name = "MyPythonClass" - return validate.create_function(code, function_name) + code = """ +from langchain.chains import ConversationChain + +class MyPythonClass: + def __init__(self, title: str, author: str, year_published: int): + self.title = title + self.author = author + self.year_published = year_published + + def get_details(self): + return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}" + + def update_year_published(self, new_year: int): + self.year_published = new_year + print(f"The year of publication has been updated to {new_year}.") + + def build(self, name, my_int, my_str, my_bool, no_type): + # do something... + print("x") + return "" + """ + + return validate.create_class(code, class_name) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 53679ea9a..33e070d94 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -178,8 +178,9 @@ def instantiate_tool(node_type, class_object, params): params["func"] = get_function(params.get("code")) return class_object(**params) elif node_type == "CustomComponent": - params["func"] = get_function_custom(params.get("code")) - return class_object(**params) + return get_function_custom(params.get("code")) + # params["func"] = get_function_custom(params.get("code")) + # return class_object(**params) # For backward compatibility elif node_type == "PythonFunction": function_string = params["code"] diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index 4b0f6f1ad..522325a9e 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -196,3 +196,21 @@ class CustomComponent(BaseModel): return build_function.get("return_type") in self.return_type_valid_list else: return False + + def get_function(self): + return validate.create_function( + self.code, + self.function_entrypoint_name + ) + + @property + def data(self): + return self.extract_class_info() + + @property + def is_valid(self): + return self.is_valid_class_template(self.data) + + @property + def args_and_return_type(self): + return self.get_entrypoint_function_args_and_return_type() diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 06fa4b257..cd791de3b 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -12,6 +12,7 @@ from langflow.interface.utilities.base import utility_creator from langflow.interface.vector_store.base import vectorstore_creator from langflow.interface.wrappers.base import wrapper_creator from langflow.interface.output_parsers.base import output_parser_creator +from langflow.interface.tools.custom import CustomComponent from langflow.template.field.base import TemplateField from langflow.template.frontend_node.tools import CustomComponentNode @@ -59,18 +60,6 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union return all_types -# TODO: Move to correct place -def find_class_type(class_name, classes_dict): - return next( - ( - {"type": class_type, "class": class_name} - for class_type, class_list in classes_dict.items() - if class_name in class_list - ), - {"error": "class not found"}, - ) - - # TODO: Move to correct place def add_new_custom_field(template, field_name: str, field_type: str): new_field = TemplateField( @@ -107,25 +96,27 @@ def add_code_field(template, raw_code): return template -def build_langchain_template_custom_component(raw_code, function_args, function_return_type): - # type_list = get_type_list() - # type_and_class = find_class_type("Tool", type_list) - # node = get_custom_nodes(node_type: str) +def build_langchain_template_custom_component(extractor: CustomComponent): + # Build base "CustomComponent" template + template = CustomComponentNode().to_dict().get(type(extractor).__name__) - # Build base CustomComponent template - template = CustomComponentNode().to_dict().get('CustomComponent') + function_args, return_type = extractor.args_and_return_type + raw_code = extractor.code # Add extra fields for extra_field in function_args: - if extra_field[0] != 'self': + def_field = extra_field[0] + def_type = extra_field[1] + + if def_field != 'self': # TODO: Validate type - if possible to render into frontend - if not extra_field[1]: - extra_field[1] = 'str' + if not def_type: + def_type = 'str' template = add_new_custom_field( template, - extra_field[0], - extra_field[1] + def_field, + def_type ) template = add_code_field( @@ -133,7 +124,13 @@ def build_langchain_template_custom_component(raw_code, function_args, function_ raw_code ) - # criar um vertex - # olhar loading.py + # TODO: Build a vertex - loading.py + + # TODO: Get base classes from "return_type" and add to template.base_classes + template.get('base_classes').append("ConversationChain") + template.get('base_classes').append("LLMChain") + template.get('base_classes').append("Chain") + template.get('base_classes').append("Serializable") + template.get('base_classes').append("function") return template diff --git a/src/backend/langflow/utils/constants.py b/src/backend/langflow/utils/constants.py index ee03f71da..b9449ecf2 100644 --- a/src/backend/langflow/utils/constants.py +++ b/src/backend/langflow/utils/constants.py @@ -50,27 +50,19 @@ def python_function(text: str) -> str: """ DEFAULT_CUSTOM_COMPONENT_CODE = """ -from langflow.interface.chains.base import ChainCreator -from langflow.interface.tools.base import ToolCreator -from xyz.abc import MyClassA, MyClassB +from langchain.chains import ConversationChain -class MyPythonClass(MyClassA, MyClassB): - def __init__(self, title: str, author: str, year_published: int): - self.title = title - self.author = author - self.year_published = year_published +class MyPythonClass: + def __init__(self, name: str, year: int): + self.name = name + self.year = year def get_details(self): - return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}" + return f"Name: {self.name}, Year: {self.year}" - def update_year_published(self, new_year: int): - self.year_published = new_year - print(f"The year of publication has been updated to {new_year}.") - - def build(self, name: str, my_int: int, my_str: str, my_bool: bool, no_type) -> ConversationChain: + def build(self, name: str, year: int, true_or_false: bool, no_type) -> ConversationChain: # do something... - return ConversationChain() """ diff --git a/src/backend/langflow/utils/validate.py b/src/backend/langflow/utils/validate.py index 905b9dd44..35a831d15 100644 --- a/src/backend/langflow/utils/validate.py +++ b/src/backend/langflow/utils/validate.py @@ -108,7 +108,8 @@ def execute_function(code, function_name, *args, **kwargs): try: exec(code_obj, exec_globals, locals()) except Exception as exc: - raise ValueError("Function string does not contain a function") from exc + raise ValueError( + "Function string does not contain a function") from exc # Add the function to the exec_globals dictionary exec_globals[function_name] = locals()[function_name] @@ -163,6 +164,54 @@ def create_function(code, function_name): return wrapped_function +def create_class(code, class_name): + if not hasattr(ast, "TypeIgnore"): + + class TypeIgnore(ast.AST): + _fields = () + + ast.TypeIgnore = TypeIgnore + + module = ast.parse(code) + exec_globals = globals().copy() + + for node in module.body: + if isinstance(node, ast.Import): + for alias in node.names: + try: + exec_globals[alias.asname or alias.name] = importlib.import_module( + alias.name + ) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Module {alias.name} not found. Please install it and try again." + ) from e + + class_code = next( + node + for node in module.body + if isinstance(node, ast.ClassDef) and node.name == class_name + ) + class_code.parent = None + code_obj = compile( + ast.Module(body=[class_code], type_ignores=[]), "", "exec" + ) + with contextlib.suppress(Exception): + exec(code_obj, exec_globals, locals()) + exec_globals[class_name] = locals()[class_name] + + # Return a function that imports necessary modules and creates an instance of the target class + def build(*args, **kwargs): + for module_name, module in exec_globals.items(): + if isinstance(module, type(importlib)): + globals()[module_name] = module + + instance = exec_globals[class_name](*args, **kwargs) + return instance + + return build + + def extract_function_name(code): module = ast.parse(code) for node in module.body: