From dd009a2913ea16fcb6a812dcc5b3ae173fb9d5c6 Mon Sep 17 00:00:00 2001 From: gustavoschaedler Date: Fri, 30 Jun 2023 00:23:50 +0100 Subject: [PATCH] feat: Add support for Custom Component in Langflow Interface This commit adds support for Custom Component in the Langflow interface. It introduces a new class `CustomComponent`, which takes in a `code` as a parameter and validates it. The `CustomComponent` class also provides a method to get the function specified in the code. The commit also makes some modifications in `initialize/loading.py` file to handle the new `CustomComponent` class. It adds a new helper function `get_function_custom` which creates a function using `validate.create_function` and the `build` function name. --- .../langflow/api/extract_info_from_class.py | 23 ++- .../langflow/interface/importing/utils.py | 9 +- .../langflow/interface/initialize/loading.py | 11 +- .../langflow/interface/tools/custom.py | 151 +++++++++++++++++- src/backend/langflow/interface/types.py | 37 ++--- src/backend/langflow/main.py | 7 - src/backend/langflow/utils/constants.py | 26 ++- .../src/modals/codeAreaModal/index.tsx | 22 +-- 8 files changed, 228 insertions(+), 58 deletions(-) diff --git a/src/backend/langflow/api/extract_info_from_class.py b/src/backend/langflow/api/extract_info_from_class.py index 032652b44..2d76d6604 100644 --- a/src/backend/langflow/api/extract_info_from_class.py +++ b/src/backend/langflow/api/extract_info_from_class.py @@ -51,6 +51,22 @@ class ClassCodeExtractor: else: self.data["functions"].append(function_data) + def transform_list(self, input_list): + output_list = [] + for item in input_list: + # Split each item on ':' to separate variable name and type + split_item = item.split(':') + + # If there is a type, strip any leading/trailing spaces from it + if len(split_item) > 1: + split_item[1] = split_item[1].strip() + # If there isn't a type, append None + else: + split_item.append(None) + output_list.append(split_item) + + return output_list + def extract_class_info(self): module = ast.parse(self.code) @@ -73,6 +89,8 @@ class ClassCodeExtractor: if build_function: function_args = build_function.get("arguments", None) + function_args = self.transform_list(function_args) + return_type = build_function.get("return_type", None) else: function_args = None @@ -82,7 +100,7 @@ class ClassCodeExtractor: def is_valid_class_template(code: dict): - function_entrypoint_name = "build" + extractor = ClassCodeExtractor(code) return_type_valid_list = ["ConversationChain", "Tool"] class_name = code.get("class", {}).get("name", None) @@ -92,7 +110,8 @@ def is_valid_class_template(code: dict): functions = code.get("functions", []) # use a generator and next to find if a function matching the criteria exists build_function = next( - (f for f in functions if f["name"] == function_entrypoint_name), None + (f for f in functions if f["name"] == + extractor.function_entrypoint_name), None ) if not build_function: diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index f65376d48..3ba40ba31 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -29,7 +29,8 @@ def import_module(module_path: str) -> Any: def import_by_type(_type: str, name: str) -> Any: """Import class by type and name""" if _type is None: - raise ValueError(f"Type cannot be None. Check if {name} is in the config file.") + raise ValueError( + f"Type cannot be None. Check if {name} is in the config file.") func_dict = { "agents": import_agent, "prompts": import_prompt, @@ -155,3 +156,9 @@ def get_function(code): function_name = validate.extract_function_name(code) return validate.create_function(code, function_name) + + +def get_function_custom(code): + function_name = "build" + + return validate.create_function(code, function_name) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 1ddae19b7..26890fc1c 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -10,8 +10,12 @@ from langflow.interface.initialize.vector_store import vecstore_initializer from pydantic import ValidationError +from langflow.interface.importing.utils import ( + get_function, + import_by_type, + get_function_custom +) from langflow.interface.custom_lists import CUSTOM_NODES -from langflow.interface.importing.utils import get_function, import_by_type from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator from langflow.interface.utils import load_file_into_dict @@ -131,9 +135,12 @@ def instantiate_tool(node_type, class_object, params): if node_type == "JsonSpec": params["dict_"] = load_file_into_dict(params.pop("path")) return class_object(**params) - elif node_type in ["PythonFunctionTool", "CustomComponent"]: + elif node_type == "PythonFunctionTool": params["func"] = get_function(params.get("code")) return class_object(**params) + elif node_type == "CustomComponent": + params["func"] = get_function_custom(params.get("code")) + return class_object(**params) # For backward compatibility elif node_type == "PythonFunction": function_string = params["code"] diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index 6c6703f36..14de8a205 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -1,3 +1,5 @@ +import ast + from typing import Callable, Optional from langflow.interface.importing.utils import get_function @@ -34,8 +36,6 @@ class Function(BaseModel): class PythonFunctionTool(Function, Tool): - """Python function""" - name: str = "Custom Tool" description: str code: str @@ -49,12 +49,149 @@ class PythonFunctionTool(Function, Tool): class PythonFunction(Function): - """Python function""" - code: str -class CustomComponent(Function): - """Python function""" - +class CustomComponent(BaseModel): code: str + function: Optional[Callable] = None + imports: Optional[str] = None + + # Eval code and store the class + def __init__(self, **data): + super().__init__(**data) + + # Validate the Class code + @validator("code") + def validate_func(cls, v): + try: + validate.eval_function(v) + except Exception as e: + raise e + + return v + + def get_function(self): + """Get the function""" + function_name = validate.extract_function_name(self.code) + + return validate.create_function(self.code, function_name) + + +class CustomComponent1(BaseModel): + code: str + function_entrypoint_name = "build" + return_type_valid_list = [ + "ConversationChain", + "Tool" + ] + class_template = { + "imports": [], + "class": { + "inherited_classes": "", + "name": "", + "init": "" + }, + "functions": [] + } + + def __init__(self, **data): + super().__init__(**data) + + def _handle_import(self, node): + for alias in node.names: + module_name = getattr(node, 'module', None) + self.class_template['imports'].append( + f"{module_name}.{alias.name}" if module_name else alias.name) + + def _handle_class(self, node): + self.class_template['class'].update({ + 'name': node.name, + 'inherited_classes': [ast.unparse(base) for base in node.bases] + }) + + for inner_node in node.body: + if isinstance(inner_node, ast.FunctionDef): + self._handle_function(inner_node) + + def _handle_function(self, node): + function_name = node.name + function_args_str = ast.unparse(node.args) + function_args = function_args_str.split( + ", ") if function_args_str else [] + + return_type = ast.unparse(node.returns) if node.returns else "None" + + function_data = { + "name": function_name, + "arguments": function_args, + "return_type": return_type + } + + if function_name == "__init__": + self.class_template['class']['init'] = function_args_str.split( + ", ") if function_args_str else [] + else: + self.class_template["functions"].append(function_data) + + def transform_list(self, input_list): + output_list = [] + for item in input_list: + # Split each item on ':' to separate variable name and type + split_item = item.split(':') + + # If there is a type, strip any leading/trailing spaces from it + if len(split_item) > 1: + split_item[1] = split_item[1].strip() + # If there isn't a type, append None + else: + split_item.append(None) + output_list.append(split_item) + + return output_list + + def extract_class_info(self): + module = ast.parse(self.code) + + for node in module.body: + if isinstance(node, (ast.Import, ast.ImportFrom)): + self._handle_import(node) + elif isinstance(node, ast.ClassDef): + self._handle_class(node) + + return self.class_template + + def get_entrypoint_function_args_and_return_type(self): + data = self.extract_class_info() + functions = data.get("functions", []) + + if build_function := next( + (f for f in functions if f["name"] + == self.function_entrypoint_name), + None, + ): + function_args = build_function.get("arguments", None) + function_args = self.transform_list(function_args) + + return_type = build_function.get("return_type", None) + else: + function_args = None + return_type = None + + return function_args, return_type + + def is_valid_class_template(self, code: dict): + class_name = code.get("class", {}).get("name", None) + if not class_name: # this will also check for None, empty string, etc. + return False + + functions = code.get("functions", []) + if build_function := next( + (f for f in functions if f["name"] + == self.function_entrypoint_name), + None, + ): + # Check if the return type of the build function is valid + return build_function.get("return_type") in self.return_type_valid_list + else: + return False diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 28c3a8840..704be06fd 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -82,9 +82,8 @@ def add_new_custom_field(template, field_name: str, field_type: str): return template + # TODO: Move to correct place - - def add_code_field(template, raw_code): # Field with the Python code to allow update code_field = { @@ -111,29 +110,21 @@ def build_langchain_template_custom_component(raw_code, function_args, function_ # type_and_class = find_class_type("Tool", type_list) # node = get_custom_nodes(node_type: str) - # TODO: Build base template - template = llm_creator.to_dict()['llms']['ChatOpenAI'] - + # Build base CustomComponent template template = CustomComponentNode().to_dict().get('CustomComponent') - # TODO: Add extra fields - template = add_new_custom_field( - template, - "my_id", - "str" - ) + # Add extra fields + for extra_field in function_args: + if extra_field[0] != 'self': + # TODO: Validate type - if possible to render into frontend + if not extra_field[1]: + extra_field[1] = 'str' - template = add_new_custom_field( - template, - "year", - "int" - ) - - template = add_new_custom_field( - template, - "other_field", - "bool" - ) + template = add_new_custom_field( + template, + extra_field[0], + extra_field[1] + ) template = add_code_field( template, @@ -144,5 +135,3 @@ def build_langchain_template_custom_component(raw_code, function_args, function_ # olhar loading.py return template - # return globals()['tool_creator'].to_dict()[type_and_class['type']][type_and_class['class']] - # return chain_creator.to_dict()['chains']['ConversationChain'] diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index e6594742c..2a1293f2e 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -5,13 +5,6 @@ from langflow.api import router from langflow.database.base import create_db_and_tables from langflow.interface.utils import setup_llm_caching -from pydantic import BaseModel - - -class ErrorMessage(BaseModel): - detail: str - traceback: str - def create_app(): """Create the FastAPI app and include the router.""" diff --git a/src/backend/langflow/utils/constants.py b/src/backend/langflow/utils/constants.py index 70ad06ee3..ee03f71da 100644 --- a/src/backend/langflow/utils/constants.py +++ b/src/backend/langflow/utils/constants.py @@ -50,10 +50,28 @@ def python_function(text: str) -> str: """ DEFAULT_CUSTOM_COMPONENT_CODE = """ -def custom_component(text: str) -> str: - \"\"\"This is a default custom component function that returns the input text\"\"\" - \"\"\"TODO: Add a Class template\"\"\" - return text +from langflow.interface.chains.base import ChainCreator +from langflow.interface.tools.base import ToolCreator +from xyz.abc import MyClassA, MyClassB + + +class MyPythonClass(MyClassA, MyClassB): + def __init__(self, title: str, author: str, year_published: int): + self.title = title + self.author = author + self.year_published = year_published + + def get_details(self): + return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}" + + def update_year_published(self, new_year: int): + self.year_published = new_year + print(f"The year of publication has been updated to {new_year}.") + + def build(self, name: str, my_int: int, my_str: str, my_bool: bool, no_type) -> ConversationChain: + # do something... + + return ConversationChain() """ DIRECT_TYPES = ["str", "bool", "code", "int", "float", "Any", "prompt"] diff --git a/src/frontend/src/modals/codeAreaModal/index.tsx b/src/frontend/src/modals/codeAreaModal/index.tsx index 1e7166c00..1605a21db 100644 --- a/src/frontend/src/modals/codeAreaModal/index.tsx +++ b/src/frontend/src/modals/codeAreaModal/index.tsx @@ -98,17 +98,17 @@ export default function CodeAreaModal({ title: "There is something wrong with this code, please review it", }); }); - // postCustomComponent(code, nodeClass).then((apiReturn) => { - // const data = apiReturn.data; - // if (data) { - // setNodeClass(data); - // setModalOpen(false); - // } - // }); - axios.get("/api/v1/custom_component_error").catch((err) => { - console.log(err.response.data); - setError(err.response.data); - }) + postCustomComponent(code, nodeClass).then((apiReturn) => { + const {data} = apiReturn; + if (data) { + setNodeClass(data); + setModalOpen(false); + } + }); + // axios.get("/api/v1/custom_component_error").catch((err) => { + // console.log(err.response.data); + // setError(err.response.data); + // }) } const tabs = [{ name: "code" }, { name: "errors" }]