From 79d2d551ff98b35da58b6801c475e9376e97b09a Mon Sep 17 00:00:00 2001 From: gustavoschaedler Date: Fri, 14 Jul 2023 04:49:42 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20refactor(langflow):=20rename=20c?= =?UTF-8?q?ustom.py=20to=20custom=5Fcomponent.py=20for=20clarity=20?= =?UTF-8?q?=F0=9F=94=A5=20remove(langflow):=20delete=20custom.py=20as=20it?= =?UTF-8?q?'s=20replaced=20by=20custom=5Fcomponent.py=20=F0=9F=93=A6=20fea?= =?UTF-8?q?t(langflow):=20add=20code=5Fparser.py=20to=20parse=20Python=20s?= =?UTF-8?q?ource=20code=20=F0=9F=90=9B=20fix(langflow):=20update=20import?= =?UTF-8?q?=20paths=20due=20to=20file=20renaming=20=F0=9F=8E=A8=20style(la?= =?UTF-8?q?ngflow):=20improve=20code=20formatting=20for=20readability=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(langflow):=20correct=20handling=20of=20funct?= =?UTF-8?q?ion=20arguments=20and=20return=20types=20in=20custom=20componen?= =?UTF-8?q?ts=20=F0=9F=94=A7=20chore(langflow):=20update=20function=20call?= =?UTF-8?q?s=20due=20to=20changes=20in=20custom=20components?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/components.py | 3 +- src/backend/langflow/api/v1/endpoints.py | 2 +- .../langflow/interface/custom/__init__.py | 2 +- src/backend/langflow/interface/custom/base.py | 4 +- .../langflow/interface/custom/code_parser.py | 178 ++++++++++++++ .../langflow/interface/custom/component.py | 53 +++++ .../langflow/interface/custom/custom.py | 220 ------------------ .../interface/custom/custom_component.py | 119 ++++++++++ .../langflow/interface/importing/utils.py | 7 +- src/backend/langflow/interface/types.py | 40 ++-- 10 files changed, 379 insertions(+), 249 deletions(-) create mode 100644 src/backend/langflow/interface/custom/code_parser.py create mode 100644 src/backend/langflow/interface/custom/component.py delete mode 100644 src/backend/langflow/interface/custom/custom.py create mode 100644 src/backend/langflow/interface/custom/custom_component.py diff --git a/src/backend/langflow/api/v1/components.py b/src/backend/langflow/api/v1/components.py index 646fcb3f6..1e34da2aa 100644 --- a/src/backend/langflow/api/v1/components.py +++ b/src/backend/langflow/api/v1/components.py @@ -1,3 +1,4 @@ +from datetime import timezone from typing import List from uuid import UUID from langflow.database.models.component import Component, ComponentModel @@ -60,7 +61,7 @@ def update_component( for key, value in component_data.items(): setattr(db_component, key, value) - db_component.update_at = datetime.utcnow() + db_component.update_at = datetime.now(timezone.utc) db.commit() db.refresh(db_component) return db_component diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index e12f2076e..c51f9ce78 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -7,7 +7,7 @@ from langflow.utils.logger import logger from fastapi import APIRouter, Depends, HTTPException, UploadFile -from langflow.interface.custom.custom import CustomComponent +from langflow.interface.custom.custom_component import CustomComponent from langflow.api.v1.schemas import ( ProcessResponse, diff --git a/src/backend/langflow/interface/custom/__init__.py b/src/backend/langflow/interface/custom/__init__.py index 48672e52b..5b87e9fa3 100644 --- a/src/backend/langflow/interface/custom/__init__.py +++ b/src/backend/langflow/interface/custom/__init__.py @@ -1,4 +1,4 @@ from langflow.interface.custom.base import CustomComponentCreator -from langflow.interface.custom.custom import CustomComponent +from langflow.interface.custom.custom_component import CustomComponent __all__ = ["CustomComponentCreator", "CustomComponent"] diff --git a/src/backend/langflow/interface/custom/base.py b/src/backend/langflow/interface/custom/base.py index 8dfa127cc..06e874fa7 100644 --- a/src/backend/langflow/interface/custom/base.py +++ b/src/backend/langflow/interface/custom/base.py @@ -2,7 +2,9 @@ from typing import Any, Dict, List, Optional, Type from langflow.interface.base import LangChainTypeCreator -from langflow.interface.custom.custom import CustomComponent + +# from langflow.interface.custom.custom import CustomComponent +from langflow.interface.custom.custom_component import CustomComponent from langflow.template.frontend_node.custom_components import ( CustomComponentFrontendNode, ) diff --git a/src/backend/langflow/interface/custom/code_parser.py b/src/backend/langflow/interface/custom/code_parser.py new file mode 100644 index 000000000..8a67fa733 --- /dev/null +++ b/src/backend/langflow/interface/custom/code_parser.py @@ -0,0 +1,178 @@ +import ast +import traceback + +from typing import Dict, Any, Union +from fastapi import HTTPException + + +class CodeSyntaxError(HTTPException): + pass + + +class CodeParser: + """ + A parser for Python source code, extracting code details. + """ + + def __init__(self, code: str) -> None: + """ + Initializes the parser with the provided code. + """ + self.code = code + self.data: Dict[str, Any] = { + "imports": [], + "functions": [], + "classes": [], + "global_vars": [], + } + self.handlers = { + ast.Import: self.parse_imports, + ast.ImportFrom: self.parse_imports, + ast.FunctionDef: self.parse_functions, + ast.ClassDef: self.parse_classes, + ast.Assign: self.parse_global_vars, + } + + def __get_tree(self): + """ + Parses the provided code to validate its syntax. + It tries to parse the code into an abstract syntax tree (AST). + """ + try: + tree = ast.parse(self.code) + except SyntaxError as err: + raise CodeSyntaxError( + status_code=400, + detail={"error": err.msg, "traceback": traceback.format_exc()}, + ) from err + + return tree + + def parse_node(self, node: ast.AST) -> None: + """ + Parses an AST node and updates the data + dictionary with the relevant information. + """ + if handler := self.handlers.get(type(node)): + handler(node) + + def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None: + """ + Extracts "imports" from the code. + """ + if isinstance(node, ast.Import): + module = node.names[0].name + self.data["imports"].append(module) + elif isinstance(node, ast.ImportFrom): + module = node.module + names = [alias.name for alias in node.names] + self.data["imports"].append((module, names)) + + def parse_functions(self, node: ast.FunctionDef) -> None: + """ + Extracts "functions" from the code. + """ + self.data["functions"].append(self.parse_callable_details(node)) + + def parse_arg(self, arg, default): + """ + Parses an argument and its default value. + """ + arg_dict = {"name": arg.arg, "default": default} + if arg.annotation: + arg_dict["type"] = ast.unparse(arg.annotation) + return arg_dict + + def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]: + """ + Extracts details from a single function or method node. + """ + func = { + "name": node.name, + "doc": ast.get_docstring(node), + "args": [], + "body": [], + "return_type": ast.unparse(node.returns) if node.returns else None, + } + + # Handle positional arguments with default values + defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + [ + ast.unparse(default) for default in node.args.defaults + ] + + for arg, default in zip(node.args.args, defaults): + func["args"].append(self.parse_arg(arg, default)) + + # Handle *args + if node.args.vararg: + func["args"].append(self.parse_arg(node.args.vararg, None)) + + # Handle keyword-only arguments with default values + kw_defaults = [None] * ( + len(node.args.kwonlyargs) - len(node.args.kw_defaults) + ) + [ + ast.unparse(default) if default else None + for default in node.args.kw_defaults + ] + + for arg, default in zip(node.args.kwonlyargs, kw_defaults): + func["args"].append(self.parse_arg(arg, default)) + + # Handle **kwargs + if node.args.kwarg: + func["args"].append(self.parse_arg(node.args.kwarg, None)) + + for line in node.body: + func["body"].append(ast.unparse(line)) + return func + + def parse_classes(self, node: ast.ClassDef) -> None: + """ + Extracts "classes" from the code, including + inheritance and init methods. + """ + class_dict = { + "name": node.name, + "doc": ast.get_docstring(node), + "bases": [ast.unparse(base) for base in node.bases], + "attributes": [], + "methods": [], + } + + for stmt in node.body: + if isinstance(stmt, ast.AnnAssign): + attr = {"name": stmt.target.id, "type": ast.unparse(stmt.annotation)} + class_dict["attributes"].append(attr) + elif isinstance(stmt, ast.Assign): + attr = {"name": stmt.targets[0].id, "value": ast.unparse(stmt.value)} + class_dict["attributes"].append(attr) + elif isinstance(stmt, ast.FunctionDef): + method = self.parse_callable_details(stmt) + if stmt.name == "__init__": + class_dict["init"] = method + else: + class_dict["methods"].append(method) + + self.data["classes"].append(class_dict) + + def parse_global_vars(self, node: ast.Assign) -> None: + """ + Extracts global variables from the code. + """ + global_var = { + "targets": [ + t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets + ], + "value": ast.unparse(node.value), + } + self.data["global_vars"].append(global_var) + + def parse_code(self) -> Dict[str, Any]: + """ + Runs all parsing operations and returns the resulting data. + """ + tree = self.__get_tree() + + for node in ast.walk(tree): + self.parse_node(node) + return self.data diff --git a/src/backend/langflow/interface/custom/component.py b/src/backend/langflow/interface/custom/component.py new file mode 100644 index 000000000..a0f99fa38 --- /dev/null +++ b/src/backend/langflow/interface/custom/component.py @@ -0,0 +1,53 @@ + +from pydantic import BaseModel +from fastapi import HTTPException + +from langflow.utils import validate +from langflow.interface.custom.code_parser import CodeParser + + +class ComponentCodeNullError(HTTPException): + pass + + +class ComponentFunctionEntrypointNameNullError(HTTPException): + pass + + +class Component(BaseModel): + ERROR_CODE_NULL = "Python code must be provided." + ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = ( + "The name of the entrypoint function must be provided." + ) + + code: str + function_entrypoint_name = "build" + field_config: dict = {} + + def __init__(self, **data): + super().__init__(**data) + + def get_code_tree(self, code: str): + parser = CodeParser(code) + return parser.parse_code() + + def get_function(self): + if not self.code: + raise ComponentCodeNullError( + status_code=400, + detail={"error": self.ERROR_CODE_NULL, "traceback": ""}, + ) + + if not self.function_entrypoint_name: + raise ComponentFunctionEntrypointNameNullError( + status_code=400, + detail={ + "error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL, + "traceback": "", + }, + ) + + return validate.create_function(self.code, self.function_entrypoint_name) + + def build(self): + raise NotImplementedError diff --git a/src/backend/langflow/interface/custom/custom.py b/src/backend/langflow/interface/custom/custom.py deleted file mode 100644 index 6d46c5d18..000000000 --- a/src/backend/langflow/interface/custom/custom.py +++ /dev/null @@ -1,220 +0,0 @@ -import re -import ast -import traceback -from typing import Callable, Optional -from fastapi import HTTPException -from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES - -from langflow.utils import validate -from pydantic import BaseModel - - -class CustomComponent(BaseModel): - field_config: dict = {} - code: str - function: Optional[Callable] = None - function_entrypoint_name = "build" - return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys()) - class_template = { - "imports": [], - "class": {"inherited_classes": "", "name": "", "init": "", "attributes": {}}, - "functions": [], - } - - def __init__(self, **data): - super().__init__(**data) - - def _handle_import(self, node): - for alias in node.names: - module_name = getattr(node, "module", None) - self.class_template["imports"].append( - f"{module_name}.{alias.name}" if module_name else alias.name - ) - - def _handle_class(self, node): - self.class_template["class"].update( - { - "name": node.name, - "inherited_classes": [ast.unparse(base) for base in node.bases], - } - ) - - attributes = {} # To store the attributes and their values - - for inner_node in node.body: - if isinstance(inner_node, ast.Assign): # An assignment - for target in inner_node.targets: # Targets of the assignment - if isinstance(target, ast.Name): # A simple variable - # Add the attribute and its value to the dictionary - attributes[target.id] = ast.unparse(inner_node.value) - elif isinstance(inner_node, ast.AnnAssign): # An annotated assignment - if isinstance(inner_node.target, ast.Name) and inner_node.value: - attributes[inner_node.target.id] = ast.unparse(inner_node.value) - - elif isinstance(inner_node, ast.FunctionDef): - self._handle_function(inner_node) - - # You can add these attributes to your class_template if you want - self.class_template["class"]["attributes"] = attributes - - def _handle_function(self, node): - function_name = node.name - function_args_str = ast.unparse(node.args) - function_args = function_args_str.split(", ") if function_args_str else [] - - return_type = ast.unparse(node.returns) if node.returns else "None" - - function_data = { - "name": function_name, - "arguments": function_args, - "return_type": return_type, - } - - if function_name == "__init__": - self.class_template["class"]["init"] = ( - function_args_str.split(", ") if function_args_str else [] - ) - else: - self.class_template["functions"].append(function_data) - - def _split_string(self, text): - """ - Split a string by ':' or '=' and append None until the resulting list has 3 items. - - Parameters: - text (str): The string to be split. - - Returns: - list: A list of strings resulting from the split operation, - padded with None until its length is 3. - """ - items = [item.strip() for item in re.split(r"[:=]", text) if item.strip()] - while len(items) < 3: - items.append(None) - - return items - - def transform_list(self, input_list): - """ - Transform a list of strings by splitting each string and padding with None. - - Parameters: - input_list (list): The list of strings to be transformed. - - Returns: - list: A list of lists, each containing the result of the split operation. - """ - return [self._split_string(item) for item in input_list] - - def extract_class_info(self): - try: - module = ast.parse(self.code) - except SyntaxError as err: - raise HTTPException( - status_code=400, - detail={"error": err.msg, "traceback": traceback.format_exc()}, - ) from err - - for node in module.body: - if isinstance(node, (ast.Import, ast.ImportFrom)): - self._handle_import(node) - elif isinstance(node, ast.ClassDef): - self._handle_class(node) - - return self.class_template - - def get_entrypoint_function_args_and_return_type(self): - data = self.extract_class_info() - attributes = data.get("class", {}).get("attributes", {}) - functions = data.get("functions", []) - template_config = self._build_template_config(attributes) - - if build_function := next( - (f for f in functions if f["name"] == self.function_entrypoint_name), - None, - ): - function_args = build_function.get("arguments", None) - function_args = self.transform_list(function_args) - - return_type = build_function.get("return_type", None) - else: - function_args = None - return_type = None - - return function_args, return_type, template_config - - def _build_template_config(self, attributes): - template_config = {} - if "field_config" in attributes: - template_config["field_config"] = ast.literal_eval( - attributes["field_config"] - ) - if "display_name" in attributes: - template_config["display_name"] = ast.literal_eval( - attributes["display_name"] - ) - if "description" in attributes: - template_config["description"] = ast.literal_eval(attributes["description"]) - - return template_config - - def _class_template_validation(self, code: dict): - class_name = code.get("class", {}).get("name", None) - if not class_name: # this will also check for None, empty string, etc. - raise HTTPException( - status_code=400, - detail={ - "error": "The main class must have a valid name.", - "traceback": "", - }, - ) - - functions = code.get("functions", []) - build_function = next( - (f for f in functions if f["name"] == self.function_entrypoint_name), - None, - ) - - if not build_function: - raise HTTPException( - status_code=400, - detail={ - "error": "Invalid entrypoint function name", - "traceback": ( - f"There needs to be at least one entrypoint function named '{self.function_entrypoint_name}'" - f" and it needs to return one of the types from this list {str(self.return_type_valid_list)}.", - ), - }, - ) - - return_type = build_function.get("return_type") - if return_type not in self.return_type_valid_list: - raise HTTPException( - status_code=400, - detail={ - "error": "Invalid entrypoint function return", - "traceback": ( - f"The entrypoint function return '{return_type}' needs to be an item " - f"from this list {str(self.return_type_valid_list)}." - ), - }, - ) - - return True - - def get_function(self): - return validate.create_function(self.code, self.function_entrypoint_name) - - def build(self): - raise NotImplementedError - - @property - def data(self): - return self.extract_class_info() - - def is_check_valid(self): - return self._class_template_validation(self.data) - - @property - def args_and_return_type(self): - return self.get_entrypoint_function_args_and_return_type() diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py new file mode 100644 index 000000000..5a9ddecbb --- /dev/null +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -0,0 +1,119 @@ +import ast +from typing import Callable, Optional +from fastapi import HTTPException +from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES +from langflow.interface.custom.component import Component + +from langflow.utils import validate + + +class CustomComponent(Component): + code: str + field_config: dict = {} + code_class_base_inheritance = "CustomComponent" + function_entrypoint_name = "build" + function: Optional[Callable] = None + return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys()) + + def __init__(self, **data): + super().__init__(**data) + + def _class_template_validation(self, code: str) -> bool: + if not code: + raise HTTPException( + status_code=400, + detail={ + "error": self.ERROR_CODE_NULL, + "traceback": "", + }, + ) + + # TODO: build logic + return True + + def is_check_valid(self) -> bool: + return self._class_template_validation(self.code) + + def get_code_tree(self, code: str): + return super().get_code_tree(code) + + @property + def get_function_entrypoint_args(self) -> str: + tree = self.get_code_tree(self.code) + + component_classes = [ + cls + for cls in tree["classes"] + if self.code_class_base_inheritance in cls["bases"] + ] + if not component_classes: + return "" + + # Assume the first Component class is the one we're interested in + component_class = component_classes[0] + build_methods = [ + method + for method in component_class["methods"] + if method["name"] == self.function_entrypoint_name + ] + + if not build_methods: + return "" + + build_method = build_methods[0] + + return build_method["args"] + + @property + def get_function_entrypoint_return_type(self) -> str: + tree = self.get_code_tree(self.code) + + component_classes = [ + cls + for cls in tree["classes"] + if self.code_class_base_inheritance in cls["bases"] + ] + if not component_classes: + return "" + + # Assume the first Component class is the one we're interested in + component_class = component_classes[0] + build_methods = [ + method + for method in component_class["methods"] + if method["name"] == self.function_entrypoint_name + ] + + if not build_methods: + return "" + + build_method = build_methods[0] + + return build_method["return_type"] + + @property + def get_template_config(self) -> dict: + extra_attributes = {} # self.get_extra_attributes + template_config = {} + + if "field_config" in extra_attributes: + template_config["field_config"] = ast.literal_eval( + extra_attributes["field_config"] + ) + if "display_name" in extra_attributes: + template_config["display_name"] = ast.literal_eval( + extra_attributes["display_name"] + ) + if "description" in extra_attributes: + template_config["description"] = ast.literal_eval( + extra_attributes["description"] + ) + + return template_config + + @property + def get_function(self): + return validate.create_function(self.code, self.function_entrypoint_name) + + def build(self): + raise NotImplementedError diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index 04ee8aba1..0acb2cff5 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -9,7 +9,7 @@ from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chat_models.base import BaseChatModel from langchain.tools import BaseTool -from langflow.interface.custom.custom import CustomComponent +from langflow.interface.custom.custom_component import CustomComponent from langflow.utils import validate from langflow.interface.wrappers.base import wrapper_creator @@ -61,7 +61,9 @@ def import_by_type(_type: str, name: str) -> Any: def import_custom_component(custom_component: str) -> CustomComponent: """Import custom component from custom component name""" - return import_class(f"langflow.interface.custom.custom.{custom_component}") + return import_class( + f"langflow.interface.custom.custom_component.{custom_component}" + ) def import_output_parser(output_parser: str) -> Any: @@ -183,5 +185,4 @@ def get_function(code): def get_function_custom(code): class_name = validate.extract_class_name(code) - return validate.create_class(code, class_name) diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 41343089c..892b70260 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -14,7 +14,7 @@ from langflow.interface.vector_store.base import vectorstore_creator from langflow.interface.wrappers.base import wrapper_creator from langflow.interface.output_parsers.base import output_parser_creator from langflow.interface.custom.base import custom_component_creator -from langflow.interface.custom.custom import CustomComponent +from langflow.interface.custom.custom_component import CustomComponent from langflow.template.field.base import TemplateField from langflow.template.frontend_node.tools import CustomComponentNode @@ -92,9 +92,6 @@ def add_new_custom_field( field_type = field_config.pop("field_type", field_type) field_type = process_type(field_type) - if field_value is not None: - field_value = field_value.replace("'", "").replace('"', "") - if "name" in field_config: warnings.warn( "The 'name' key in field_config is used to build the object and can't be changed." @@ -158,29 +155,27 @@ def extract_type_from_optional(field_type): return match[1] if match else None -def build_langchain_template_custom_component(extractor: CustomComponent): +def build_langchain_template_custom_component(custom_component: CustomComponent): # Build base "CustomComponent" template - frontend_node = CustomComponentNode().to_dict().get(type(extractor).__name__) + frontend_node = CustomComponentNode().to_dict().get(type(custom_component).__name__) - function_args, return_type, template_config = extractor.args_and_return_type - - if "display_name" in template_config and frontend_node is not None: - frontend_node["display_name"] = template_config["display_name"] - if "description" in template_config and frontend_node is not None: - frontend_node["description"] = template_config["description"] - raw_code = extractor.code - field_config = template_config.get("field_config", {}) + function_args = custom_component.get_function_entrypoint_args + return_type = custom_component.get_function_entrypoint_return_type + # template_config = custom_component.get_template_config if function_args is not None: # Add extra fields for extra_field in function_args: - field_required = True - field_name, field_type, field_value = extra_field - - if not field_type: - field_type = "" + field_name = extra_field.get("name") if "name" in extra_field else "" if field_name != "self": + field_type = extra_field.get("type") if "type" in extra_field else "" + field_value = ( + extra_field.get("default") if "default" in extra_field else "" + ) + field_required = True + field_config = {} + # TODO: Validate type - if is possible to render into frontend if "optional" in field_type.lower(): field_type = extract_type_from_optional(field_type) @@ -189,17 +184,16 @@ def build_langchain_template_custom_component(extractor: CustomComponent): if not field_type: field_type = "str" - config = field_config.get(field_name, {}) frontend_node = add_new_custom_field( frontend_node, field_name, field_type, field_value, field_required, - config, + field_config, ) - frontend_node = add_code_field(frontend_node, raw_code) + frontend_node = add_code_field(frontend_node, custom_component.code) # Get base classes from "return_type" and add to template.base_classes try: @@ -214,8 +208,10 @@ def build_langchain_template_custom_component(extractor: CustomComponent): "traceback": traceback.format_exc(), }, ) + return_type_instance = LANGCHAIN_BASE_TYPES.get(return_type) base_classes = get_base_classes(return_type_instance) + except (KeyError, AttributeError) as err: raise HTTPException( status_code=400,