Refactor code to extract class name from custom code
- Updatethe `get_function_custom` function to use the `validate.extract_class_name` function to extract class name from the `code` parameter. - Modify `instantiate_tool` function to update the `class_object` with the return value of `get_function_custom` and call the `build` method on the instantiated object with the remaining params. - Add a new function `extract_class_name` in the `validate` module to extract class name from the `code` parameter.
This commit is contained in:
parent
ac569f3405
commit
839e9737cc
5 changed files with 16 additions and 24 deletions
|
|
@ -165,26 +165,6 @@ def get_function(code):
|
|||
|
||||
|
||||
def get_function_custom(code):
|
||||
function_name = "build"
|
||||
class_name = "MyPythonClass"
|
||||
|
||||
code = """
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.chains import ConversationChain
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
class MyPythonClass:
|
||||
def my_conversation(self):
|
||||
llm = OpenAI(temperature=0)
|
||||
return ConversationChain(
|
||||
llm=llm, verbose=True, memory=ConversationBufferMemory()
|
||||
)
|
||||
|
||||
def build(self, my_custom_input: str) -> ConversationChain:
|
||||
conversation = self.my_conversation()
|
||||
|
||||
return conversation
|
||||
"""
|
||||
class_name = validate.extract_class_name(code)
|
||||
|
||||
return validate.create_class(code, class_name)
|
||||
|
|
|
|||
|
|
@ -178,8 +178,8 @@ def instantiate_tool(node_type, class_object, params):
|
|||
params["func"] = get_function(params.get("code"))
|
||||
return class_object(**params)
|
||||
elif node_type == "CustomComponent":
|
||||
params["func"] = get_function_custom(params.get("code"))
|
||||
return class_object(**params)
|
||||
class_object = get_function_custom(params.pop("code"))
|
||||
return class_object().build(**params)
|
||||
# For backward compatibility
|
||||
elif node_type == "PythonFunction":
|
||||
function_string = params["code"]
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class CustomComponent(BaseModel):
|
|||
function_entrypoint_name = "build"
|
||||
return_type_valid_list = [
|
||||
"ConversationChain",
|
||||
"BaseLLM",
|
||||
"Tool"
|
||||
]
|
||||
class_template = {
|
||||
|
|
|
|||
|
|
@ -57,7 +57,10 @@ from langchain.memory import ConversationBufferMemory
|
|||
|
||||
class MyPythonClass:
|
||||
def my_conversation(self):
|
||||
llm = OpenAI(temperature=0)
|
||||
llm = OpenAI(
|
||||
openai_api_key='',
|
||||
temperature=0
|
||||
)
|
||||
return ConversationChain(
|
||||
llm=llm, verbose=True, memory=ConversationBufferMemory()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -230,3 +230,11 @@ def extract_function_name(code):
|
|||
if isinstance(node, ast.FunctionDef):
|
||||
return node.name
|
||||
raise ValueError("No function definition found in the code string")
|
||||
|
||||
|
||||
def extract_class_name(code):
|
||||
module = ast.parse(code)
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
return node.name
|
||||
raise ValueError("No class definition found in the code string")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue