Refactor code to extract class name from custom code

- Updatethe `get_function_custom` function to use the `validate.extract_class_name` function to extract class name from the `code` parameter.
- Modify `instantiate_tool` function to update the `class_object` with the return value of `get_function_custom` and call the `build` method on the instantiated object with the remaining params.
- Add a new function `extract_class_name` in the `validate` module to extract class name from the `code` parameter.
This commit is contained in:
gustavoschaedler 2023-07-05 03:34:16 +01:00
commit 839e9737cc
5 changed files with 16 additions and 24 deletions

View file

@ -165,26 +165,6 @@ def get_function(code):
def get_function_custom(code):
function_name = "build"
class_name = "MyPythonClass"
code = """
from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
class MyPythonClass:
def my_conversation(self):
llm = OpenAI(temperature=0)
return ConversationChain(
llm=llm, verbose=True, memory=ConversationBufferMemory()
)
def build(self, my_custom_input: str) -> ConversationChain:
conversation = self.my_conversation()
return conversation
"""
class_name = validate.extract_class_name(code)
return validate.create_class(code, class_name)

View file

@ -178,8 +178,8 @@ def instantiate_tool(node_type, class_object, params):
params["func"] = get_function(params.get("code"))
return class_object(**params)
elif node_type == "CustomComponent":
params["func"] = get_function_custom(params.get("code"))
return class_object(**params)
class_object = get_function_custom(params.pop("code"))
return class_object().build(**params)
# For backward compatibility
elif node_type == "PythonFunction":
function_string = params["code"]

View file

@ -84,6 +84,7 @@ class CustomComponent(BaseModel):
function_entrypoint_name = "build"
return_type_valid_list = [
"ConversationChain",
"BaseLLM",
"Tool"
]
class_template = {

View file

@ -57,7 +57,10 @@ from langchain.memory import ConversationBufferMemory
class MyPythonClass:
def my_conversation(self):
llm = OpenAI(temperature=0)
llm = OpenAI(
openai_api_key='',
temperature=0
)
return ConversationChain(
llm=llm, verbose=True, memory=ConversationBufferMemory()
)

View file

@ -230,3 +230,11 @@ def extract_function_name(code):
if isinstance(node, ast.FunctionDef):
return node.name
raise ValueError("No function definition found in the code string")
def extract_class_name(code):
module = ast.parse(code)
for node in module.body:
if isinstance(node, ast.ClassDef):
return node.name
raise ValueError("No class definition found in the code string")