🔀 refactor(langflow): rename custom.py to custom_component.py for clarity
🔥 remove(langflow): delete custom.py as it's replaced by custom_component.py 📦 feat(langflow): add code_parser.py to parse Python source code 🐛 fix(langflow): update import paths due to file renaming 🎨 style(langflow): improve code formatting for readability 🐛 fix(langflow): correct handling of function arguments and return types in custom components 🔧 chore(langflow): update function calls due to changes in custom components
This commit is contained in:
parent
e8c844a75f
commit
79d2d551ff
10 changed files with 379 additions and 249 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from datetime import timezone
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from langflow.database.models.component import Component, ComponentModel
|
||||
|
|
@ -60,7 +61,7 @@ def update_component(
|
|||
for key, value in component_data.items():
|
||||
setattr(db_component, key, value)
|
||||
|
||||
db_component.update_at = datetime.utcnow()
|
||||
db_component.update_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(db_component)
|
||||
return db_component
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from langflow.utils.logger import logger
|
|||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
|
||||
from langflow.interface.custom.custom import CustomComponent
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
from langflow.api.v1.schemas import (
|
||||
ProcessResponse,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from langflow.interface.custom.base import CustomComponentCreator
|
||||
from langflow.interface.custom.custom import CustomComponent
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
__all__ = ["CustomComponentCreator", "CustomComponent"]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from typing import Any, Dict, List, Optional, Type
|
|||
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom.custom import CustomComponent
|
||||
|
||||
# from langflow.interface.custom.custom import CustomComponent
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.template.frontend_node.custom_components import (
|
||||
CustomComponentFrontendNode,
|
||||
)
|
||||
|
|
|
|||
178
src/backend/langflow/interface/custom/code_parser.py
Normal file
178
src/backend/langflow/interface/custom/code_parser.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
import ast
|
||||
import traceback
|
||||
|
||||
from typing import Dict, Any, Union
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
class CodeSyntaxError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class CodeParser:
|
||||
"""
|
||||
A parser for Python source code, extracting code details.
|
||||
"""
|
||||
|
||||
def __init__(self, code: str) -> None:
|
||||
"""
|
||||
Initializes the parser with the provided code.
|
||||
"""
|
||||
self.code = code
|
||||
self.data: Dict[str, Any] = {
|
||||
"imports": [],
|
||||
"functions": [],
|
||||
"classes": [],
|
||||
"global_vars": [],
|
||||
}
|
||||
self.handlers = {
|
||||
ast.Import: self.parse_imports,
|
||||
ast.ImportFrom: self.parse_imports,
|
||||
ast.FunctionDef: self.parse_functions,
|
||||
ast.ClassDef: self.parse_classes,
|
||||
ast.Assign: self.parse_global_vars,
|
||||
}
|
||||
|
||||
def __get_tree(self):
|
||||
"""
|
||||
Parses the provided code to validate its syntax.
|
||||
It tries to parse the code into an abstract syntax tree (AST).
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(self.code)
|
||||
except SyntaxError as err:
|
||||
raise CodeSyntaxError(
|
||||
status_code=400,
|
||||
detail={"error": err.msg, "traceback": traceback.format_exc()},
|
||||
) from err
|
||||
|
||||
return tree
|
||||
|
||||
def parse_node(self, node: ast.AST) -> None:
|
||||
"""
|
||||
Parses an AST node and updates the data
|
||||
dictionary with the relevant information.
|
||||
"""
|
||||
if handler := self.handlers.get(type(node)):
|
||||
handler(node)
|
||||
|
||||
def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
|
||||
"""
|
||||
Extracts "imports" from the code.
|
||||
"""
|
||||
if isinstance(node, ast.Import):
|
||||
module = node.names[0].name
|
||||
self.data["imports"].append(module)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
module = node.module
|
||||
names = [alias.name for alias in node.names]
|
||||
self.data["imports"].append((module, names))
|
||||
|
||||
def parse_functions(self, node: ast.FunctionDef) -> None:
|
||||
"""
|
||||
Extracts "functions" from the code.
|
||||
"""
|
||||
self.data["functions"].append(self.parse_callable_details(node))
|
||||
|
||||
def parse_arg(self, arg, default):
|
||||
"""
|
||||
Parses an argument and its default value.
|
||||
"""
|
||||
arg_dict = {"name": arg.arg, "default": default}
|
||||
if arg.annotation:
|
||||
arg_dict["type"] = ast.unparse(arg.annotation)
|
||||
return arg_dict
|
||||
|
||||
def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts details from a single function or method node.
|
||||
"""
|
||||
func = {
|
||||
"name": node.name,
|
||||
"doc": ast.get_docstring(node),
|
||||
"args": [],
|
||||
"body": [],
|
||||
"return_type": ast.unparse(node.returns) if node.returns else None,
|
||||
}
|
||||
|
||||
# Handle positional arguments with default values
|
||||
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + [
|
||||
ast.unparse(default) for default in node.args.defaults
|
||||
]
|
||||
|
||||
for arg, default in zip(node.args.args, defaults):
|
||||
func["args"].append(self.parse_arg(arg, default))
|
||||
|
||||
# Handle *args
|
||||
if node.args.vararg:
|
||||
func["args"].append(self.parse_arg(node.args.vararg, None))
|
||||
|
||||
# Handle keyword-only arguments with default values
|
||||
kw_defaults = [None] * (
|
||||
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
|
||||
) + [
|
||||
ast.unparse(default) if default else None
|
||||
for default in node.args.kw_defaults
|
||||
]
|
||||
|
||||
for arg, default in zip(node.args.kwonlyargs, kw_defaults):
|
||||
func["args"].append(self.parse_arg(arg, default))
|
||||
|
||||
# Handle **kwargs
|
||||
if node.args.kwarg:
|
||||
func["args"].append(self.parse_arg(node.args.kwarg, None))
|
||||
|
||||
for line in node.body:
|
||||
func["body"].append(ast.unparse(line))
|
||||
return func
|
||||
|
||||
def parse_classes(self, node: ast.ClassDef) -> None:
|
||||
"""
|
||||
Extracts "classes" from the code, including
|
||||
inheritance and init methods.
|
||||
"""
|
||||
class_dict = {
|
||||
"name": node.name,
|
||||
"doc": ast.get_docstring(node),
|
||||
"bases": [ast.unparse(base) for base in node.bases],
|
||||
"attributes": [],
|
||||
"methods": [],
|
||||
}
|
||||
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.AnnAssign):
|
||||
attr = {"name": stmt.target.id, "type": ast.unparse(stmt.annotation)}
|
||||
class_dict["attributes"].append(attr)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
attr = {"name": stmt.targets[0].id, "value": ast.unparse(stmt.value)}
|
||||
class_dict["attributes"].append(attr)
|
||||
elif isinstance(stmt, ast.FunctionDef):
|
||||
method = self.parse_callable_details(stmt)
|
||||
if stmt.name == "__init__":
|
||||
class_dict["init"] = method
|
||||
else:
|
||||
class_dict["methods"].append(method)
|
||||
|
||||
self.data["classes"].append(class_dict)
|
||||
|
||||
def parse_global_vars(self, node: ast.Assign) -> None:
|
||||
"""
|
||||
Extracts global variables from the code.
|
||||
"""
|
||||
global_var = {
|
||||
"targets": [
|
||||
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
|
||||
],
|
||||
"value": ast.unparse(node.value),
|
||||
}
|
||||
self.data["global_vars"].append(global_var)
|
||||
|
||||
def parse_code(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Runs all parsing operations and returns the resulting data.
|
||||
"""
|
||||
tree = self.__get_tree()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
self.parse_node(node)
|
||||
return self.data
|
||||
53
src/backend/langflow/interface/custom/component.py
Normal file
53
src/backend/langflow/interface/custom/component.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
|
||||
from pydantic import BaseModel
|
||||
from fastapi import HTTPException
|
||||
|
||||
from langflow.utils import validate
|
||||
from langflow.interface.custom.code_parser import CodeParser
|
||||
|
||||
|
||||
class ComponentCodeNullError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class ComponentFunctionEntrypointNameNullError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class Component(BaseModel):
|
||||
ERROR_CODE_NULL = "Python code must be provided."
|
||||
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = (
|
||||
"The name of the entrypoint function must be provided."
|
||||
)
|
||||
|
||||
code: str
|
||||
function_entrypoint_name = "build"
|
||||
field_config: dict = {}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def get_code_tree(self, code: str):
|
||||
parser = CodeParser(code)
|
||||
return parser.parse_code()
|
||||
|
||||
def get_function(self):
|
||||
if not self.code:
|
||||
raise ComponentCodeNullError(
|
||||
status_code=400,
|
||||
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
|
||||
)
|
||||
|
||||
if not self.function_entrypoint_name:
|
||||
raise ComponentFunctionEntrypointNameNullError(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def build(self):
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,220 +0,0 @@
|
|||
import re
|
||||
import ast
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
|
||||
|
||||
from langflow.utils import validate
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CustomComponent(BaseModel):
|
||||
field_config: dict = {}
|
||||
code: str
|
||||
function: Optional[Callable] = None
|
||||
function_entrypoint_name = "build"
|
||||
return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys())
|
||||
class_template = {
|
||||
"imports": [],
|
||||
"class": {"inherited_classes": "", "name": "", "init": "", "attributes": {}},
|
||||
"functions": [],
|
||||
}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def _handle_import(self, node):
|
||||
for alias in node.names:
|
||||
module_name = getattr(node, "module", None)
|
||||
self.class_template["imports"].append(
|
||||
f"{module_name}.{alias.name}" if module_name else alias.name
|
||||
)
|
||||
|
||||
def _handle_class(self, node):
|
||||
self.class_template["class"].update(
|
||||
{
|
||||
"name": node.name,
|
||||
"inherited_classes": [ast.unparse(base) for base in node.bases],
|
||||
}
|
||||
)
|
||||
|
||||
attributes = {} # To store the attributes and their values
|
||||
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.Assign): # An assignment
|
||||
for target in inner_node.targets: # Targets of the assignment
|
||||
if isinstance(target, ast.Name): # A simple variable
|
||||
# Add the attribute and its value to the dictionary
|
||||
attributes[target.id] = ast.unparse(inner_node.value)
|
||||
elif isinstance(inner_node, ast.AnnAssign): # An annotated assignment
|
||||
if isinstance(inner_node.target, ast.Name) and inner_node.value:
|
||||
attributes[inner_node.target.id] = ast.unparse(inner_node.value)
|
||||
|
||||
elif isinstance(inner_node, ast.FunctionDef):
|
||||
self._handle_function(inner_node)
|
||||
|
||||
# You can add these attributes to your class_template if you want
|
||||
self.class_template["class"]["attributes"] = attributes
|
||||
|
||||
def _handle_function(self, node):
|
||||
function_name = node.name
|
||||
function_args_str = ast.unparse(node.args)
|
||||
function_args = function_args_str.split(", ") if function_args_str else []
|
||||
|
||||
return_type = ast.unparse(node.returns) if node.returns else "None"
|
||||
|
||||
function_data = {
|
||||
"name": function_name,
|
||||
"arguments": function_args,
|
||||
"return_type": return_type,
|
||||
}
|
||||
|
||||
if function_name == "__init__":
|
||||
self.class_template["class"]["init"] = (
|
||||
function_args_str.split(", ") if function_args_str else []
|
||||
)
|
||||
else:
|
||||
self.class_template["functions"].append(function_data)
|
||||
|
||||
def _split_string(self, text):
|
||||
"""
|
||||
Split a string by ':' or '=' and append None until the resulting list has 3 items.
|
||||
|
||||
Parameters:
|
||||
text (str): The string to be split.
|
||||
|
||||
Returns:
|
||||
list: A list of strings resulting from the split operation,
|
||||
padded with None until its length is 3.
|
||||
"""
|
||||
items = [item.strip() for item in re.split(r"[:=]", text) if item.strip()]
|
||||
while len(items) < 3:
|
||||
items.append(None)
|
||||
|
||||
return items
|
||||
|
||||
def transform_list(self, input_list):
|
||||
"""
|
||||
Transform a list of strings by splitting each string and padding with None.
|
||||
|
||||
Parameters:
|
||||
input_list (list): The list of strings to be transformed.
|
||||
|
||||
Returns:
|
||||
list: A list of lists, each containing the result of the split operation.
|
||||
"""
|
||||
return [self._split_string(item) for item in input_list]
|
||||
|
||||
def extract_class_info(self):
|
||||
try:
|
||||
module = ast.parse(self.code)
|
||||
except SyntaxError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": err.msg, "traceback": traceback.format_exc()},
|
||||
) from err
|
||||
|
||||
for node in module.body:
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
self._handle_import(node)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
self._handle_class(node)
|
||||
|
||||
return self.class_template
|
||||
|
||||
def get_entrypoint_function_args_and_return_type(self):
|
||||
data = self.extract_class_info()
|
||||
attributes = data.get("class", {}).get("attributes", {})
|
||||
functions = data.get("functions", [])
|
||||
template_config = self._build_template_config(attributes)
|
||||
|
||||
if build_function := next(
|
||||
(f for f in functions if f["name"] == self.function_entrypoint_name),
|
||||
None,
|
||||
):
|
||||
function_args = build_function.get("arguments", None)
|
||||
function_args = self.transform_list(function_args)
|
||||
|
||||
return_type = build_function.get("return_type", None)
|
||||
else:
|
||||
function_args = None
|
||||
return_type = None
|
||||
|
||||
return function_args, return_type, template_config
|
||||
|
||||
def _build_template_config(self, attributes):
|
||||
template_config = {}
|
||||
if "field_config" in attributes:
|
||||
template_config["field_config"] = ast.literal_eval(
|
||||
attributes["field_config"]
|
||||
)
|
||||
if "display_name" in attributes:
|
||||
template_config["display_name"] = ast.literal_eval(
|
||||
attributes["display_name"]
|
||||
)
|
||||
if "description" in attributes:
|
||||
template_config["description"] = ast.literal_eval(attributes["description"])
|
||||
|
||||
return template_config
|
||||
|
||||
def _class_template_validation(self, code: dict):
|
||||
class_name = code.get("class", {}).get("name", None)
|
||||
if not class_name: # this will also check for None, empty string, etc.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "The main class must have a valid name.",
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
functions = code.get("functions", [])
|
||||
build_function = next(
|
||||
(f for f in functions if f["name"] == self.function_entrypoint_name),
|
||||
None,
|
||||
)
|
||||
|
||||
if not build_function:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid entrypoint function name",
|
||||
"traceback": (
|
||||
f"There needs to be at least one entrypoint function named '{self.function_entrypoint_name}'"
|
||||
f" and it needs to return one of the types from this list {str(self.return_type_valid_list)}.",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return_type = build_function.get("return_type")
|
||||
if return_type not in self.return_type_valid_list:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid entrypoint function return",
|
||||
"traceback": (
|
||||
f"The entrypoint function return '{return_type}' needs to be an item "
|
||||
f"from this list {str(self.return_type_valid_list)}."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def build(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.extract_class_info()
|
||||
|
||||
def is_check_valid(self):
|
||||
return self._class_template_validation(self.data)
|
||||
|
||||
@property
|
||||
def args_and_return_type(self):
|
||||
return self.get_entrypoint_function_args_and_return_type()
|
||||
119
src/backend/langflow/interface/custom/custom_component.py
Normal file
119
src/backend/langflow/interface/custom/custom_component.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import ast
|
||||
from typing import Callable, Optional
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
|
||||
from langflow.interface.custom.component import Component
|
||||
|
||||
from langflow.utils import validate
|
||||
|
||||
|
||||
class CustomComponent(Component):
|
||||
code: str
|
||||
field_config: dict = {}
|
||||
code_class_base_inheritance = "CustomComponent"
|
||||
function_entrypoint_name = "build"
|
||||
function: Optional[Callable] = None
|
||||
return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys())
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def _class_template_validation(self, code: str) -> bool:
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": self.ERROR_CODE_NULL,
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: build logic
|
||||
return True
|
||||
|
||||
def is_check_valid(self) -> bool:
|
||||
return self._class_template_validation(self.code)
|
||||
|
||||
def get_code_tree(self, code: str):
|
||||
return super().get_code_tree(code)
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_args(self) -> str:
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
if not build_methods:
|
||||
return ""
|
||||
|
||||
build_method = build_methods[0]
|
||||
|
||||
return build_method["args"]
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_return_type(self) -> str:
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
if not build_methods:
|
||||
return ""
|
||||
|
||||
build_method = build_methods[0]
|
||||
|
||||
return build_method["return_type"]
|
||||
|
||||
@property
|
||||
def get_template_config(self) -> dict:
|
||||
extra_attributes = {} # self.get_extra_attributes
|
||||
template_config = {}
|
||||
|
||||
if "field_config" in extra_attributes:
|
||||
template_config["field_config"] = ast.literal_eval(
|
||||
extra_attributes["field_config"]
|
||||
)
|
||||
if "display_name" in extra_attributes:
|
||||
template_config["display_name"] = ast.literal_eval(
|
||||
extra_attributes["display_name"]
|
||||
)
|
||||
if "description" in extra_attributes:
|
||||
template_config["description"] = ast.literal_eval(
|
||||
extra_attributes["description"]
|
||||
)
|
||||
|
||||
return template_config
|
||||
|
||||
@property
|
||||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def build(self):
|
||||
raise NotImplementedError
|
||||
|
|
@ -9,7 +9,7 @@ from langchain.base_language import BaseLanguageModel
|
|||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.tools import BaseTool
|
||||
from langflow.interface.custom.custom import CustomComponent
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.utils import validate
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
|
||||
|
|
@ -61,7 +61,9 @@ def import_by_type(_type: str, name: str) -> Any:
|
|||
|
||||
def import_custom_component(custom_component: str) -> CustomComponent:
|
||||
"""Import custom component from custom component name"""
|
||||
return import_class(f"langflow.interface.custom.custom.{custom_component}")
|
||||
return import_class(
|
||||
f"langflow.interface.custom.custom_component.{custom_component}"
|
||||
)
|
||||
|
||||
|
||||
def import_output_parser(output_parser: str) -> Any:
|
||||
|
|
@ -183,5 +185,4 @@ def get_function(code):
|
|||
|
||||
def get_function_custom(code):
|
||||
class_name = validate.extract_class_name(code)
|
||||
|
||||
return validate.create_class(code, class_name)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from langflow.interface.vector_store.base import vectorstore_creator
|
|||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.interface.custom.custom import CustomComponent
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.tools import CustomComponentNode
|
||||
|
|
@ -92,9 +92,6 @@ def add_new_custom_field(
|
|||
field_type = field_config.pop("field_type", field_type)
|
||||
field_type = process_type(field_type)
|
||||
|
||||
if field_value is not None:
|
||||
field_value = field_value.replace("'", "").replace('"', "")
|
||||
|
||||
if "name" in field_config:
|
||||
warnings.warn(
|
||||
"The 'name' key in field_config is used to build the object and can't be changed."
|
||||
|
|
@ -158,29 +155,27 @@ def extract_type_from_optional(field_type):
|
|||
return match[1] if match else None
|
||||
|
||||
|
||||
def build_langchain_template_custom_component(extractor: CustomComponent):
|
||||
def build_langchain_template_custom_component(custom_component: CustomComponent):
|
||||
# Build base "CustomComponent" template
|
||||
frontend_node = CustomComponentNode().to_dict().get(type(extractor).__name__)
|
||||
frontend_node = CustomComponentNode().to_dict().get(type(custom_component).__name__)
|
||||
|
||||
function_args, return_type, template_config = extractor.args_and_return_type
|
||||
|
||||
if "display_name" in template_config and frontend_node is not None:
|
||||
frontend_node["display_name"] = template_config["display_name"]
|
||||
if "description" in template_config and frontend_node is not None:
|
||||
frontend_node["description"] = template_config["description"]
|
||||
raw_code = extractor.code
|
||||
field_config = template_config.get("field_config", {})
|
||||
function_args = custom_component.get_function_entrypoint_args
|
||||
return_type = custom_component.get_function_entrypoint_return_type
|
||||
# template_config = custom_component.get_template_config
|
||||
|
||||
if function_args is not None:
|
||||
# Add extra fields
|
||||
for extra_field in function_args:
|
||||
field_required = True
|
||||
field_name, field_type, field_value = extra_field
|
||||
|
||||
if not field_type:
|
||||
field_type = ""
|
||||
field_name = extra_field.get("name") if "name" in extra_field else ""
|
||||
|
||||
if field_name != "self":
|
||||
field_type = extra_field.get("type") if "type" in extra_field else ""
|
||||
field_value = (
|
||||
extra_field.get("default") if "default" in extra_field else ""
|
||||
)
|
||||
field_required = True
|
||||
field_config = {}
|
||||
|
||||
# TODO: Validate type - if is possible to render into frontend
|
||||
if "optional" in field_type.lower():
|
||||
field_type = extract_type_from_optional(field_type)
|
||||
|
|
@ -189,17 +184,16 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
|
|||
if not field_type:
|
||||
field_type = "str"
|
||||
|
||||
config = field_config.get(field_name, {})
|
||||
frontend_node = add_new_custom_field(
|
||||
frontend_node,
|
||||
field_name,
|
||||
field_type,
|
||||
field_value,
|
||||
field_required,
|
||||
config,
|
||||
field_config,
|
||||
)
|
||||
|
||||
frontend_node = add_code_field(frontend_node, raw_code)
|
||||
frontend_node = add_code_field(frontend_node, custom_component.code)
|
||||
|
||||
# Get base classes from "return_type" and add to template.base_classes
|
||||
try:
|
||||
|
|
@ -214,8 +208,10 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
|
|||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
|
||||
return_type_instance = LANGCHAIN_BASE_TYPES.get(return_type)
|
||||
base_classes = get_base_classes(return_type_instance)
|
||||
|
||||
except (KeyError, AttributeError) as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue