🔀 refactor(langflow): rename custom.py to custom_component.py for clarity

🔥 remove(langflow): delete custom.py as it's replaced by custom_component.py
📦 feat(langflow): add code_parser.py to parse Python source code
🐛 fix(langflow): update import paths due to file renaming
🎨 style(langflow): improve code formatting for readability
🐛 fix(langflow): correct handling of function arguments and return types in custom components
🔧 chore(langflow): update function calls due to changes in custom components
This commit is contained in:
gustavoschaedler 2023-07-14 04:49:42 +01:00
commit 79d2d551ff
10 changed files with 379 additions and 249 deletions

View file

@ -1,3 +1,4 @@
from datetime import timezone
from typing import List
from uuid import UUID
from langflow.database.models.component import Component, ComponentModel
@ -60,7 +61,7 @@ def update_component(
for key, value in component_data.items():
setattr(db_component, key, value)
db_component.update_at = datetime.utcnow()
db_component.update_at = datetime.now(timezone.utc)
db.commit()
db.refresh(db_component)
return db_component

View file

@ -7,7 +7,7 @@ from langflow.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from langflow.interface.custom.custom import CustomComponent
from langflow.interface.custom.custom_component import CustomComponent
from langflow.api.v1.schemas import (
ProcessResponse,

View file

@ -1,4 +1,4 @@
from langflow.interface.custom.base import CustomComponentCreator
from langflow.interface.custom.custom import CustomComponent
from langflow.interface.custom.custom_component import CustomComponent
__all__ = ["CustomComponentCreator", "CustomComponent"]

View file

@ -2,7 +2,9 @@ from typing import Any, Dict, List, Optional, Type
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.custom.custom import CustomComponent
# from langflow.interface.custom.custom import CustomComponent
from langflow.interface.custom.custom_component import CustomComponent
from langflow.template.frontend_node.custom_components import (
CustomComponentFrontendNode,
)

View file

@ -0,0 +1,178 @@
import ast
import traceback
from typing import Dict, Any, Union
from fastapi import HTTPException
class CodeSyntaxError(HTTPException):
pass
class CodeParser:
"""
A parser for Python source code, extracting code details.
"""
def __init__(self, code: str) -> None:
"""
Initializes the parser with the provided code.
"""
self.code = code
self.data: Dict[str, Any] = {
"imports": [],
"functions": [],
"classes": [],
"global_vars": [],
}
self.handlers = {
ast.Import: self.parse_imports,
ast.ImportFrom: self.parse_imports,
ast.FunctionDef: self.parse_functions,
ast.ClassDef: self.parse_classes,
ast.Assign: self.parse_global_vars,
}
def __get_tree(self):
"""
Parses the provided code to validate its syntax.
It tries to parse the code into an abstract syntax tree (AST).
"""
try:
tree = ast.parse(self.code)
except SyntaxError as err:
raise CodeSyntaxError(
status_code=400,
detail={"error": err.msg, "traceback": traceback.format_exc()},
) from err
return tree
def parse_node(self, node: ast.AST) -> None:
"""
Parses an AST node and updates the data
dictionary with the relevant information.
"""
if handler := self.handlers.get(type(node)):
handler(node)
def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
"""
Extracts "imports" from the code.
"""
if isinstance(node, ast.Import):
module = node.names[0].name
self.data["imports"].append(module)
elif isinstance(node, ast.ImportFrom):
module = node.module
names = [alias.name for alias in node.names]
self.data["imports"].append((module, names))
def parse_functions(self, node: ast.FunctionDef) -> None:
"""
Extracts "functions" from the code.
"""
self.data["functions"].append(self.parse_callable_details(node))
def parse_arg(self, arg, default):
"""
Parses an argument and its default value.
"""
arg_dict = {"name": arg.arg, "default": default}
if arg.annotation:
arg_dict["type"] = ast.unparse(arg.annotation)
return arg_dict
def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
"""
Extracts details from a single function or method node.
"""
func = {
"name": node.name,
"doc": ast.get_docstring(node),
"args": [],
"body": [],
"return_type": ast.unparse(node.returns) if node.returns else None,
}
# Handle positional arguments with default values
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + [
ast.unparse(default) for default in node.args.defaults
]
for arg, default in zip(node.args.args, defaults):
func["args"].append(self.parse_arg(arg, default))
# Handle *args
if node.args.vararg:
func["args"].append(self.parse_arg(node.args.vararg, None))
# Handle keyword-only arguments with default values
kw_defaults = [None] * (
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
) + [
ast.unparse(default) if default else None
for default in node.args.kw_defaults
]
for arg, default in zip(node.args.kwonlyargs, kw_defaults):
func["args"].append(self.parse_arg(arg, default))
# Handle **kwargs
if node.args.kwarg:
func["args"].append(self.parse_arg(node.args.kwarg, None))
for line in node.body:
func["body"].append(ast.unparse(line))
return func
def parse_classes(self, node: ast.ClassDef) -> None:
"""
Extracts "classes" from the code, including
inheritance and init methods.
"""
class_dict = {
"name": node.name,
"doc": ast.get_docstring(node),
"bases": [ast.unparse(base) for base in node.bases],
"attributes": [],
"methods": [],
}
for stmt in node.body:
if isinstance(stmt, ast.AnnAssign):
attr = {"name": stmt.target.id, "type": ast.unparse(stmt.annotation)}
class_dict["attributes"].append(attr)
elif isinstance(stmt, ast.Assign):
attr = {"name": stmt.targets[0].id, "value": ast.unparse(stmt.value)}
class_dict["attributes"].append(attr)
elif isinstance(stmt, ast.FunctionDef):
method = self.parse_callable_details(stmt)
if stmt.name == "__init__":
class_dict["init"] = method
else:
class_dict["methods"].append(method)
self.data["classes"].append(class_dict)
def parse_global_vars(self, node: ast.Assign) -> None:
"""
Extracts global variables from the code.
"""
global_var = {
"targets": [
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
],
"value": ast.unparse(node.value),
}
self.data["global_vars"].append(global_var)
def parse_code(self) -> Dict[str, Any]:
"""
Runs all parsing operations and returns the resulting data.
"""
tree = self.__get_tree()
for node in ast.walk(tree):
self.parse_node(node)
return self.data

View file

@ -0,0 +1,53 @@
from pydantic import BaseModel
from fastapi import HTTPException
from langflow.utils import validate
from langflow.interface.custom.code_parser import CodeParser
class ComponentCodeNullError(HTTPException):
pass
class ComponentFunctionEntrypointNameNullError(HTTPException):
pass
class Component(BaseModel):
ERROR_CODE_NULL = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = (
"The name of the entrypoint function must be provided."
)
code: str
function_entrypoint_name = "build"
field_config: dict = {}
def __init__(self, **data):
super().__init__(**data)
def get_code_tree(self, code: str):
parser = CodeParser(code)
return parser.parse_code()
def get_function(self):
if not self.code:
raise ComponentCodeNullError(
status_code=400,
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
)
if not self.function_entrypoint_name:
raise ComponentFunctionEntrypointNameNullError(
status_code=400,
detail={
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
"traceback": "",
},
)
return validate.create_function(self.code, self.function_entrypoint_name)
def build(self):
raise NotImplementedError

View file

@ -1,220 +0,0 @@
import re
import ast
import traceback
from typing import Callable, Optional
from fastapi import HTTPException
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
from langflow.utils import validate
from pydantic import BaseModel
class CustomComponent(BaseModel):
field_config: dict = {}
code: str
function: Optional[Callable] = None
function_entrypoint_name = "build"
return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys())
class_template = {
"imports": [],
"class": {"inherited_classes": "", "name": "", "init": "", "attributes": {}},
"functions": [],
}
def __init__(self, **data):
super().__init__(**data)
def _handle_import(self, node):
for alias in node.names:
module_name = getattr(node, "module", None)
self.class_template["imports"].append(
f"{module_name}.{alias.name}" if module_name else alias.name
)
def _handle_class(self, node):
self.class_template["class"].update(
{
"name": node.name,
"inherited_classes": [ast.unparse(base) for base in node.bases],
}
)
attributes = {} # To store the attributes and their values
for inner_node in node.body:
if isinstance(inner_node, ast.Assign): # An assignment
for target in inner_node.targets: # Targets of the assignment
if isinstance(target, ast.Name): # A simple variable
# Add the attribute and its value to the dictionary
attributes[target.id] = ast.unparse(inner_node.value)
elif isinstance(inner_node, ast.AnnAssign): # An annotated assignment
if isinstance(inner_node.target, ast.Name) and inner_node.value:
attributes[inner_node.target.id] = ast.unparse(inner_node.value)
elif isinstance(inner_node, ast.FunctionDef):
self._handle_function(inner_node)
# You can add these attributes to your class_template if you want
self.class_template["class"]["attributes"] = attributes
def _handle_function(self, node):
function_name = node.name
function_args_str = ast.unparse(node.args)
function_args = function_args_str.split(", ") if function_args_str else []
return_type = ast.unparse(node.returns) if node.returns else "None"
function_data = {
"name": function_name,
"arguments": function_args,
"return_type": return_type,
}
if function_name == "__init__":
self.class_template["class"]["init"] = (
function_args_str.split(", ") if function_args_str else []
)
else:
self.class_template["functions"].append(function_data)
def _split_string(self, text):
"""
Split a string by ':' or '=' and append None until the resulting list has 3 items.
Parameters:
text (str): The string to be split.
Returns:
list: A list of strings resulting from the split operation,
padded with None until its length is 3.
"""
items = [item.strip() for item in re.split(r"[:=]", text) if item.strip()]
while len(items) < 3:
items.append(None)
return items
def transform_list(self, input_list):
"""
Transform a list of strings by splitting each string and padding with None.
Parameters:
input_list (list): The list of strings to be transformed.
Returns:
list: A list of lists, each containing the result of the split operation.
"""
return [self._split_string(item) for item in input_list]
def extract_class_info(self):
try:
module = ast.parse(self.code)
except SyntaxError as err:
raise HTTPException(
status_code=400,
detail={"error": err.msg, "traceback": traceback.format_exc()},
) from err
for node in module.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._handle_import(node)
elif isinstance(node, ast.ClassDef):
self._handle_class(node)
return self.class_template
def get_entrypoint_function_args_and_return_type(self):
data = self.extract_class_info()
attributes = data.get("class", {}).get("attributes", {})
functions = data.get("functions", [])
template_config = self._build_template_config(attributes)
if build_function := next(
(f for f in functions if f["name"] == self.function_entrypoint_name),
None,
):
function_args = build_function.get("arguments", None)
function_args = self.transform_list(function_args)
return_type = build_function.get("return_type", None)
else:
function_args = None
return_type = None
return function_args, return_type, template_config
def _build_template_config(self, attributes):
template_config = {}
if "field_config" in attributes:
template_config["field_config"] = ast.literal_eval(
attributes["field_config"]
)
if "display_name" in attributes:
template_config["display_name"] = ast.literal_eval(
attributes["display_name"]
)
if "description" in attributes:
template_config["description"] = ast.literal_eval(attributes["description"])
return template_config
def _class_template_validation(self, code: dict):
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
raise HTTPException(
status_code=400,
detail={
"error": "The main class must have a valid name.",
"traceback": "",
},
)
functions = code.get("functions", [])
build_function = next(
(f for f in functions if f["name"] == self.function_entrypoint_name),
None,
)
if not build_function:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid entrypoint function name",
"traceback": (
f"There needs to be at least one entrypoint function named '{self.function_entrypoint_name}'"
f" and it needs to return one of the types from this list {str(self.return_type_valid_list)}.",
),
},
)
return_type = build_function.get("return_type")
if return_type not in self.return_type_valid_list:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid entrypoint function return",
"traceback": (
f"The entrypoint function return '{return_type}' needs to be an item "
f"from this list {str(self.return_type_valid_list)}."
),
},
)
return True
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
def build(self):
raise NotImplementedError
@property
def data(self):
return self.extract_class_info()
def is_check_valid(self):
return self._class_template_validation(self.data)
@property
def args_and_return_type(self):
return self.get_entrypoint_function_args_and_return_type()

View file

@ -0,0 +1,119 @@
import ast
from typing import Callable, Optional
from fastapi import HTTPException
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
from langflow.interface.custom.component import Component
from langflow.utils import validate
class CustomComponent(Component):
code: str
field_config: dict = {}
code_class_base_inheritance = "CustomComponent"
function_entrypoint_name = "build"
function: Optional[Callable] = None
return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys())
def __init__(self, **data):
super().__init__(**data)
def _class_template_validation(self, code: str) -> bool:
if not code:
raise HTTPException(
status_code=400,
detail={
"error": self.ERROR_CODE_NULL,
"traceback": "",
},
)
# TODO: build logic
return True
def is_check_valid(self) -> bool:
return self._class_template_validation(self.code)
def get_code_tree(self, code: str):
return super().get_code_tree(code)
@property
def get_function_entrypoint_args(self) -> str:
tree = self.get_code_tree(self.code)
component_classes = [
cls
for cls in tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return ""
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
if not build_methods:
return ""
build_method = build_methods[0]
return build_method["args"]
@property
def get_function_entrypoint_return_type(self) -> str:
tree = self.get_code_tree(self.code)
component_classes = [
cls
for cls in tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return ""
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
if not build_methods:
return ""
build_method = build_methods[0]
return build_method["return_type"]
@property
def get_template_config(self) -> dict:
extra_attributes = {} # self.get_extra_attributes
template_config = {}
if "field_config" in extra_attributes:
template_config["field_config"] = ast.literal_eval(
extra_attributes["field_config"]
)
if "display_name" in extra_attributes:
template_config["display_name"] = ast.literal_eval(
extra_attributes["display_name"]
)
if "description" in extra_attributes:
template_config["description"] = ast.literal_eval(
extra_attributes["description"]
)
return template_config
@property
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
def build(self):
raise NotImplementedError

View file

@ -9,7 +9,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.tools import BaseTool
from langflow.interface.custom.custom import CustomComponent
from langflow.interface.custom.custom_component import CustomComponent
from langflow.utils import validate
from langflow.interface.wrappers.base import wrapper_creator
@ -61,7 +61,9 @@ def import_by_type(_type: str, name: str) -> Any:
def import_custom_component(custom_component: str) -> CustomComponent:
"""Import custom component from custom component name"""
return import_class(f"langflow.interface.custom.custom.{custom_component}")
return import_class(
f"langflow.interface.custom.custom_component.{custom_component}"
)
def import_output_parser(output_parser: str) -> Any:
@ -183,5 +185,4 @@ def get_function(code):
def get_function_custom(code):
class_name = validate.extract_class_name(code)
return validate.create_class(code, class_name)

View file

@ -14,7 +14,7 @@ from langflow.interface.vector_store.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.interface.output_parsers.base import output_parser_creator
from langflow.interface.custom.base import custom_component_creator
from langflow.interface.custom.custom import CustomComponent
from langflow.interface.custom.custom_component import CustomComponent
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.tools import CustomComponentNode
@ -92,9 +92,6 @@ def add_new_custom_field(
field_type = field_config.pop("field_type", field_type)
field_type = process_type(field_type)
if field_value is not None:
field_value = field_value.replace("'", "").replace('"', "")
if "name" in field_config:
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
@ -158,29 +155,27 @@ def extract_type_from_optional(field_type):
return match[1] if match else None
def build_langchain_template_custom_component(extractor: CustomComponent):
def build_langchain_template_custom_component(custom_component: CustomComponent):
# Build base "CustomComponent" template
frontend_node = CustomComponentNode().to_dict().get(type(extractor).__name__)
frontend_node = CustomComponentNode().to_dict().get(type(custom_component).__name__)
function_args, return_type, template_config = extractor.args_and_return_type
if "display_name" in template_config and frontend_node is not None:
frontend_node["display_name"] = template_config["display_name"]
if "description" in template_config and frontend_node is not None:
frontend_node["description"] = template_config["description"]
raw_code = extractor.code
field_config = template_config.get("field_config", {})
function_args = custom_component.get_function_entrypoint_args
return_type = custom_component.get_function_entrypoint_return_type
# template_config = custom_component.get_template_config
if function_args is not None:
# Add extra fields
for extra_field in function_args:
field_required = True
field_name, field_type, field_value = extra_field
if not field_type:
field_type = ""
field_name = extra_field.get("name") if "name" in extra_field else ""
if field_name != "self":
field_type = extra_field.get("type") if "type" in extra_field else ""
field_value = (
extra_field.get("default") if "default" in extra_field else ""
)
field_required = True
field_config = {}
# TODO: Validate type - if is possible to render into frontend
if "optional" in field_type.lower():
field_type = extract_type_from_optional(field_type)
@ -189,17 +184,16 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
if not field_type:
field_type = "str"
config = field_config.get(field_name, {})
frontend_node = add_new_custom_field(
frontend_node,
field_name,
field_type,
field_value,
field_required,
config,
field_config,
)
frontend_node = add_code_field(frontend_node, raw_code)
frontend_node = add_code_field(frontend_node, custom_component.code)
# Get base classes from "return_type" and add to template.base_classes
try:
@ -214,8 +208,10 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
"traceback": traceback.format_exc(),
},
)
return_type_instance = LANGCHAIN_BASE_TYPES.get(return_type)
base_classes = get_base_classes(return_type_instance)
except (KeyError, AttributeError) as err:
raise HTTPException(
status_code=400,