🔨 refactor(langflow): improve code parsing and custom component handling
- Refactor code parsing in `code_parser.py` to handle imports, function definitions, and class attributes more robustly and cleanly. - Add new methods in `component.py` to parse Assign and AnnAssign statements, and FunctionDef statements. - Refactor `custom_component.py` to improve the handling of custom components, including better extraction of main class name and template configuration. - Update `types.py` to better handle the building of custom component templates, including handling of field configurations and error handling. - Minor formatting fix in `conftest.py` test fixture. These changes improve the robustness and readability of the code, and provide better handling and validation of custom components.
This commit is contained in:
parent
cd94c47b0e
commit
0aab360629
5 changed files with 106 additions and 38 deletions
|
|
@ -61,12 +61,11 @@ class CodeParser:
|
|||
Extracts "imports" from the code.
|
||||
"""
|
||||
if isinstance(node, ast.Import):
|
||||
module = node.names[0].name
|
||||
self.data["imports"].append(module)
|
||||
for alias in node.names:
|
||||
self.data["imports"].append(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
module = node.module
|
||||
names = [alias.name for alias in node.names]
|
||||
self.data["imports"].append((module, names))
|
||||
for alias in node.names:
|
||||
self.data["imports"].append((node.module, alias.name))
|
||||
|
||||
def parse_functions(self, node: ast.FunctionDef) -> None:
|
||||
"""
|
||||
|
|
@ -97,7 +96,7 @@ class CodeParser:
|
|||
|
||||
# Handle positional arguments with default values
|
||||
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + [
|
||||
ast.unparse(default) for default in node.args.defaults
|
||||
ast.unparse(default) if default else None for default in node.args.defaults
|
||||
]
|
||||
|
||||
for arg, default in zip(node.args.args, defaults):
|
||||
|
|
@ -126,10 +125,38 @@ class CodeParser:
|
|||
func["body"].append(ast.unparse(line))
|
||||
return func
|
||||
|
||||
def parse_assign(self, stmt):
|
||||
"""
|
||||
Parses an Assign statement and returns a dictionary
|
||||
with the target's name and value.
|
||||
"""
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
return {"name": target.id, "value": ast.unparse(stmt.value)}
|
||||
|
||||
def parse_ann_assign(self, stmt):
|
||||
"""
|
||||
Parses an AnnAssign statement and returns a dictionary
|
||||
with the target's name, value, and annotation.
|
||||
"""
|
||||
if isinstance(stmt.target, ast.Name):
|
||||
return {
|
||||
"name": stmt.target.id,
|
||||
"value": ast.unparse(stmt.value) if stmt.value else None,
|
||||
"annotation": ast.unparse(stmt.annotation),
|
||||
}
|
||||
|
||||
def parse_function_def(self, stmt):
|
||||
"""
|
||||
Parses a FunctionDef statement and returns the parsed
|
||||
method and a boolean indicating if it's an __init__ method.
|
||||
"""
|
||||
method = self.parse_callable_details(stmt)
|
||||
return (method, True) if stmt.name == "__init__" else (method, False)
|
||||
|
||||
def parse_classes(self, node: ast.ClassDef) -> None:
|
||||
"""
|
||||
Extracts "classes" from the code, including
|
||||
inheritance and init methods.
|
||||
Extracts "classes" from the code, including inheritance and init methods.
|
||||
"""
|
||||
class_dict = {
|
||||
"name": node.name,
|
||||
|
|
@ -140,15 +167,15 @@ class CodeParser:
|
|||
}
|
||||
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.AnnAssign):
|
||||
attr = {"name": stmt.target.id, "type": ast.unparse(stmt.annotation)}
|
||||
class_dict["attributes"].append(attr)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
attr = {"name": stmt.targets[0].id, "value": ast.unparse(stmt.value)}
|
||||
class_dict["attributes"].append(attr)
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if attr := self.parse_assign(stmt):
|
||||
class_dict["attributes"].append(attr)
|
||||
elif isinstance(stmt, ast.AnnAssign):
|
||||
if attr := self.parse_ann_assign(stmt):
|
||||
class_dict["attributes"].append(attr)
|
||||
elif isinstance(stmt, ast.FunctionDef):
|
||||
method = self.parse_callable_details(stmt)
|
||||
if stmt.name == "__init__":
|
||||
method, is_init = self.parse_function_def(stmt)
|
||||
if is_init:
|
||||
class_dict["init"] = method
|
||||
else:
|
||||
class_dict["methods"].append(method)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import ast
|
||||
from pydantic import BaseModel
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
|
@ -48,5 +49,23 @@ class Component(BaseModel):
|
|||
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def build_template_config(self, attributes) -> dict:
|
||||
template_config = {}
|
||||
|
||||
for item in attributes:
|
||||
item_name = item.get("name")
|
||||
|
||||
if item_value := item.get("value"):
|
||||
if "langflow_display_name" in item_name:
|
||||
template_config["display_name"] = ast.literal_eval(item_value)
|
||||
|
||||
elif "langflow_description" in item_name:
|
||||
template_config["description"] = ast.literal_eval(item_value)
|
||||
|
||||
elif "langflow_field_config" in item_name:
|
||||
template_config["field_config"] = ast.literal_eval(item_value)
|
||||
|
||||
return template_config
|
||||
|
||||
def build(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import ast
|
||||
from typing import Callable, Optional
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
|
||||
|
|
@ -28,7 +27,7 @@ class CustomComponent(Component):
|
|||
},
|
||||
)
|
||||
|
||||
# TODO: build logic
|
||||
# TODO: Create the logic to validate what the Custom Component should have as a prerequisite to be able to execute
|
||||
return True
|
||||
|
||||
def is_check_valid(self) -> bool:
|
||||
|
|
@ -92,24 +91,35 @@ class CustomComponent(Component):
|
|||
return build_method["return_type"]
|
||||
|
||||
@property
|
||||
def get_template_config(self) -> dict:
|
||||
extra_attributes = {} # self.get_extra_attributes
|
||||
template_config = {}
|
||||
def get_main_class_name(self):
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
if "field_config" in extra_attributes:
|
||||
template_config["field_config"] = ast.literal_eval(
|
||||
extra_attributes["field_config"]
|
||||
)
|
||||
if "display_name" in extra_attributes:
|
||||
template_config["display_name"] = ast.literal_eval(
|
||||
extra_attributes["display_name"]
|
||||
)
|
||||
if "description" in extra_attributes:
|
||||
template_config["description"] = ast.literal_eval(
|
||||
extra_attributes["description"]
|
||||
)
|
||||
base_name = self.code_class_base_inheritance
|
||||
method_name = self.function_entrypoint_name
|
||||
|
||||
return template_config
|
||||
classes = []
|
||||
for item in tree.get("classes"):
|
||||
if base_name in item["bases"]:
|
||||
method_names = [method["name"] for method in item["methods"]]
|
||||
if method_name in method_names:
|
||||
classes.append(item["name"])
|
||||
|
||||
# Get just the first item
|
||||
return next(iter(classes), "")
|
||||
|
||||
@property
|
||||
def build_template_config(self):
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
attributes = [
|
||||
main_class["attributes"]
|
||||
for main_class in tree.get("classes")
|
||||
if main_class["name"] == self.get_main_class_name
|
||||
]
|
||||
# Get just the first item
|
||||
attributes = next(iter(attributes), [])
|
||||
|
||||
return super().build_template_config(attributes)
|
||||
|
||||
@property
|
||||
def get_function(self):
|
||||
|
|
|
|||
|
|
@ -161,7 +161,18 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
|
|||
|
||||
function_args = custom_component.get_function_entrypoint_args
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
# template_config = custom_component.get_template_config
|
||||
template_config = custom_component.build_template_config
|
||||
|
||||
# Rewrite diplay_name and description values
|
||||
if frontend_node:
|
||||
if "display_name" in template_config:
|
||||
frontend_node["display_name"] = template_config["display_name"]
|
||||
|
||||
elif "description" in template_config:
|
||||
frontend_node["description"] = template_config["description"]
|
||||
|
||||
# Rewrite field configurations
|
||||
field_config = template_config.get("field_config", {})
|
||||
|
||||
if function_args is not None:
|
||||
# Add extra fields
|
||||
|
|
@ -174,7 +185,6 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
|
|||
extra_field.get("default") if "default" in extra_field else ""
|
||||
)
|
||||
field_required = True
|
||||
field_config = {}
|
||||
|
||||
# TODO: Validate type - if is possible to render into frontend
|
||||
if "optional" in field_type.lower():
|
||||
|
|
@ -184,13 +194,14 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
|
|||
if not field_type:
|
||||
field_type = "str"
|
||||
|
||||
config = field_config.get(field_name, {})
|
||||
frontend_node = add_new_custom_field(
|
||||
frontend_node,
|
||||
field_name,
|
||||
field_type,
|
||||
field_value,
|
||||
field_required,
|
||||
field_config,
|
||||
config,
|
||||
)
|
||||
|
||||
frontend_node = add_code_field(frontend_node, custom_component.code)
|
||||
|
|
|
|||
|
|
@ -252,7 +252,8 @@ class CustomChain(CustomComponent):
|
|||
|
||||
def build(self, prompt, llm, input: str) -> Document:
|
||||
chain = MyCustomChain(prompt=prompt, llm=llm)
|
||||
return chain(input)'''
|
||||
return chain(input)
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue