diff --git a/src/backend/langflow/interface/custom/code_parser.py b/src/backend/langflow/interface/custom/code_parser.py index 8a67fa733..e86c12cef 100644 --- a/src/backend/langflow/interface/custom/code_parser.py +++ b/src/backend/langflow/interface/custom/code_parser.py @@ -61,12 +61,11 @@ class CodeParser: Extracts "imports" from the code. """ if isinstance(node, ast.Import): - module = node.names[0].name - self.data["imports"].append(module) + for alias in node.names: + self.data["imports"].append(alias.name) elif isinstance(node, ast.ImportFrom): - module = node.module - names = [alias.name for alias in node.names] - self.data["imports"].append((module, names)) + for alias in node.names: + self.data["imports"].append((node.module, alias.name)) def parse_functions(self, node: ast.FunctionDef) -> None: """ @@ -97,7 +96,7 @@ class CodeParser: # Handle positional arguments with default values defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + [ - ast.unparse(default) for default in node.args.defaults + ast.unparse(default) if default else None for default in node.args.defaults ] for arg, default in zip(node.args.args, defaults): @@ -126,10 +125,38 @@ class CodeParser: func["body"].append(ast.unparse(line)) return func + def parse_assign(self, stmt): + """ + Parses an Assign statement and returns a dictionary + with the target's name and value. + """ + for target in stmt.targets: + if isinstance(target, ast.Name): + return {"name": target.id, "value": ast.unparse(stmt.value)} + + def parse_ann_assign(self, stmt): + """ + Parses an AnnAssign statement and returns a dictionary + with the target's name, value, and annotation. + """ + if isinstance(stmt.target, ast.Name): + return { + "name": stmt.target.id, + "value": ast.unparse(stmt.value) if stmt.value else None, + "annotation": ast.unparse(stmt.annotation), + } + + def parse_function_def(self, stmt): + """ + Parses a FunctionDef statement and returns the parsed + method and a boolean indicating if it's an __init__ method. + """ + method = self.parse_callable_details(stmt) + return (method, True) if stmt.name == "__init__" else (method, False) + def parse_classes(self, node: ast.ClassDef) -> None: """ - Extracts "classes" from the code, including - inheritance and init methods. + Extracts "classes" from the code, including inheritance and init methods. """ class_dict = { "name": node.name, @@ -140,15 +167,15 @@ class CodeParser: } for stmt in node.body: - if isinstance(stmt, ast.AnnAssign): - attr = {"name": stmt.target.id, "type": ast.unparse(stmt.annotation)} - class_dict["attributes"].append(attr) - elif isinstance(stmt, ast.Assign): - attr = {"name": stmt.targets[0].id, "value": ast.unparse(stmt.value)} - class_dict["attributes"].append(attr) + if isinstance(stmt, ast.Assign): + if attr := self.parse_assign(stmt): + class_dict["attributes"].append(attr) + elif isinstance(stmt, ast.AnnAssign): + if attr := self.parse_ann_assign(stmt): + class_dict["attributes"].append(attr) elif isinstance(stmt, ast.FunctionDef): - method = self.parse_callable_details(stmt) - if stmt.name == "__init__": + method, is_init = self.parse_function_def(stmt) + if is_init: class_dict["init"] = method else: class_dict["methods"].append(method) diff --git a/src/backend/langflow/interface/custom/component.py b/src/backend/langflow/interface/custom/component.py index f6ef62802..5e84c235e 100644 --- a/src/backend/langflow/interface/custom/component.py +++ b/src/backend/langflow/interface/custom/component.py @@ -1,3 +1,4 @@ +import ast from pydantic import BaseModel from fastapi import HTTPException @@ -48,5 +49,23 @@ class Component(BaseModel): return validate.create_function(self.code, self.function_entrypoint_name) + def build_template_config(self, attributes) -> dict: + template_config = {} + + for item in attributes: + item_name = item.get("name") + + if item_value := item.get("value"): + if "langflow_display_name" in item_name: + template_config["display_name"] = ast.literal_eval(item_value) + + elif "langflow_description" in item_name: + template_config["description"] = ast.literal_eval(item_value) + + elif "langflow_field_config" in item_name: + template_config["field_config"] = ast.literal_eval(item_value) + + return template_config + def build(self): raise NotImplementedError diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 8fb5af62c..c439b0d2a 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -1,4 +1,3 @@ -import ast from typing import Callable, Optional from fastapi import HTTPException from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES @@ -28,7 +27,7 @@ class CustomComponent(Component): }, ) - # TODO: build logic + # TODO: Create the logic to validate what the Custom Component should have as a prerequisite to be able to execute return True def is_check_valid(self) -> bool: @@ -92,24 +91,35 @@ class CustomComponent(Component): return build_method["return_type"] @property - def get_template_config(self) -> dict: - extra_attributes = {} # self.get_extra_attributes - template_config = {} + def get_main_class_name(self): + tree = self.get_code_tree(self.code) - if "field_config" in extra_attributes: - template_config["field_config"] = ast.literal_eval( - extra_attributes["field_config"] - ) - if "display_name" in extra_attributes: - template_config["display_name"] = ast.literal_eval( - extra_attributes["display_name"] - ) - if "description" in extra_attributes: - template_config["description"] = ast.literal_eval( - extra_attributes["description"] - ) + base_name = self.code_class_base_inheritance + method_name = self.function_entrypoint_name - return template_config + classes = [] + for item in tree.get("classes"): + if base_name in item["bases"]: + method_names = [method["name"] for method in item["methods"]] + if method_name in method_names: + classes.append(item["name"]) + + # Get just the first item + return next(iter(classes), "") + + @property + def build_template_config(self): + tree = self.get_code_tree(self.code) + + attributes = [ + main_class["attributes"] + for main_class in tree.get("classes") + if main_class["name"] == self.get_main_class_name + ] + # Get just the first item + attributes = next(iter(attributes), []) + + return super().build_template_config(attributes) @property def get_function(self): diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 892b70260..b1ba8573f 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -161,7 +161,18 @@ def build_langchain_template_custom_component(custom_component: CustomComponent) function_args = custom_component.get_function_entrypoint_args return_type = custom_component.get_function_entrypoint_return_type - # template_config = custom_component.get_template_config + template_config = custom_component.build_template_config + + # Rewrite diplay_name and description values + if frontend_node: + if "display_name" in template_config: + frontend_node["display_name"] = template_config["display_name"] + + elif "description" in template_config: + frontend_node["description"] = template_config["description"] + + # Rewrite field configurations + field_config = template_config.get("field_config", {}) if function_args is not None: # Add extra fields @@ -174,7 +185,6 @@ def build_langchain_template_custom_component(custom_component: CustomComponent) extra_field.get("default") if "default" in extra_field else "" ) field_required = True - field_config = {} # TODO: Validate type - if is possible to render into frontend if "optional" in field_type.lower(): @@ -184,13 +194,14 @@ def build_langchain_template_custom_component(custom_component: CustomComponent) if not field_type: field_type = "str" + config = field_config.get(field_name, {}) frontend_node = add_new_custom_field( frontend_node, field_name, field_type, field_value, field_required, - field_config, + config, ) frontend_node = add_code_field(frontend_node, custom_component.code) diff --git a/tests/conftest.py b/tests/conftest.py index 8be738632..1773ebf23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -252,7 +252,8 @@ class CustomChain(CustomComponent): def build(self, prompt, llm, input: str) -> Document: chain = MyCustomChain(prompt=prompt, llm=llm) - return chain(input)''' + return chain(input) +''' @pytest.fixture