[extract_info_from_class.py] Remove unnecessary code

This commit removes the file `extract_info_from_class.py` which contained unnecessary code.

[v1/endpoints.py] Fix error handling in custom_component endpoint

This commit fixes the error handling in the `custom_component` endpoint in `endpoints.py`. If the class template extracted from the code is not valid, an error message is printed.

[importing/utils.py] Comment out unused code

This commit comments out unused code in `get_function_custom` function in `utils.py` file.

[initialize/loading.py] Comment out unused code

This commit comments out unused code in the `instantiate_tool` function in `loading.py` file.

[interface/tools/custom.py] Refactor code and add properties

This commit refactors the code in `CustomComponent` class in `custom.py` file. It adds properties for `data`, `is_valid`, and `args_and_return_type`.

[interface/types.py] Add base classes to custom component template

This commit adds base classes to the custom component template in the `build_langchain_template_custom_component` function in `types.py` file.

[utils/constants.py] Remove unnecessary import

This commit removes an unnecessary import in `DEFAULT_CUSTOM_COMPONENT_CODE` constant in `constants.py` file.
This commit is contained in:
gustavoschaedler 2023-07-04 21:00:02 +01:00
commit 13bb0280f5
8 changed files with 128 additions and 176 deletions

View file

@ -1,124 +0,0 @@
import ast
class ClassCodeExtractor:
def __init__(self, code):
self.code = code
self.function_entrypoint_name = "build"
self.data = {
"imports": [],
"class": {
"inherited_classes": "",
"name": "",
"init": ""
},
"functions": []
}
def _handle_import(self, node):
for alias in node.names:
module_name = getattr(node, 'module', None)
self.data['imports'].append(
f"{module_name}.{alias.name}" if module_name else alias.name)
def _handle_class(self, node):
self.data['class'].update({
'name': node.name,
'inherited_classes': [ast.unparse(base) for base in node.bases]
})
for inner_node in node.body:
if isinstance(inner_node, ast.FunctionDef):
self._handle_function(inner_node)
def _handle_function(self, node):
function_name = node.name
function_args_str = ast.unparse(node.args)
function_args = function_args_str.split(
", ") if function_args_str else []
return_type = ast.unparse(node.returns) if node.returns else "None"
function_data = {
"name": function_name,
"arguments": function_args,
"return_type": return_type
}
if function_name == "__init__":
self.data['class']['init'] = function_args_str.split(
", ") if function_args_str else []
else:
self.data["functions"].append(function_data)
def transform_list(self, input_list):
output_list = []
for item in input_list:
# Split each item on ':' to separate variable name and type
split_item = item.split(':')
# If there is a type, strip any leading/trailing spaces from it
if len(split_item) > 1:
split_item[1] = split_item[1].strip()
# If there isn't a type, append None
else:
split_item.append(None)
output_list.append(split_item)
return output_list
def extract_class_info(self):
module = ast.parse(self.code)
for node in module.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._handle_import(node)
elif isinstance(node, ast.ClassDef):
self._handle_class(node)
return self.data
def get_entrypoint_function_args_and_return_type(self):
data = self.extract_class_info()
functions = data.get("functions", [])
build_function = next(
(f for f in functions if f["name"] ==
self.function_entrypoint_name), None
)
if build_function:
function_args = build_function.get("arguments", None)
function_args = self.transform_list(function_args)
return_type = build_function.get("return_type", None)
else:
function_args = None
return_type = None
return function_args, return_type
def is_valid_class_template(code: dict):
extractor = ClassCodeExtractor(code)
return_type_valid_list = ["ConversationChain", "Tool"]
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
return False
functions = code.get("functions", [])
# use a generator and next to find if a function matching the criteria exists
build_function = next(
(f for f in functions if f["name"] ==
extractor.function_entrypoint_name), None
)
if not build_function:
return False
# Check if the return type of the build function is valid
if build_function.get("return_type") not in return_type_valid_list:
return False
return True

View file

@ -104,16 +104,12 @@ async def custom_component(
raw_code: CustomComponentCode,
):
extractor = CustomComponent(code=raw_code.code)
data = extractor.extract_class_info()
valid = extractor.is_valid_class_template(data)
function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type()
if not extractor.is_valid:
print("ERROR")
# TODO: Raise error
return build_langchain_template_custom_component(
raw_code.code,
function_args,
function_return_type
)
return build_langchain_template_custom_component(extractor)
# TODO: Just for test - will be remove

View file

@ -166,5 +166,28 @@ def get_function(code):
def get_function_custom(code):
function_name = "build"
class_name = "MyPythonClass"
return validate.create_function(code, function_name)
code = """
from langchain.chains import ConversationChain
class MyPythonClass:
def __init__(self, title: str, author: str, year_published: int):
self.title = title
self.author = author
self.year_published = year_published
def get_details(self):
return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}"
def update_year_published(self, new_year: int):
self.year_published = new_year
print(f"The year of publication has been updated to {new_year}.")
def build(self, name, my_int, my_str, my_bool, no_type):
# do something...
print("x")
return ""
"""
return validate.create_class(code, class_name)

View file

@ -178,8 +178,9 @@ def instantiate_tool(node_type, class_object, params):
params["func"] = get_function(params.get("code"))
return class_object(**params)
elif node_type == "CustomComponent":
params["func"] = get_function_custom(params.get("code"))
return class_object(**params)
return get_function_custom(params.get("code"))
# params["func"] = get_function_custom(params.get("code"))
# return class_object(**params)
# For backward compatibility
elif node_type == "PythonFunction":
function_string = params["code"]

View file

@ -196,3 +196,21 @@ class CustomComponent(BaseModel):
return build_function.get("return_type") in self.return_type_valid_list
else:
return False
def get_function(self):
return validate.create_function(
self.code,
self.function_entrypoint_name
)
@property
def data(self):
return self.extract_class_info()
@property
def is_valid(self):
return self.is_valid_class_template(self.data)
@property
def args_and_return_type(self):
return self.get_entrypoint_function_args_and_return_type()

View file

@ -12,6 +12,7 @@ from langflow.interface.utilities.base import utility_creator
from langflow.interface.vector_store.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.interface.output_parsers.base import output_parser_creator
from langflow.interface.tools.custom import CustomComponent
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.tools import CustomComponentNode
@ -59,18 +60,6 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
return all_types
# TODO: Move to correct place
def find_class_type(class_name, classes_dict):
return next(
(
{"type": class_type, "class": class_name}
for class_type, class_list in classes_dict.items()
if class_name in class_list
),
{"error": "class not found"},
)
# TODO: Move to correct place
def add_new_custom_field(template, field_name: str, field_type: str):
new_field = TemplateField(
@ -107,25 +96,27 @@ def add_code_field(template, raw_code):
return template
def build_langchain_template_custom_component(raw_code, function_args, function_return_type):
# type_list = get_type_list()
# type_and_class = find_class_type("Tool", type_list)
# node = get_custom_nodes(node_type: str)
def build_langchain_template_custom_component(extractor: CustomComponent):
# Build base "CustomComponent" template
template = CustomComponentNode().to_dict().get(type(extractor).__name__)
# Build base CustomComponent template
template = CustomComponentNode().to_dict().get('CustomComponent')
function_args, return_type = extractor.args_and_return_type
raw_code = extractor.code
# Add extra fields
for extra_field in function_args:
if extra_field[0] != 'self':
def_field = extra_field[0]
def_type = extra_field[1]
if def_field != 'self':
# TODO: Validate type - if possible to render into frontend
if not extra_field[1]:
extra_field[1] = 'str'
if not def_type:
def_type = 'str'
template = add_new_custom_field(
template,
extra_field[0],
extra_field[1]
def_field,
def_type
)
template = add_code_field(
@ -133,7 +124,13 @@ def build_langchain_template_custom_component(raw_code, function_args, function_
raw_code
)
# criar um vertex
# olhar loading.py
# TODO: Build a vertex - loading.py
# TODO: Get base classes from "return_type" and add to template.base_classes
template.get('base_classes').append("ConversationChain")
template.get('base_classes').append("LLMChain")
template.get('base_classes').append("Chain")
template.get('base_classes').append("Serializable")
template.get('base_classes').append("function")
return template

View file

@ -50,27 +50,19 @@ def python_function(text: str) -> str:
"""
DEFAULT_CUSTOM_COMPONENT_CODE = """
from langflow.interface.chains.base import ChainCreator
from langflow.interface.tools.base import ToolCreator
from xyz.abc import MyClassA, MyClassB
from langchain.chains import ConversationChain
class MyPythonClass(MyClassA, MyClassB):
def __init__(self, title: str, author: str, year_published: int):
self.title = title
self.author = author
self.year_published = year_published
class MyPythonClass:
def __init__(self, name: str, year: int):
self.name = name
self.year = year
def get_details(self):
return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}"
return f"Name: {self.name}, Year: {self.year}"
def update_year_published(self, new_year: int):
self.year_published = new_year
print(f"The year of publication has been updated to {new_year}.")
def build(self, name: str, my_int: int, my_str: str, my_bool: bool, no_type) -> ConversationChain:
def build(self, name: str, year: int, true_or_false: bool, no_type) -> ConversationChain:
# do something...
return ConversationChain()
"""

View file

@ -108,7 +108,8 @@ def execute_function(code, function_name, *args, **kwargs):
try:
exec(code_obj, exec_globals, locals())
except Exception as exc:
raise ValueError("Function string does not contain a function") from exc
raise ValueError(
"Function string does not contain a function") from exc
# Add the function to the exec_globals dictionary
exec_globals[function_name] = locals()[function_name]
@ -163,6 +164,54 @@ def create_function(code, function_name):
return wrapped_function
def create_class(code, class_name):
if not hasattr(ast, "TypeIgnore"):
class TypeIgnore(ast.AST):
_fields = ()
ast.TypeIgnore = TypeIgnore
module = ast.parse(code)
exec_globals = globals().copy()
for node in module.body:
if isinstance(node, ast.Import):
for alias in node.names:
try:
exec_globals[alias.asname or alias.name] = importlib.import_module(
alias.name
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Module {alias.name} not found. Please install it and try again."
) from e
class_code = next(
node
for node in module.body
if isinstance(node, ast.ClassDef) and node.name == class_name
)
class_code.parent = None
code_obj = compile(
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
)
with contextlib.suppress(Exception):
exec(code_obj, exec_globals, locals())
exec_globals[class_name] = locals()[class_name]
# Return a function that imports necessary modules and creates an instance of the target class
def build(*args, **kwargs):
for module_name, module in exec_globals.items():
if isinstance(module, type(importlib)):
globals()[module_name] = module
instance = exec_globals[class_name](*args, **kwargs)
return instance
return build
def extract_function_name(code):
module = ast.parse(code)
for node in module.body: