🐛 fix(custom.py): import re module to fix NameError when using re.split

🐛 fix(custom.py): fix indentation of class_template dictionary to improve readability
🐛 fix(custom.py): fix indentation of class dictionary to improve readability
🐛 fix(custom.py): fix indentation of _handle_function method to improve readability
🐛 fix(custom.py): fix indentation of transform_list method to improve readability
🐛 fix(custom.py): fix indentation of extract_class_info method to improve readability
🐛 fix(custom.py): fix indentation of _class_template_validation method to improve readability
🐛 fix(custom.py): fix indentation of build_langchain_template_custom_component method to improve readability
🐛 fix(custom.py): fix indentation of add_new_custom_field method to improve readability
🐛 fix(custom.py): fix indentation of add_code_field method to improve readability
🐛 fix(custom.py): fix indentation of extract_type_from_optional method to improve readability
🐛 fix(custom.py): fix indentation of build_langchain_template_custom_component method to improve readability
🔥 chore(custom.py): remove unused imports and variables
 feat(custom.py): add support for splitting a string by ':' or '=' and padding with None until length is 3 in _split_string method
 feat(custom.py): add support for transforming a list of strings by splitting each string and padding with None in transform_list method
 feat(custom.py): add support for extracting the type from a string formatted as "Optional[<type>]" in extract_type_from_optional method
 feat(custom.py): add support for passing field_value and field_required parameters to add_new_custom_field method
 feat(custom.py): add support for passing field_value and field_required parameters to build_langchain_template_custom_component method
 feat(custom.py): add support for passing field_value and field_required parameters to add_new_custom_field method
 feat(custom.py): add support for passing field_value and field_required parameters to build_langchain_template_custom_component method
 feat(custom.py): add support for passing field_value and field_required parameters to add_new_custom_field method
 feat(custom.py): add support for passing field_value and field_required parameters to build_langchain_template_custom_component method
 feat(custom.py): add support for passing field_value and field_required parameters to add_new_custom_field method
 feat(custom.py): add support for
This commit is contained in:
gustavoschaedler 2023-07-10 23:38:01 +01:00
commit 719015b5bb
2 changed files with 77 additions and 31 deletions

View file

@ -1,5 +1,5 @@
import re
import ast
import contextlib
import traceback
from typing import Callable, Optional
from fastapi import HTTPException
@ -77,26 +77,34 @@ class CustomComponent(BaseModel):
else:
self.class_template["functions"].append(function_data)
def _split_string(self, text):
"""
Split a string by ':' or '=' and append None until the resulting list has 3 items.
Parameters:
text (str): The string to be split.
Returns:
list: A list of strings resulting from the split operation,
padded with None until its length is 3.
"""
items = [item.strip() for item in re.split(r"[:=]", text) if item.strip()]
while len(items) < 3:
items.append(None)
return items
def transform_list(self, input_list):
output_list = []
for item in input_list:
# Split each item on ':' to separate variable name and type
split_item = item.split(":")
"""
Transform a list of strings by splitting each string and padding with None.
# If there is a type, strip any leading/trailing spaces from it
if len(split_item) > 1:
split_item[1] = split_item[1].strip()
# If there isn't a type, append None
else:
split_item.append(None)
for i in range(len(split_item)):
with contextlib.suppress(ValueError):
# Try to evaluate the item
split_item[i] = ast.literal_eval(split_item[i])
Parameters:
input_list (list): The list of strings to be transformed.
output_list.append(split_item)
return output_list
Returns:
list: A list of lists, each containing the result of the split operation.
"""
return [self._split_string(item) for item in input_list]
def extract_class_info(self):
try:
@ -120,6 +128,7 @@ class CustomComponent(BaseModel):
attributes = data.get("class", {}).get("attributes", {})
functions = data.get("functions", [])
template_config = self._build_template_config(attributes)
if build_function := next(
(f for f in functions if f["name"] == self.function_entrypoint_name),
None,
@ -146,6 +155,7 @@ class CustomComponent(BaseModel):
)
if "description" in attributes:
template_config["description"] = ast.literal_eval(attributes["description"])
return template_config
def _class_template_validation(self, code: dict):

View file

@ -20,14 +20,14 @@ from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.tools import CustomComponentNode
from langflow.interface.retrievers.base import retriever_creator
from langflow.utils.util import get_base_classes
import re
import warnings
from fastapi import HTTPException
import traceback
from fastapi import HTTPException
from langflow.utils.util import get_base_classes
# Used to get the base_classes list
def get_type_list():
"""Get a list of all langchain types"""
all_types = build_langchain_types_dict()
@ -69,6 +69,7 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
created_types = creator.to_dict()
if created_types[creator.type_name].values():
all_types.update(created_types)
return all_types
@ -78,25 +79,35 @@ def process_type(field_type: str):
# TODO: Move to correct place
def add_new_custom_field(
template, field_name: str, field_type: str, field_config: dict
template,
field_name: str,
field_type: str,
field_value: str,
field_required: bool,
field_config: dict,
):
# Check field_config if any of the keys are in it
# if it is, update the value
display_name = field_config.pop("display_name", field_name)
field_type = field_config.pop("field_type", field_type)
field_type = process_type(field_type)
if field_value is not None:
field_value = field_value.replace("'", "").replace('"', "")
if "name" in field_config:
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
)
field_config.pop("name", None)
required = field_config.pop("required", True)
required = field_config.pop("required", field_required)
placeholder = field_config.pop("placeholder", "")
new_field = TemplateField(
name=field_name,
field_type=field_type,
value=field_value,
show=True,
required=required,
advanced=False,
@ -133,6 +144,20 @@ def add_code_field(template, raw_code):
return template
def extract_type_from_optional(field_type):
"""
Extract the type from a string formatted as "Optional[<type>]".
Parameters:
field_type (str): The string from which to extract the type.
Returns:
str: The extracted type, or an empty string if no type was found.
"""
match = re.search(r"\[(.*?)\]", field_type)
return match[1] if match else None
def build_langchain_template_custom_component(extractor: CustomComponent):
# Build base "CustomComponent" template
frontend_node = CustomComponentNode().to_dict().get(type(extractor).__name__)
@ -145,19 +170,30 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
frontend_node["description"] = template_config["description"]
raw_code = extractor.code
field_config = template_config.get("field_config", {})
if function_args is not None:
# Add extra fields
for extra_field in function_args:
def_field = extra_field[0]
def_type = extra_field[1]
field_required = True
field_name, field_type, field_value = extra_field
if def_field != "self":
if field_name != "self":
# TODO: Validate type - if is possible to render into frontend
if not def_type:
def_type = "str"
config = field_config.get(def_field, {})
if "optional" in field_type.lower():
field_type = extract_type_from_optional(field_type)
field_required = False
if not field_type:
field_type = "str"
config = field_config.get(field_name, {})
frontend_node = add_new_custom_field(
frontend_node, def_field, def_type, config
frontend_node,
field_name,
field_type,
field_value,
field_required,
config,
)
frontend_node = add_code_field(frontend_node, raw_code)