feat: Add support for Custom Component in Langflow Interface
This commit adds support for Custom Component in the Langflow interface. It introduces a new class `CustomComponent`, which takes in a `code` as a parameter and validates it. The `CustomComponent` class also provides a method to get the function specified in the code. The commit also makes some modifications in `initialize/loading.py` file to handle the new `CustomComponent` class. It adds a new helper function `get_function_custom` which creates a function using `validate.create_function` and the `build` function name.
This commit is contained in:
parent
7d8304cb60
commit
dd009a2913
8 changed files with 228 additions and 58 deletions
|
|
@ -51,6 +51,22 @@ class ClassCodeExtractor:
|
|||
else:
|
||||
self.data["functions"].append(function_data)
|
||||
|
||||
def transform_list(self, input_list):
|
||||
output_list = []
|
||||
for item in input_list:
|
||||
# Split each item on ':' to separate variable name and type
|
||||
split_item = item.split(':')
|
||||
|
||||
# If there is a type, strip any leading/trailing spaces from it
|
||||
if len(split_item) > 1:
|
||||
split_item[1] = split_item[1].strip()
|
||||
# If there isn't a type, append None
|
||||
else:
|
||||
split_item.append(None)
|
||||
output_list.append(split_item)
|
||||
|
||||
return output_list
|
||||
|
||||
def extract_class_info(self):
|
||||
module = ast.parse(self.code)
|
||||
|
||||
|
|
@ -73,6 +89,8 @@ class ClassCodeExtractor:
|
|||
|
||||
if build_function:
|
||||
function_args = build_function.get("arguments", None)
|
||||
function_args = self.transform_list(function_args)
|
||||
|
||||
return_type = build_function.get("return_type", None)
|
||||
else:
|
||||
function_args = None
|
||||
|
|
@ -82,7 +100,7 @@ class ClassCodeExtractor:
|
|||
|
||||
|
||||
def is_valid_class_template(code: dict):
|
||||
function_entrypoint_name = "build"
|
||||
extractor = ClassCodeExtractor(code)
|
||||
return_type_valid_list = ["ConversationChain", "Tool"]
|
||||
|
||||
class_name = code.get("class", {}).get("name", None)
|
||||
|
|
@ -92,7 +110,8 @@ def is_valid_class_template(code: dict):
|
|||
functions = code.get("functions", [])
|
||||
# use a generator and next to find if a function matching the criteria exists
|
||||
build_function = next(
|
||||
(f for f in functions if f["name"] == function_entrypoint_name), None
|
||||
(f for f in functions if f["name"] ==
|
||||
extractor.function_entrypoint_name), None
|
||||
)
|
||||
|
||||
if not build_function:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,8 @@ def import_module(module_path: str) -> Any:
|
|||
def import_by_type(_type: str, name: str) -> Any:
|
||||
"""Import class by type and name"""
|
||||
if _type is None:
|
||||
raise ValueError(f"Type cannot be None. Check if {name} is in the config file.")
|
||||
raise ValueError(
|
||||
f"Type cannot be None. Check if {name} is in the config file.")
|
||||
func_dict = {
|
||||
"agents": import_agent,
|
||||
"prompts": import_prompt,
|
||||
|
|
@ -155,3 +156,9 @@ def get_function(code):
|
|||
function_name = validate.extract_function_name(code)
|
||||
|
||||
return validate.create_function(code, function_name)
|
||||
|
||||
|
||||
def get_function_custom(code):
|
||||
function_name = "build"
|
||||
|
||||
return validate.create_function(code, function_name)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,12 @@ from langflow.interface.initialize.vector_store import vecstore_initializer
|
|||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.importing.utils import (
|
||||
get_function,
|
||||
import_by_type,
|
||||
get_function_custom
|
||||
)
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.importing.utils import get_function, import_by_type
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
|
|
@ -131,9 +135,12 @@ def instantiate_tool(node_type, class_object, params):
|
|||
if node_type == "JsonSpec":
|
||||
params["dict_"] = load_file_into_dict(params.pop("path"))
|
||||
return class_object(**params)
|
||||
elif node_type in ["PythonFunctionTool", "CustomComponent"]:
|
||||
elif node_type == "PythonFunctionTool":
|
||||
params["func"] = get_function(params.get("code"))
|
||||
return class_object(**params)
|
||||
elif node_type == "CustomComponent":
|
||||
params["func"] = get_function_custom(params.get("code"))
|
||||
return class_object(**params)
|
||||
# For backward compatibility
|
||||
elif node_type == "PythonFunction":
|
||||
function_string = params["code"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import ast
|
||||
|
||||
from typing import Callable, Optional
|
||||
from langflow.interface.importing.utils import get_function
|
||||
|
||||
|
|
@ -34,8 +36,6 @@ class Function(BaseModel):
|
|||
|
||||
|
||||
class PythonFunctionTool(Function, Tool):
|
||||
"""Python function"""
|
||||
|
||||
name: str = "Custom Tool"
|
||||
description: str
|
||||
code: str
|
||||
|
|
@ -49,12 +49,149 @@ class PythonFunctionTool(Function, Tool):
|
|||
|
||||
|
||||
class PythonFunction(Function):
|
||||
"""Python function"""
|
||||
|
||||
code: str
|
||||
|
||||
|
||||
class CustomComponent(Function):
|
||||
"""Python function"""
|
||||
|
||||
class CustomComponent(BaseModel):
|
||||
code: str
|
||||
function: Optional[Callable] = None
|
||||
imports: Optional[str] = None
|
||||
|
||||
# Eval code and store the class
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
# Validate the Class code
|
||||
@validator("code")
|
||||
def validate_func(cls, v):
|
||||
try:
|
||||
validate.eval_function(v)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return v
|
||||
|
||||
def get_function(self):
|
||||
"""Get the function"""
|
||||
function_name = validate.extract_function_name(self.code)
|
||||
|
||||
return validate.create_function(self.code, function_name)
|
||||
|
||||
|
||||
class CustomComponent1(BaseModel):
|
||||
code: str
|
||||
function_entrypoint_name = "build"
|
||||
return_type_valid_list = [
|
||||
"ConversationChain",
|
||||
"Tool"
|
||||
]
|
||||
class_template = {
|
||||
"imports": [],
|
||||
"class": {
|
||||
"inherited_classes": "",
|
||||
"name": "",
|
||||
"init": ""
|
||||
},
|
||||
"functions": []
|
||||
}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def _handle_import(self, node):
|
||||
for alias in node.names:
|
||||
module_name = getattr(node, 'module', None)
|
||||
self.class_template['imports'].append(
|
||||
f"{module_name}.{alias.name}" if module_name else alias.name)
|
||||
|
||||
def _handle_class(self, node):
|
||||
self.class_template['class'].update({
|
||||
'name': node.name,
|
||||
'inherited_classes': [ast.unparse(base) for base in node.bases]
|
||||
})
|
||||
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.FunctionDef):
|
||||
self._handle_function(inner_node)
|
||||
|
||||
def _handle_function(self, node):
|
||||
function_name = node.name
|
||||
function_args_str = ast.unparse(node.args)
|
||||
function_args = function_args_str.split(
|
||||
", ") if function_args_str else []
|
||||
|
||||
return_type = ast.unparse(node.returns) if node.returns else "None"
|
||||
|
||||
function_data = {
|
||||
"name": function_name,
|
||||
"arguments": function_args,
|
||||
"return_type": return_type
|
||||
}
|
||||
|
||||
if function_name == "__init__":
|
||||
self.class_template['class']['init'] = function_args_str.split(
|
||||
", ") if function_args_str else []
|
||||
else:
|
||||
self.class_template["functions"].append(function_data)
|
||||
|
||||
def transform_list(self, input_list):
|
||||
output_list = []
|
||||
for item in input_list:
|
||||
# Split each item on ':' to separate variable name and type
|
||||
split_item = item.split(':')
|
||||
|
||||
# If there is a type, strip any leading/trailing spaces from it
|
||||
if len(split_item) > 1:
|
||||
split_item[1] = split_item[1].strip()
|
||||
# If there isn't a type, append None
|
||||
else:
|
||||
split_item.append(None)
|
||||
output_list.append(split_item)
|
||||
|
||||
return output_list
|
||||
|
||||
def extract_class_info(self):
|
||||
module = ast.parse(self.code)
|
||||
|
||||
for node in module.body:
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
self._handle_import(node)
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
self._handle_class(node)
|
||||
|
||||
return self.class_template
|
||||
|
||||
def get_entrypoint_function_args_and_return_type(self):
|
||||
data = self.extract_class_info()
|
||||
functions = data.get("functions", [])
|
||||
|
||||
if build_function := next(
|
||||
(f for f in functions if f["name"]
|
||||
== self.function_entrypoint_name),
|
||||
None,
|
||||
):
|
||||
function_args = build_function.get("arguments", None)
|
||||
function_args = self.transform_list(function_args)
|
||||
|
||||
return_type = build_function.get("return_type", None)
|
||||
else:
|
||||
function_args = None
|
||||
return_type = None
|
||||
|
||||
return function_args, return_type
|
||||
|
||||
def is_valid_class_template(self, code: dict):
|
||||
class_name = code.get("class", {}).get("name", None)
|
||||
if not class_name: # this will also check for None, empty string, etc.
|
||||
return False
|
||||
|
||||
functions = code.get("functions", [])
|
||||
if build_function := next(
|
||||
(f for f in functions if f["name"]
|
||||
== self.function_entrypoint_name),
|
||||
None,
|
||||
):
|
||||
# Check if the return type of the build function is valid
|
||||
return build_function.get("return_type") in self.return_type_valid_list
|
||||
else:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -82,9 +82,8 @@ def add_new_custom_field(template, field_name: str, field_type: str):
|
|||
|
||||
return template
|
||||
|
||||
|
||||
# TODO: Move to correct place
|
||||
|
||||
|
||||
def add_code_field(template, raw_code):
|
||||
# Field with the Python code to allow update
|
||||
code_field = {
|
||||
|
|
@ -111,29 +110,21 @@ def build_langchain_template_custom_component(raw_code, function_args, function_
|
|||
# type_and_class = find_class_type("Tool", type_list)
|
||||
# node = get_custom_nodes(node_type: str)
|
||||
|
||||
# TODO: Build base template
|
||||
template = llm_creator.to_dict()['llms']['ChatOpenAI']
|
||||
|
||||
# Build base CustomComponent template
|
||||
template = CustomComponentNode().to_dict().get('CustomComponent')
|
||||
|
||||
# TODO: Add extra fields
|
||||
template = add_new_custom_field(
|
||||
template,
|
||||
"my_id",
|
||||
"str"
|
||||
)
|
||||
# Add extra fields
|
||||
for extra_field in function_args:
|
||||
if extra_field[0] != 'self':
|
||||
# TODO: Validate type - if possible to render into frontend
|
||||
if not extra_field[1]:
|
||||
extra_field[1] = 'str'
|
||||
|
||||
template = add_new_custom_field(
|
||||
template,
|
||||
"year",
|
||||
"int"
|
||||
)
|
||||
|
||||
template = add_new_custom_field(
|
||||
template,
|
||||
"other_field",
|
||||
"bool"
|
||||
)
|
||||
template = add_new_custom_field(
|
||||
template,
|
||||
extra_field[0],
|
||||
extra_field[1]
|
||||
)
|
||||
|
||||
template = add_code_field(
|
||||
template,
|
||||
|
|
@ -144,5 +135,3 @@ def build_langchain_template_custom_component(raw_code, function_args, function_
|
|||
# olhar loading.py
|
||||
|
||||
return template
|
||||
# return globals()['tool_creator'].to_dict()[type_and_class['type']][type_and_class['class']]
|
||||
# return chain_creator.to_dict()['chains']['ConversationChain']
|
||||
|
|
|
|||
|
|
@ -5,13 +5,6 @@ from langflow.api import router
|
|||
from langflow.database.base import create_db_and_tables
|
||||
from langflow.interface.utils import setup_llm_caching
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ErrorMessage(BaseModel):
|
||||
detail: str
|
||||
traceback: str
|
||||
|
||||
|
||||
def create_app():
|
||||
"""Create the FastAPI app and include the router."""
|
||||
|
|
|
|||
|
|
@ -50,10 +50,28 @@ def python_function(text: str) -> str:
|
|||
"""
|
||||
|
||||
DEFAULT_CUSTOM_COMPONENT_CODE = """
|
||||
def custom_component(text: str) -> str:
|
||||
\"\"\"This is a default custom component function that returns the input text\"\"\"
|
||||
\"\"\"TODO: Add a Class template\"\"\"
|
||||
return text
|
||||
from langflow.interface.chains.base import ChainCreator
|
||||
from langflow.interface.tools.base import ToolCreator
|
||||
from xyz.abc import MyClassA, MyClassB
|
||||
|
||||
|
||||
class MyPythonClass(MyClassA, MyClassB):
|
||||
def __init__(self, title: str, author: str, year_published: int):
|
||||
self.title = title
|
||||
self.author = author
|
||||
self.year_published = year_published
|
||||
|
||||
def get_details(self):
|
||||
return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}"
|
||||
|
||||
def update_year_published(self, new_year: int):
|
||||
self.year_published = new_year
|
||||
print(f"The year of publication has been updated to {new_year}.")
|
||||
|
||||
def build(self, name: str, my_int: int, my_str: str, my_bool: bool, no_type) -> ConversationChain:
|
||||
# do something...
|
||||
|
||||
return ConversationChain()
|
||||
"""
|
||||
|
||||
DIRECT_TYPES = ["str", "bool", "code", "int", "float", "Any", "prompt"]
|
||||
|
|
|
|||
|
|
@ -98,17 +98,17 @@ export default function CodeAreaModal({
|
|||
title: "There is something wrong with this code, please review it",
|
||||
});
|
||||
});
|
||||
// postCustomComponent(code, nodeClass).then((apiReturn) => {
|
||||
// const data = apiReturn.data;
|
||||
// if (data) {
|
||||
// setNodeClass(data);
|
||||
// setModalOpen(false);
|
||||
// }
|
||||
// });
|
||||
axios.get("/api/v1/custom_component_error").catch((err) => {
|
||||
console.log(err.response.data);
|
||||
setError(err.response.data);
|
||||
})
|
||||
postCustomComponent(code, nodeClass).then((apiReturn) => {
|
||||
const {data} = apiReturn;
|
||||
if (data) {
|
||||
setNodeClass(data);
|
||||
setModalOpen(false);
|
||||
}
|
||||
});
|
||||
// axios.get("/api/v1/custom_component_error").catch((err) => {
|
||||
// console.log(err.response.data);
|
||||
// setError(err.response.data);
|
||||
// })
|
||||
|
||||
}
|
||||
const tabs = [{ name: "code" }, { name: "errors" }]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue