🔨 refactor(langflow): improve code parsing and custom component handling

- Refactor code parsing in `code_parser.py` to handle imports, function definitions, and class attributes more robustly and cleanly.
- Add new methods in `component.py` to parse Assign and AnnAssign statements, and FunctionDef statements.
- Refactor `custom_component.py` to improve the handling of custom components, including better extraction of main class name and template configuration.
- Update `types.py` to better handle the building of custom component templates, including handling of field configurations and error handling.
- Minor formatting fix in `conftest.py` test fixture.

These changes improve the robustness and readability of the code, and provide better handling and validation of custom components.
This commit is contained in:
gustavoschaedler 2023-07-15 00:41:31 +01:00
commit 0aab360629
5 changed files with 106 additions and 38 deletions

View file

@ -61,12 +61,11 @@ class CodeParser:
Extracts "imports" from the code.
"""
if isinstance(node, ast.Import):
module = node.names[0].name
self.data["imports"].append(module)
for alias in node.names:
self.data["imports"].append(alias.name)
elif isinstance(node, ast.ImportFrom):
module = node.module
names = [alias.name for alias in node.names]
self.data["imports"].append((module, names))
for alias in node.names:
self.data["imports"].append((node.module, alias.name))
def parse_functions(self, node: ast.FunctionDef) -> None:
"""
@ -97,7 +96,7 @@ class CodeParser:
# Handle positional arguments with default values
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + [
ast.unparse(default) for default in node.args.defaults
ast.unparse(default) if default else None for default in node.args.defaults
]
for arg, default in zip(node.args.args, defaults):
@ -126,10 +125,38 @@ class CodeParser:
func["body"].append(ast.unparse(line))
return func
def parse_assign(self, stmt):
"""
Parses an Assign statement and returns a dictionary
with the target's name and value.
"""
for target in stmt.targets:
if isinstance(target, ast.Name):
return {"name": target.id, "value": ast.unparse(stmt.value)}
def parse_ann_assign(self, stmt):
"""
Parses an AnnAssign statement and returns a dictionary
with the target's name, value, and annotation.
"""
if isinstance(stmt.target, ast.Name):
return {
"name": stmt.target.id,
"value": ast.unparse(stmt.value) if stmt.value else None,
"annotation": ast.unparse(stmt.annotation),
}
def parse_function_def(self, stmt):
"""
Parses a FunctionDef statement and returns the parsed
method and a boolean indicating if it's an __init__ method.
"""
method = self.parse_callable_details(stmt)
return (method, True) if stmt.name == "__init__" else (method, False)
def parse_classes(self, node: ast.ClassDef) -> None:
"""
Extracts "classes" from the code, including
inheritance and init methods.
Extracts "classes" from the code, including inheritance and init methods.
"""
class_dict = {
"name": node.name,
@ -140,15 +167,15 @@ class CodeParser:
}
for stmt in node.body:
if isinstance(stmt, ast.AnnAssign):
attr = {"name": stmt.target.id, "type": ast.unparse(stmt.annotation)}
class_dict["attributes"].append(attr)
elif isinstance(stmt, ast.Assign):
attr = {"name": stmt.targets[0].id, "value": ast.unparse(stmt.value)}
class_dict["attributes"].append(attr)
if isinstance(stmt, ast.Assign):
if attr := self.parse_assign(stmt):
class_dict["attributes"].append(attr)
elif isinstance(stmt, ast.AnnAssign):
if attr := self.parse_ann_assign(stmt):
class_dict["attributes"].append(attr)
elif isinstance(stmt, ast.FunctionDef):
method = self.parse_callable_details(stmt)
if stmt.name == "__init__":
method, is_init = self.parse_function_def(stmt)
if is_init:
class_dict["init"] = method
else:
class_dict["methods"].append(method)

View file

@ -1,3 +1,4 @@
import ast
from pydantic import BaseModel
from fastapi import HTTPException
@ -48,5 +49,23 @@ class Component(BaseModel):
return validate.create_function(self.code, self.function_entrypoint_name)
def build_template_config(self, attributes) -> dict:
template_config = {}
for item in attributes:
item_name = item.get("name")
if item_value := item.get("value"):
if "langflow_display_name" in item_name:
template_config["display_name"] = ast.literal_eval(item_value)
elif "langflow_description" in item_name:
template_config["description"] = ast.literal_eval(item_value)
elif "langflow_field_config" in item_name:
template_config["field_config"] = ast.literal_eval(item_value)
return template_config
def build(self):
raise NotImplementedError

View file

@ -1,4 +1,3 @@
import ast
from typing import Callable, Optional
from fastapi import HTTPException
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
@ -28,7 +27,7 @@ class CustomComponent(Component):
},
)
# TODO: build logic
# TODO: Create the logic to validate what the Custom Component should have as a prerequisite to be able to execute
return True
def is_check_valid(self) -> bool:
@ -92,24 +91,35 @@ class CustomComponent(Component):
return build_method["return_type"]
@property
def get_template_config(self) -> dict:
extra_attributes = {} # self.get_extra_attributes
template_config = {}
def get_main_class_name(self):
tree = self.get_code_tree(self.code)
if "field_config" in extra_attributes:
template_config["field_config"] = ast.literal_eval(
extra_attributes["field_config"]
)
if "display_name" in extra_attributes:
template_config["display_name"] = ast.literal_eval(
extra_attributes["display_name"]
)
if "description" in extra_attributes:
template_config["description"] = ast.literal_eval(
extra_attributes["description"]
)
base_name = self.code_class_base_inheritance
method_name = self.function_entrypoint_name
return template_config
classes = []
for item in tree.get("classes"):
if base_name in item["bases"]:
method_names = [method["name"] for method in item["methods"]]
if method_name in method_names:
classes.append(item["name"])
# Get just the first item
return next(iter(classes), "")
@property
def build_template_config(self):
tree = self.get_code_tree(self.code)
attributes = [
main_class["attributes"]
for main_class in tree.get("classes")
if main_class["name"] == self.get_main_class_name
]
# Get just the first item
attributes = next(iter(attributes), [])
return super().build_template_config(attributes)
@property
def get_function(self):

View file

@ -161,7 +161,18 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
function_args = custom_component.get_function_entrypoint_args
return_type = custom_component.get_function_entrypoint_return_type
# template_config = custom_component.get_template_config
template_config = custom_component.build_template_config
# Rewrite diplay_name and description values
if frontend_node:
if "display_name" in template_config:
frontend_node["display_name"] = template_config["display_name"]
elif "description" in template_config:
frontend_node["description"] = template_config["description"]
# Rewrite field configurations
field_config = template_config.get("field_config", {})
if function_args is not None:
# Add extra fields
@ -174,7 +185,6 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
extra_field.get("default") if "default" in extra_field else ""
)
field_required = True
field_config = {}
# TODO: Validate type - if is possible to render into frontend
if "optional" in field_type.lower():
@ -184,13 +194,14 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
if not field_type:
field_type = "str"
config = field_config.get(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
field_name,
field_type,
field_value,
field_required,
field_config,
config,
)
frontend_node = add_code_field(frontend_node, custom_component.code)

View file

@ -252,7 +252,8 @@ class CustomChain(CustomComponent):
def build(self, prompt, llm, input: str) -> Document:
chain = MyCustomChain(prompt=prompt, llm=llm)
return chain(input)'''
return chain(input)
'''
@pytest.fixture