Refactor build_template_from_class function to remove unused code and improve readability
This commit is contained in:
parent
a3821f4367
commit
88754b9cf3
2 changed files with 0 additions and 68 deletions
|
|
@ -117,51 +117,3 @@ def set_langchain_cache(settings):
|
|||
logger.warning(f"Could not import {cache_type}. ")
|
||||
else:
|
||||
logger.info("No LLM cache set.")
|
||||
|
||||
|
||||
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
|
||||
classes = [item.__name__ for item in type_to_cls_dict.values()]
|
||||
|
||||
# Raise error if name is not in chains
|
||||
if name not in classes:
|
||||
raise ValueError(f"{name} not found.")
|
||||
|
||||
for _type, v in type_to_cls_dict.items():
|
||||
if v.__name__ == name:
|
||||
_class = v
|
||||
|
||||
# Get the docstring
|
||||
docs = parse(_class.__doc__)
|
||||
|
||||
variables = {"_type": _type}
|
||||
|
||||
if "__fields__" in _class.__dict__:
|
||||
for class_field_items, value in _class.__fields__.items():
|
||||
if class_field_items in ["callback_manager"]:
|
||||
continue
|
||||
variables[class_field_items] = {}
|
||||
for name_, value_ in value.__repr_args__():
|
||||
if name_ == "default_factory":
|
||||
try:
|
||||
variables[class_field_items]["default"] = get_default_factory(
|
||||
module=_class.__base__.__module__,
|
||||
function=value_,
|
||||
)
|
||||
except Exception:
|
||||
variables[class_field_items]["default"] = None
|
||||
elif name_ not in ["name"]:
|
||||
variables[class_field_items][name_] = value_
|
||||
|
||||
variables[class_field_items]["placeholder"] = (
|
||||
docs.params[class_field_items] if class_field_items in docs.params else ""
|
||||
)
|
||||
base_classes = get_base_classes(_class)
|
||||
# Adding function to base classes to allow
|
||||
# the output to be a function
|
||||
if add_function:
|
||||
base_classes.append("Callable")
|
||||
return {
|
||||
"template": format_dict(variables, name),
|
||||
"description": docs.short_description or "",
|
||||
"base_classes": base_classes,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import importlib
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from langflow.interface.utils import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_function, get_base_classes, get_default_factory
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -68,25 +67,6 @@ def test_build_template_from_function():
|
|||
build_template_from_function("NonExistent", type_to_loader_dict)
|
||||
|
||||
|
||||
# Test build_template_from_class
|
||||
def test_build_template_from_class():
|
||||
type_to_cls_dict: Dict[str, type] = {"parent": Parent, "child": Child}
|
||||
|
||||
# Test valid input
|
||||
result = build_template_from_class("Child", type_to_cls_dict)
|
||||
assert result is not None
|
||||
assert "template" in result
|
||||
assert "description" in result
|
||||
assert "base_classes" in result
|
||||
assert "Child" in result["base_classes"]
|
||||
assert "Parent" in result["base_classes"]
|
||||
assert result["description"] == "Child Class"
|
||||
|
||||
# Test invalid input
|
||||
with pytest.raises(ValueError, match="InvalidClass not found."):
|
||||
build_template_from_class("InvalidClass", type_to_cls_dict)
|
||||
|
||||
|
||||
# Test get_base_classes
|
||||
def test_get_base_classes():
|
||||
base_classes_parent = get_base_classes(Parent)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue