diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index 1176de07c..2b1db217c 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -165,26 +165,6 @@ def get_function(code): def get_function_custom(code): - function_name = "build" - class_name = "MyPythonClass" - - code = """ -from langchain.llms import OpenAI -from langchain.chains import ConversationChain -from langchain.memory import ConversationBufferMemory - - -class MyPythonClass: - def my_conversation(self): - llm = OpenAI(temperature=0) - return ConversationChain( - llm=llm, verbose=True, memory=ConversationBufferMemory() - ) - - def build(self, my_custom_input: str) -> ConversationChain: - conversation = self.my_conversation() - - return conversation - """ + class_name = validate.extract_class_name(code) return validate.create_class(code, class_name) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 53679ea9a..f50b34ac3 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -178,8 +178,8 @@ def instantiate_tool(node_type, class_object, params): params["func"] = get_function(params.get("code")) return class_object(**params) elif node_type == "CustomComponent": - params["func"] = get_function_custom(params.get("code")) - return class_object(**params) + class_object = get_function_custom(params.pop("code")) + return class_object().build(**params) # For backward compatibility elif node_type == "PythonFunction": function_string = params["code"] diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index 522325a9e..8f5c9fcd8 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -84,6 +84,7 @@ class CustomComponent(BaseModel): function_entrypoint_name = "build" return_type_valid_list = [ "ConversationChain", + "BaseLLM", "Tool" ] class_template = { diff --git a/src/backend/langflow/utils/constants.py b/src/backend/langflow/utils/constants.py index a7de648b3..930cecc73 100644 --- a/src/backend/langflow/utils/constants.py +++ b/src/backend/langflow/utils/constants.py @@ -57,7 +57,10 @@ from langchain.memory import ConversationBufferMemory class MyPythonClass: def my_conversation(self): - llm = OpenAI(temperature=0) + llm = OpenAI( + openai_api_key='', + temperature=0 + ) return ConversationChain( llm=llm, verbose=True, memory=ConversationBufferMemory() ) diff --git a/src/backend/langflow/utils/validate.py b/src/backend/langflow/utils/validate.py index ae8589a3c..bc13238c3 100644 --- a/src/backend/langflow/utils/validate.py +++ b/src/backend/langflow/utils/validate.py @@ -230,3 +230,11 @@ def extract_function_name(code): if isinstance(node, ast.FunctionDef): return node.name raise ValueError("No function definition found in the code string") + + +def extract_class_name(code): + module = ast.parse(code) + for node in module.body: + if isinstance(node, ast.ClassDef): + return node.name + raise ValueError("No class definition found in the code string")