Merge branch 'toolkits' of https://github.com/logspace-ai/langflow into toolkits
This commit is contained in:
commit
0b760003b1
17 changed files with 1310 additions and 87 deletions
0
src/backend/langflow/cache/__init__.py
vendored
Normal file
0
src/backend/langflow/cache/__init__.py
vendored
Normal file
49
src/backend/langflow/cache/utils.py
vendored
Normal file
49
src/backend/langflow/cache/utils.py
vendored
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import dill
|
||||
|
||||
PREFIX = "langflow_cache"
|
||||
|
||||
|
||||
def clear_old_cache_files(max_cache_size: int = 10):
|
||||
cache_dir = Path(tempfile.gettempdir())
|
||||
cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill"))
|
||||
|
||||
if len(cache_files) > max_cache_size:
|
||||
cache_files_sorted_by_mtime = sorted(
|
||||
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
|
||||
)
|
||||
|
||||
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(cache_file)
|
||||
|
||||
|
||||
def remove_position_info(node):
|
||||
node.pop("position", None)
|
||||
|
||||
|
||||
def compute_hash(graph_data):
|
||||
for node in graph_data["nodes"]:
|
||||
remove_position_info(node)
|
||||
|
||||
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
|
||||
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def save_cache(hash_val, chat_data):
|
||||
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
|
||||
with cache_path.open("wb") as cache_file:
|
||||
dill.dump(chat_data, cache_file)
|
||||
|
||||
|
||||
def load_cache(hash_val):
|
||||
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
|
||||
if cache_path.exists():
|
||||
with cache_path.open("rb") as cache_file:
|
||||
return dill.load(cache_file)
|
||||
return None
|
||||
|
|
@ -29,4 +29,8 @@ tools:
|
|||
wrappers:
|
||||
- RequestsWrapper
|
||||
|
||||
toolkits:
|
||||
- OpenAPIToolkit
|
||||
- JsonToolkit
|
||||
|
||||
dev: false
|
||||
|
|
|
|||
|
|
@ -68,7 +68,9 @@ class PromptNode(Node):
|
|||
)
|
||||
self.params["tools"] = tools
|
||||
prompt_params = [
|
||||
key for key, value in self.params.items() if value["type"] == "str"
|
||||
key
|
||||
for key, value in self.params.items()
|
||||
if isinstance(value, str) and key != "format_instructions"
|
||||
]
|
||||
else:
|
||||
prompt_params = ["template"]
|
||||
|
|
|
|||
|
|
@ -2,30 +2,48 @@ import contextlib
|
|||
import io
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
from langflow.cache.utils import compute_hash, load_cache, save_cache
|
||||
|
||||
from langflow.graph.graph import Graph
|
||||
from langflow.interface import loading
|
||||
from langflow.utils import payload
|
||||
|
||||
|
||||
def load_langchain_object(data_graph):
|
||||
computed_hash = compute_hash(data_graph)
|
||||
|
||||
# Load langchain_object from cache if it exists
|
||||
langchain_object = load_cache(computed_hash)
|
||||
if langchain_object is None:
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
langchain_object = graph.build()
|
||||
|
||||
return computed_hash, langchain_object
|
||||
|
||||
|
||||
def process_graph(data_graph: Dict[str, Any]):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
# ? Is this necessary?
|
||||
nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
langchain_object = graph.build()
|
||||
# Load langchain object
|
||||
computed_hash, langchain_object = load_langchain_object(data_graph)
|
||||
message = data_graph["message"]
|
||||
# Process json
|
||||
|
||||
# Generate result and thought
|
||||
result, thought = get_result_and_thought_using_graph(langchain_object, message)
|
||||
|
||||
# Save langchain_object to cache
|
||||
# We have to save it here because if the
|
||||
# memory is updated we need to keep the new values
|
||||
save_cache(computed_hash, langchain_object)
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"thought": re.sub(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from langchain.agents import agent_toolkits
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class, import_module
|
||||
from langflow.settings import settings
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -33,7 +34,7 @@ class ToolkitCreator(LangChainTypeCreator):
|
|||
)
|
||||
# if toolkit_name is not lower case it is a class
|
||||
for toolkit_name in agent_toolkits.__all__
|
||||
if not toolkit_name.islower()
|
||||
if not toolkit_name.islower() and toolkit_name in settings.toolkits
|
||||
}
|
||||
|
||||
return self.type_dict
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class ToolCreator(LangChainTypeCreator):
|
|||
# Copy the field and add the name
|
||||
fields = []
|
||||
for param in params:
|
||||
field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"])
|
||||
field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"]).copy()
|
||||
field.name = param
|
||||
if param == "aiosession":
|
||||
field.show = False
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class Settings(BaseSettings):
|
|||
tools: List[str] = []
|
||||
memories: List[str] = []
|
||||
wrappers: List[str] = []
|
||||
toolkits: List[str] = []
|
||||
dev: bool = False
|
||||
|
||||
class Config:
|
||||
|
|
@ -35,6 +36,7 @@ class Settings(BaseSettings):
|
|||
self.tools = new_settings.tools or []
|
||||
self.memories = new_settings.memories or []
|
||||
self.wrappers = new_settings.wrappers or []
|
||||
self.toolkits = new_settings.toolkits or []
|
||||
self.dev = new_settings.dev or False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union, Dict
|
||||
from langflow.utils import constants
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -17,6 +18,7 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
file_types: list[str] = []
|
||||
content: Union[str, None] = None
|
||||
password: bool = False
|
||||
options: list[str] = []
|
||||
# _name will be used to store the name of the field
|
||||
# in the template
|
||||
name: str = ""
|
||||
|
|
@ -37,6 +39,88 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
result["content"] = self.content
|
||||
return result
|
||||
|
||||
def process_field(
|
||||
self, key: str, value: Dict[str, Any], name: Optional[str] = None
|
||||
) -> None:
|
||||
_type = value["type"]
|
||||
|
||||
# Remove 'Optional' wrapper
|
||||
if "Optional" in _type:
|
||||
_type = _type.replace("Optional[", "")[:-1]
|
||||
|
||||
# Check for list type
|
||||
if "List" in _type:
|
||||
_type = _type.replace("List[", "")[:-1]
|
||||
self.is_list = True
|
||||
else:
|
||||
self.is_list = False
|
||||
|
||||
# Replace 'Mapping' with 'dict'
|
||||
if "Mapping" in _type:
|
||||
_type = _type.replace("Mapping", "dict")
|
||||
|
||||
# Change type from str to Tool
|
||||
self.field_type = "Tool" if key in ["allowed_tools"] else _type
|
||||
|
||||
self.field_type = "int" if key in ["max_value_length"] else self.field_type
|
||||
|
||||
# Show or not field
|
||||
self.show = bool(
|
||||
(self.required and key not in ["input_variables"])
|
||||
or key
|
||||
in [
|
||||
"allowed_tools",
|
||||
"memory",
|
||||
"prefix",
|
||||
"examples",
|
||||
"temperature",
|
||||
"model_name",
|
||||
"headers",
|
||||
"max_value_length",
|
||||
]
|
||||
or "api_key" in key
|
||||
)
|
||||
|
||||
# Add password field
|
||||
self.password = any(
|
||||
text in key.lower() for text in ["password", "token", "api", "key"]
|
||||
)
|
||||
|
||||
# Add multline
|
||||
self.multiline = key in [
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
"examples",
|
||||
"code",
|
||||
"headers",
|
||||
]
|
||||
|
||||
# Replace dict type with str
|
||||
if "dict" in self.field_type.lower():
|
||||
self.field_type = "code"
|
||||
|
||||
if key == "dict_":
|
||||
self.field_type = "file"
|
||||
self.suffixes = [".json", ".yaml", ".yml"]
|
||||
self.file_types = ["json", "yaml", "yml"]
|
||||
|
||||
# Replace default value with actual value
|
||||
if "default" in value:
|
||||
self.value = value["default"]
|
||||
|
||||
if key == "headers":
|
||||
self.value = """{'Authorization':
|
||||
'Bearer <token>'}"""
|
||||
|
||||
# Add options to openai
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
self.options = constants.OPENAI_MODELS
|
||||
self.is_list = True
|
||||
elif name == "OpenAIChat" and key == "model_name":
|
||||
self.options = constants.CHAT_OPENAI_MODELS
|
||||
self.is_list = True
|
||||
|
||||
|
||||
class TemplateField(TemplateFieldCreator):
|
||||
pass
|
||||
|
|
@ -46,7 +130,13 @@ class Template(BaseModel):
|
|||
type_name: str
|
||||
fields: list[TemplateField]
|
||||
|
||||
def process_fields(self, name: Optional[str] = None) -> None:
|
||||
for field in self.fields:
|
||||
signature = field.to_dict()
|
||||
field.process_field(field.name, signature, name)
|
||||
|
||||
def to_dict(self):
|
||||
self.process_fields(self.type_name)
|
||||
result = {field.name: field.to_dict() for field in self.fields}
|
||||
result["_type"] = self.type_name # type: ignore
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def build_template_from_function(
|
|||
return {
|
||||
"template": format_dict(variables, name),
|
||||
"description": docs["Description"],
|
||||
"base_classes": get_base_classes(_class),
|
||||
"base_classes": base_classes,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ def get_base_classes(cls):
|
|||
result = [cls.__name__]
|
||||
if not result:
|
||||
result = [cls.__name__]
|
||||
return list(set(result))
|
||||
return list(set(result + [cls.__name__]))
|
||||
|
||||
|
||||
def get_default_factory(module: str, function: str):
|
||||
|
|
@ -333,8 +333,10 @@ def format_dict(d, name: Optional[str] = None):
|
|||
# Add options to openai
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
value["options"] = constants.OPENAI_MODELS
|
||||
value["list"] = True
|
||||
elif name == "OpenAIChat" and key == "model_name":
|
||||
value["options"] = constants.CHAT_OPENAI_MODELS
|
||||
value["list"] = True
|
||||
|
||||
return d
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ def pytest_configure():
|
|||
pytest.COMPLEX_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "complex_example.json"
|
||||
)
|
||||
pytest.OPENAPI_EXAMPLE_PATH = (
|
||||
Path(__file__).parent.absolute() / "data" / "Openapi.json"
|
||||
)
|
||||
|
||||
pytest.CODE_WITH_SYNTAX_ERROR = """
|
||||
def get_text():
|
||||
|
|
|
|||
445
tests/data/Openapi.json
Normal file
445
tests/data/Openapi.json
Normal file
File diff suppressed because one or more lines are too long
64
tests/test_cache.py
Normal file
64
tests/test_cache.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import json
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import dill
|
||||
import tempfile
|
||||
from langflow.cache.utils import compute_hash, load_cache, save_cache, PREFIX
|
||||
from langflow.interface.run import load_langchain_object, process_graph
|
||||
import pytest
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
return flow_graph["data"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_data_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_data_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_data_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
def langchain_objects_are_equal(obj1, obj2):
|
||||
return str(obj1) == str(obj2)
|
||||
|
||||
|
||||
def test_cache_creation(basic_data_graph):
|
||||
# Compute hash for the input data_graph
|
||||
computed_hash = compute_hash(basic_data_graph)
|
||||
|
||||
# Call process_graph function to build and cache the langchain_object
|
||||
_ = load_langchain_object(basic_data_graph)
|
||||
|
||||
# Check if the cache file exists
|
||||
cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}_{computed_hash}.dill"
|
||||
assert cache_file.exists()
|
||||
|
||||
|
||||
def test_cache_reuse(basic_data_graph):
|
||||
# Call process_graph function to build and cache the langchain_object
|
||||
result1 = load_langchain_object(basic_data_graph)
|
||||
|
||||
# Call process_graph function again to use the cached langchain_object
|
||||
result2 = load_langchain_object(basic_data_graph)
|
||||
|
||||
# Compare the results to ensure the same langchain_object was used
|
||||
assert langchain_objects_are_equal(result1, result2)
|
||||
49
tests/test_creators.py
Normal file
49
tests/test_creators.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
from typing import Dict, List
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.agents.base import AgentCreator
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_lang_chain_type_creator() -> LangChainTypeCreator:
|
||||
class SampleLangChainTypeCreator(LangChainTypeCreator):
|
||||
type_name: str = "test_type"
|
||||
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
return {"test_type": "TestClass"}
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
return ["node1", "node2"]
|
||||
|
||||
def get_signature(self, name: str) -> Dict:
|
||||
return {
|
||||
"template": {"test_field": {"type": "str"}},
|
||||
"description": "test description",
|
||||
"base_classes": ["base_class1", "base_class2"],
|
||||
}
|
||||
|
||||
return SampleLangChainTypeCreator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_creator() -> AgentCreator:
|
||||
return AgentCreator()
|
||||
|
||||
|
||||
def test_lang_chain_type_creator_to_dict(
|
||||
sample_lang_chain_type_creator: LangChainTypeCreator,
|
||||
):
|
||||
type_dict = sample_lang_chain_type_creator.to_dict()
|
||||
assert len(type_dict) == 1
|
||||
assert "test_type" in type_dict
|
||||
assert "node1" in type_dict["test_type"]
|
||||
assert "node2" in type_dict["test_type"]
|
||||
assert "template" in type_dict["test_type"]["node1"]
|
||||
assert "description" in type_dict["test_type"]["node1"]
|
||||
assert "base_classes" in type_dict["test_type"]["node1"]
|
||||
|
||||
|
||||
def test_agent_creator_type_to_loader_dict(sample_agent_creator: AgentCreator):
|
||||
type_to_loader_dict = sample_agent_creator.type_to_loader_dict
|
||||
assert len(type_to_loader_dict) > 0
|
||||
assert "JsonAgent"
|
||||
60
tests/test_frontend_nodes.py
Normal file
60
tests/test_frontend_nodes.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import pytest
|
||||
from typing import Dict, List
|
||||
from langflow.template.base import TemplateField, FrontendNode, Template
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.agents.base import AgentCreator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_template_field() -> TemplateField:
|
||||
return TemplateField(name="test_field", field_type="str")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_template(sample_template_field: TemplateField) -> Template:
|
||||
return Template(type_name="test_template", fields=[sample_template_field])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_frontend_node(sample_template: Template) -> FrontendNode:
|
||||
return FrontendNode(
|
||||
template=sample_template,
|
||||
description="test description",
|
||||
base_classes=["base_class1", "base_class2"],
|
||||
name="test_frontend_node",
|
||||
)
|
||||
|
||||
|
||||
def test_template_field_defaults(sample_template_field: TemplateField):
|
||||
assert sample_template_field.field_type == "str"
|
||||
assert sample_template_field.required == False
|
||||
assert sample_template_field.placeholder == ""
|
||||
assert sample_template_field.is_list == False
|
||||
assert sample_template_field.show == True
|
||||
assert sample_template_field.multiline == False
|
||||
assert sample_template_field.value == None
|
||||
assert sample_template_field.suffixes == []
|
||||
assert sample_template_field.file_types == []
|
||||
assert sample_template_field.content == None
|
||||
assert sample_template_field.password == False
|
||||
assert sample_template_field.name == "test_field"
|
||||
|
||||
|
||||
def test_template_to_dict(
|
||||
sample_template: Template, sample_template_field: TemplateField
|
||||
):
|
||||
template_dict = sample_template.to_dict()
|
||||
assert template_dict["_type"] == "test_template"
|
||||
assert len(template_dict) == 2 # _type and test_field
|
||||
assert "test_field" in template_dict
|
||||
assert "type" in template_dict["test_field"]
|
||||
assert "required" in template_dict["test_field"]
|
||||
|
||||
|
||||
def test_frontend_node_to_dict(sample_frontend_node: FrontendNode):
|
||||
node_dict = sample_frontend_node.to_dict()
|
||||
assert len(node_dict) == 1
|
||||
assert "test_frontend_node" in node_dict
|
||||
assert "description" in node_dict["test_frontend_node"]
|
||||
assert "template" in node_dict["test_frontend_node"]
|
||||
assert "base_classes" in node_dict["test_frontend_node"]
|
||||
|
|
@ -1,16 +1,36 @@
|
|||
import json
|
||||
from langflow.graph.nodes import (
|
||||
WrapperNode,
|
||||
AgentNode,
|
||||
ToolNode,
|
||||
ChainNode,
|
||||
PromptNode,
|
||||
LLMNode,
|
||||
ToolkitNode,
|
||||
FileToolNode,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from langchain.agents import AgentExecutor
|
||||
from langflow.graph import Edge, Graph, Node
|
||||
from langflow.utils.payload import build_json, get_root_node
|
||||
|
||||
|
||||
# Test cases for the graph module
|
||||
|
||||
# now we have three types of graph:
|
||||
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
|
||||
|
||||
def get_graph(basic=True):
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
path = pytest.BASIC_EXAMPLE_PATH if basic else pytest.COMPLEX_EXAMPLE_PATH
|
||||
if _type == "basic":
|
||||
path = pytest.BASIC_EXAMPLE_PATH
|
||||
elif _type == "complex":
|
||||
path = pytest.COMPLEX_EXAMPLE_PATH
|
||||
elif _type == "openapi":
|
||||
path = pytest.OPENAPI_EXAMPLE_PATH
|
||||
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
|
|
@ -19,26 +39,94 @@ def get_graph(basic=True):
|
|||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
def test_get_nodes_with_target():
|
||||
@pytest.fixture
|
||||
def basic_graph():
|
||||
return get_graph()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_graph():
|
||||
return get_graph("complex")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_graph():
|
||||
return get_graph("openapi")
|
||||
|
||||
|
||||
def get_node_by_type(graph, node_type):
|
||||
"""Get a node by type"""
|
||||
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
|
||||
|
||||
|
||||
def test_graph_structure(basic_graph):
|
||||
assert isinstance(basic_graph, Graph)
|
||||
assert len(basic_graph.nodes) > 0
|
||||
assert len(basic_graph.edges) > 0
|
||||
for node in basic_graph.nodes:
|
||||
assert isinstance(node, Node)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert edge.source in basic_graph.nodes
|
||||
assert edge.target in basic_graph.nodes
|
||||
|
||||
|
||||
def test_circular_dependencies(basic_graph):
|
||||
assert isinstance(basic_graph, Graph)
|
||||
|
||||
def check_circular(node, visited):
|
||||
visited.add(node)
|
||||
neighbors = basic_graph.get_nodes_with_target(node)
|
||||
for neighbor in neighbors:
|
||||
if neighbor in visited:
|
||||
return True
|
||||
if check_circular(neighbor, visited.copy()):
|
||||
return True
|
||||
return False
|
||||
|
||||
for node in basic_graph.nodes:
|
||||
assert not check_circular(node, set())
|
||||
|
||||
|
||||
def test_invalid_node_types():
|
||||
graph_data = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "1",
|
||||
"data": {
|
||||
"node": {
|
||||
"base_classes": ["BaseClass"],
|
||||
"template": {
|
||||
"_type": "InvalidNodeType",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
Graph(graph_data["nodes"], graph_data["edges"])
|
||||
|
||||
|
||||
def test_get_nodes_with_target(basic_graph):
|
||||
"""Test getting connected nodes"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
connected_nodes = graph.get_nodes_with_target(root)
|
||||
connected_nodes = basic_graph.get_nodes_with_target(root)
|
||||
assert connected_nodes is not None
|
||||
|
||||
|
||||
def test_get_node_neighbors_basic():
|
||||
def test_get_node_neighbors_basic(basic_graph):
|
||||
"""Test getting node neighbors"""
|
||||
|
||||
graph = get_graph(basic=True)
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
neighbors = graph.get_node_neighbors(root)
|
||||
neighbors = basic_graph.get_node_neighbors(root)
|
||||
assert neighbors is not None
|
||||
assert isinstance(neighbors, dict)
|
||||
# Root Node is an Agent, it requires an LLMChain and tools
|
||||
|
|
@ -57,7 +145,7 @@ def test_get_node_neighbors_basic():
|
|||
for neighbor, val in neighbors.items()
|
||||
if "Chain" in neighbor.data["type"] and val
|
||||
)
|
||||
chain_neighbors = graph.get_node_neighbors(chain)
|
||||
chain_neighbors = basic_graph.get_node_neighbors(chain)
|
||||
assert chain_neighbors is not None
|
||||
assert isinstance(chain_neighbors, dict)
|
||||
# Check if there is a LLM in the chain's neighbors
|
||||
|
|
@ -74,15 +162,13 @@ def test_get_node_neighbors_basic():
|
|||
)
|
||||
|
||||
|
||||
def test_get_node_neighbors_complex():
|
||||
def test_get_node_neighbors_complex(complex_graph):
|
||||
"""Test getting node neighbors"""
|
||||
|
||||
graph = get_graph(basic=False)
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(complex_graph, Graph)
|
||||
# Get root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(complex_graph)
|
||||
assert root is not None
|
||||
neighbors = graph.get_nodes_with_target(root)
|
||||
neighbors = complex_graph.get_nodes_with_target(root)
|
||||
assert neighbors is not None
|
||||
# Neighbors should be a list of nodes
|
||||
assert isinstance(neighbors, list)
|
||||
|
|
@ -93,7 +179,7 @@ def test_get_node_neighbors_complex():
|
|||
assert any("Tool" in neighbor.data["type"] for neighbor in neighbors)
|
||||
# Now on to the Chain's neighbors
|
||||
chain = next(neighbor for neighbor in neighbors if "Chain" in neighbor.data["type"])
|
||||
chain_neighbors = graph.get_nodes_with_target(chain)
|
||||
chain_neighbors = complex_graph.get_nodes_with_target(chain)
|
||||
assert chain_neighbors is not None
|
||||
# Check if there is a LLM in the chain's neighbors
|
||||
assert any("OpenAI" in neighbor.data["type"] for neighbor in chain_neighbors)
|
||||
|
|
@ -101,7 +187,7 @@ def test_get_node_neighbors_complex():
|
|||
assert any("Prompt" in neighbor.data["type"] for neighbor in chain_neighbors)
|
||||
# Now on to the Tool's neighbors
|
||||
tool = next(neighbor for neighbor in neighbors if "Tool" in neighbor.data["type"])
|
||||
tool_neighbors = graph.get_nodes_with_target(tool)
|
||||
tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
||||
assert tool_neighbors is not None
|
||||
# Check if there is an Agent in the tool's neighbors
|
||||
assert any("Agent" in neighbor.data["type"] for neighbor in tool_neighbors)
|
||||
|
|
@ -109,7 +195,7 @@ def test_get_node_neighbors_complex():
|
|||
agent = next(
|
||||
neighbor for neighbor in tool_neighbors if "Agent" in neighbor.data["type"]
|
||||
)
|
||||
agent_neighbors = graph.get_nodes_with_target(agent)
|
||||
agent_neighbors = complex_graph.get_nodes_with_target(agent)
|
||||
assert agent_neighbors is not None
|
||||
# Check if there is a Tool in the agent's neighbors
|
||||
assert any("Tool" in neighbor.data["type"] for neighbor in agent_neighbors)
|
||||
|
|
@ -117,62 +203,57 @@ def test_get_node_neighbors_complex():
|
|||
tool = next(
|
||||
neighbor for neighbor in agent_neighbors if "Tool" in neighbor.data["type"]
|
||||
)
|
||||
tool_neighbors = graph.get_nodes_with_target(tool)
|
||||
tool_neighbors = complex_graph.get_nodes_with_target(tool)
|
||||
assert tool_neighbors is not None
|
||||
# Check if there is a PythonFunction in the tool's neighbors
|
||||
assert any("PythonFunction" in neighbor.data["type"] for neighbor in tool_neighbors)
|
||||
|
||||
|
||||
def test_get_node():
|
||||
def test_get_node(basic_graph):
|
||||
"""Test getting a single node"""
|
||||
graph = get_graph()
|
||||
node_id = graph.nodes[0].id
|
||||
node = graph.get_node(node_id)
|
||||
node_id = basic_graph.nodes[0].id
|
||||
node = basic_graph.get_node(node_id)
|
||||
assert isinstance(node, Node)
|
||||
assert node.id == node_id
|
||||
|
||||
|
||||
def test_build_nodes():
|
||||
def test_build_nodes(basic_graph):
|
||||
"""Test building nodes"""
|
||||
graph = get_graph()
|
||||
assert len(graph.nodes) == len(graph._nodes)
|
||||
for node in graph.nodes:
|
||||
|
||||
assert len(basic_graph.nodes) == len(basic_graph._nodes)
|
||||
for node in basic_graph.nodes:
|
||||
assert isinstance(node, Node)
|
||||
|
||||
|
||||
def test_build_edges():
|
||||
def test_build_edges(basic_graph):
|
||||
"""Test building edges"""
|
||||
graph = get_graph()
|
||||
assert len(graph.edges) == len(graph._edges)
|
||||
for edge in graph.edges:
|
||||
assert len(basic_graph.edges) == len(basic_graph._edges)
|
||||
for edge in basic_graph.edges:
|
||||
assert isinstance(edge, Edge)
|
||||
assert isinstance(edge.source, Node)
|
||||
assert isinstance(edge.target, Node)
|
||||
|
||||
|
||||
def test_get_root_node():
|
||||
def test_get_root_node(basic_graph, complex_graph):
|
||||
"""Test getting root node"""
|
||||
graph = get_graph(basic=True)
|
||||
assert isinstance(graph, Graph)
|
||||
root = get_root_node(graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_node(basic_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Node)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
# For complex example, the root node is a ZeroShotAgent too
|
||||
graph = get_graph(basic=False)
|
||||
assert isinstance(graph, Graph)
|
||||
root = get_root_node(graph)
|
||||
assert isinstance(complex_graph, Graph)
|
||||
root = get_root_node(complex_graph)
|
||||
assert root is not None
|
||||
assert isinstance(root, Node)
|
||||
assert root.data["type"] == "ZeroShotAgent"
|
||||
|
||||
|
||||
def test_build_json():
|
||||
def test_build_json(basic_graph):
|
||||
"""Test building JSON from graph"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
root = get_root_node(graph)
|
||||
json_data = build_json(root, graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
root = get_root_node(basic_graph)
|
||||
json_data = build_json(root, basic_graph)
|
||||
assert isinstance(json_data, dict)
|
||||
assert json_data["_type"] == "zero-shot-react-description"
|
||||
assert isinstance(json_data["llm_chain"], dict)
|
||||
|
|
@ -188,38 +269,37 @@ def test_build_json():
|
|||
assert all(isinstance(val, str) for val in json_data["return_values"])
|
||||
|
||||
|
||||
def test_validate_edges():
|
||||
def test_validate_edges(basic_graph):
|
||||
"""Test validating edges"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
|
||||
|
||||
def test_matched_type():
|
||||
def test_matched_type(basic_graph):
|
||||
"""Test matched type attribute in Edge"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
|
||||
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
|
||||
|
||||
def test_build_params():
|
||||
def test_build_params(basic_graph):
|
||||
"""Test building params"""
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# all edges should be valid
|
||||
assert all(edge.valid for edge in graph.edges)
|
||||
assert all(edge.valid for edge in basic_graph.edges)
|
||||
# all edges should have a matched_type attribute
|
||||
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
|
||||
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
|
||||
# The matched_type attribute should be in the source_types attr
|
||||
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
|
||||
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
|
||||
# Get the root node
|
||||
root = get_root_node(graph)
|
||||
root = get_root_node(basic_graph)
|
||||
# Root node is a ZeroShotAgent
|
||||
# which requires an llm_chain, allowed_tools and return_values
|
||||
assert isinstance(root.params, dict)
|
||||
|
|
@ -261,7 +341,7 @@ def test_build_params():
|
|||
assert isinstance(llm_node.params["model_name"], str)
|
||||
|
||||
|
||||
def test_build():
|
||||
def test_build(basic_graph, complex_graph):
|
||||
"""Test Node's build method"""
|
||||
# def build(self):
|
||||
# # The params dict is used to build the module
|
||||
|
|
@ -284,18 +364,81 @@ def test_build():
|
|||
# # and instantiate it with the params
|
||||
# # and return the instance
|
||||
# return LANGCHAIN_TYPES_DICT[self.node_type](**self.params)
|
||||
graph = get_graph()
|
||||
assert isinstance(graph, Graph)
|
||||
|
||||
assert isinstance(basic_graph, Graph)
|
||||
# Now we test the build method
|
||||
# Build the Agent
|
||||
agent = graph.build()
|
||||
agent = basic_graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
||||
# Now we test the complex example
|
||||
graph = get_graph(basic=False)
|
||||
assert isinstance(graph, Graph)
|
||||
assert isinstance(complex_graph, Graph)
|
||||
# Now we test the build method
|
||||
agent = graph.build()
|
||||
agent = complex_graph.build()
|
||||
# The agent should be a AgentExecutor
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
||||
|
||||
def test_agent_node_build(basic_graph):
|
||||
agent_node = get_node_by_type(basic_graph, AgentNode)
|
||||
assert agent_node is not None
|
||||
built_object = agent_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the AgentNode's build() method
|
||||
|
||||
|
||||
def test_tool_node_build(basic_graph):
|
||||
tool_node = get_node_by_type(basic_graph, ToolNode)
|
||||
assert tool_node is not None
|
||||
built_object = tool_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the ToolNode's build() method
|
||||
|
||||
|
||||
def test_chain_node_build(complex_graph):
|
||||
chain_node = get_node_by_type(complex_graph, ChainNode)
|
||||
assert chain_node is not None
|
||||
built_object = chain_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the ChainNode's build() method
|
||||
|
||||
|
||||
def test_prompt_node_build(complex_graph):
|
||||
prompt_node = get_node_by_type(complex_graph, PromptNode)
|
||||
assert prompt_node is not None
|
||||
built_object = prompt_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the PromptNode's build() method
|
||||
|
||||
|
||||
def test_llm_node_build(basic_graph):
|
||||
llm_node = get_node_by_type(basic_graph, LLMNode)
|
||||
assert llm_node is not None
|
||||
built_object = llm_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the LLMNode's build() method
|
||||
|
||||
|
||||
def test_toolkit_node_build(openapi_graph):
|
||||
toolkit_node = get_node_by_type(openapi_graph, ToolkitNode)
|
||||
assert toolkit_node is not None
|
||||
built_object = toolkit_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the ToolkitNode's build() method
|
||||
|
||||
|
||||
def test_file_tool_node_build(openapi_graph):
|
||||
file_tool_node = get_node_by_type(openapi_graph, FileToolNode)
|
||||
assert file_tool_node is not None
|
||||
built_object = file_tool_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the FileToolNode's build() method
|
||||
|
||||
|
||||
def test_wrapper_node_build(openapi_graph):
|
||||
wrapper_node = get_node_by_type(openapi_graph, WrapperNode)
|
||||
assert wrapper_node is not None
|
||||
built_object = wrapper_node.build()
|
||||
assert built_object is not None
|
||||
# Add any further assertions specific to the WrapperNode's build() method
|
||||
|
|
|
|||
291
tests/test_template.py
Normal file
291
tests/test_template.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
from langflow.utils.constants import CHAT_OPENAI_MODELS, OPENAI_MODELS
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
import re
|
||||
import importlib
|
||||
from typing import Dict, List, Optional
|
||||
from langflow.utils.util import (
|
||||
build_template_from_class,
|
||||
build_template_from_function,
|
||||
format_dict,
|
||||
get_base_classes,
|
||||
get_default_factory,
|
||||
get_class_doc,
|
||||
)
|
||||
|
||||
|
||||
# Dummy classes for testing purposes
|
||||
class Parent(BaseModel):
|
||||
"""Parent Class"""
|
||||
|
||||
parent_field: str
|
||||
|
||||
|
||||
class Child(Parent):
|
||||
"""Child Class"""
|
||||
|
||||
child_field: int
|
||||
|
||||
|
||||
class ExampleClass1(BaseModel):
|
||||
"""Example class 1."""
|
||||
|
||||
def __init__(self, data: Optional[List[int]] = None):
|
||||
self.data = data or [1, 2, 3]
|
||||
|
||||
|
||||
class ExampleClass2(BaseModel):
|
||||
"""Example class 2."""
|
||||
|
||||
def __init__(self, data: Optional[Dict[str, int]] = None):
|
||||
self.data = data or {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def example_loader_1() -> ExampleClass1:
|
||||
"""Example loader function 1."""
|
||||
return ExampleClass1()
|
||||
|
||||
|
||||
def example_loader_2() -> ExampleClass2:
|
||||
"""Example loader function 2."""
|
||||
return ExampleClass2()
|
||||
|
||||
|
||||
def test_build_template_from_function():
|
||||
type_to_loader_dict = {
|
||||
"example1": example_loader_1,
|
||||
"example2": example_loader_2,
|
||||
}
|
||||
|
||||
# Test with valid name
|
||||
result = build_template_from_function("ExampleClass1", type_to_loader_dict)
|
||||
|
||||
assert "template" in result
|
||||
assert "description" in result
|
||||
assert "base_classes" in result
|
||||
|
||||
# Test with add_function=True
|
||||
result_with_function = build_template_from_function(
|
||||
"ExampleClass1", type_to_loader_dict, add_function=True
|
||||
)
|
||||
assert "function" in result_with_function["base_classes"]
|
||||
|
||||
# Test with invalid name
|
||||
with pytest.raises(ValueError, match=r".* not found"):
|
||||
build_template_from_function("NonExistent", type_to_loader_dict)
|
||||
|
||||
|
||||
# Test build_template_from_class
|
||||
def test_build_template_from_class():
|
||||
type_to_cls_dict: Dict[str, type] = {"parent": Parent, "child": Child}
|
||||
|
||||
# Test valid input
|
||||
result = build_template_from_class("Child", type_to_cls_dict)
|
||||
assert "template" in result
|
||||
assert "description" in result
|
||||
assert "base_classes" in result
|
||||
assert "Child" in result["base_classes"]
|
||||
assert "Parent" in result["base_classes"]
|
||||
assert result["description"] == "Child Class"
|
||||
|
||||
# Test invalid input
|
||||
with pytest.raises(ValueError, match="InvalidClass not found."):
|
||||
build_template_from_class("InvalidClass", type_to_cls_dict)
|
||||
|
||||
|
||||
# Test format_dict
|
||||
def test_format_dict():
|
||||
# Test 1: Optional type removal
|
||||
input_dict = {
|
||||
"field1": {"type": "Optional[str]", "required": False},
|
||||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict) == expected_output
|
||||
|
||||
# Test 2: List type processing
|
||||
input_dict = {
|
||||
"field1": {"type": "List[str]", "required": False},
|
||||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": True,
|
||||
"show": False,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict) == expected_output
|
||||
|
||||
# Test 3: Mapping type replacement
|
||||
input_dict = {
|
||||
"field1": {"type": "Mapping[str, int]", "required": False},
|
||||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "code", # Mapping type is replaced with dict which is replaced with code
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict) == expected_output
|
||||
|
||||
# Test 4: Replace default value with actual value
|
||||
input_dict = {
|
||||
"field1": {"type": "str", "required": False, "default": "test"},
|
||||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
"value": "test",
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict) == expected_output
|
||||
|
||||
# Test 5: Add password field
|
||||
input_dict = {
|
||||
"field1": {"type": "str", "required": False},
|
||||
"api_key": {"type": "str", "required": False},
|
||||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
},
|
||||
"api_key": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"password": True,
|
||||
"multiline": False,
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict) == expected_output
|
||||
|
||||
# Test 6: Add multiline
|
||||
input_dict = {
|
||||
"field1": {"type": "str", "required": False},
|
||||
"prefix": {"type": "str", "required": False},
|
||||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
},
|
||||
"prefix": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": True,
|
||||
"password": False,
|
||||
"multiline": True,
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict) == expected_output
|
||||
|
||||
# Test 7: Check class name-specific cases (OpenAI, OpenAIChat)
|
||||
input_dict = {
|
||||
"model_name": {"type": "str", "required": False},
|
||||
}
|
||||
expected_output_openai = {
|
||||
"model_name": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": True,
|
||||
"show": True,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
"options": OPENAI_MODELS,
|
||||
},
|
||||
}
|
||||
expected_output_openai_chat = {
|
||||
"model_name": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"list": True,
|
||||
"show": True,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
"options": CHAT_OPENAI_MODELS,
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict, "OpenAI") == expected_output_openai
|
||||
assert format_dict(input_dict, "OpenAIChat") == expected_output_openai_chat
|
||||
|
||||
# Test 8: Replace dict type with str
|
||||
input_dict = {
|
||||
"field1": {"type": "Dict[str, int]", "required": False},
|
||||
}
|
||||
expected_output = {
|
||||
"field1": {
|
||||
"type": "code",
|
||||
"required": False,
|
||||
"list": False,
|
||||
"show": False,
|
||||
"password": False,
|
||||
"multiline": False,
|
||||
},
|
||||
}
|
||||
assert format_dict(input_dict) == expected_output
|
||||
|
||||
|
||||
# Test get_base_classes
|
||||
def test_get_base_classes():
|
||||
base_classes_parent = get_base_classes(Parent)
|
||||
base_classes_child = get_base_classes(Child)
|
||||
|
||||
assert "Parent" in base_classes_parent
|
||||
assert "Child" in base_classes_child
|
||||
assert "Parent" in base_classes_child
|
||||
|
||||
|
||||
# Test get_default_factory
|
||||
def test_get_default_factory():
|
||||
module_name = "langflow.utils.util"
|
||||
function_repr = "<function dummy_function>"
|
||||
|
||||
def dummy_function():
|
||||
return "default_value"
|
||||
|
||||
# Add dummy_function to your_module
|
||||
setattr(importlib.import_module(module_name), "dummy_function", dummy_function)
|
||||
|
||||
default_value = get_default_factory(module_name, function_repr)
|
||||
|
||||
assert default_value == "default_value"
|
||||
|
||||
|
||||
# Test get_class_doc
|
||||
def test_get_class_doc():
|
||||
class_doc_parent = get_class_doc(Parent)
|
||||
class_doc_child = get_class_doc(Child)
|
||||
|
||||
assert class_doc_parent["Description"] == "Parent Class"
|
||||
assert class_doc_child["Description"] == "Child Class"
|
||||
Loading…
Add table
Add a link
Reference in a new issue