Merge branch 'toolkits' of https://github.com/logspace-ai/langflow into toolkits

This commit is contained in:
Ibis Prevedello 2023-04-01 16:19:19 -03:00
commit 0b760003b1
17 changed files with 1310 additions and 87 deletions

View file

49
src/backend/langflow/cache/utils.py vendored Normal file
View file

@ -0,0 +1,49 @@
import contextlib
import hashlib
import json
import os
from pathlib import Path
import tempfile
import dill
PREFIX = "langflow_cache"
def clear_old_cache_files(max_cache_size: int = 10):
cache_dir = Path(tempfile.gettempdir())
cache_files = list(cache_dir.glob(f"{PREFIX}_*.dill"))
if len(cache_files) > max_cache_size:
cache_files_sorted_by_mtime = sorted(
cache_files, key=lambda x: x.stat().st_mtime, reverse=True
)
for cache_file in cache_files_sorted_by_mtime[max_cache_size:]:
with contextlib.suppress(OSError):
os.remove(cache_file)
def remove_position_info(node):
node.pop("position", None)
def compute_hash(graph_data):
for node in graph_data["nodes"]:
remove_position_info(node)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def save_cache(hash_val, chat_data):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
with cache_path.open("wb") as cache_file:
dill.dump(chat_data, cache_file)
def load_cache(hash_val):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
if cache_path.exists():
with cache_path.open("rb") as cache_file:
return dill.load(cache_file)
return None

View file

@ -29,4 +29,8 @@ tools:
wrappers:
- RequestsWrapper
toolkits:
- OpenAPIToolkit
- JsonToolkit
dev: false

View file

@ -68,7 +68,9 @@ class PromptNode(Node):
)
self.params["tools"] = tools
prompt_params = [
key for key, value in self.params.items() if value["type"] == "str"
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]

View file

@ -2,30 +2,48 @@ import contextlib
import io
import re
from typing import Any, Dict
from langflow.cache.utils import compute_hash, load_cache, save_cache
from langflow.graph.graph import Graph
from langflow.interface import loading
from langflow.utils import payload
def load_langchain_object(data_graph):
computed_hash = compute_hash(data_graph)
# Load langchain_object from cache if it exists
langchain_object = load_cache(computed_hash)
if langchain_object is None:
nodes = data_graph["nodes"]
# Add input variables
nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
langchain_object = graph.build()
return computed_hash, langchain_object
def process_graph(data_graph: Dict[str, Any]):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
nodes = data_graph["nodes"]
# Add input variables
# ? Is this necessary?
nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
langchain_object = graph.build()
# Load langchain object
computed_hash, langchain_object = load_langchain_object(data_graph)
message = data_graph["message"]
# Process json
# Generate result and thought
result, thought = get_result_and_thought_using_graph(langchain_object, message)
# Save langchain_object to cache
# We have to save it here because if the
# memory is updated we need to keep the new values
save_cache(computed_hash, langchain_object)
return {
"result": result,
"thought": re.sub(

View file

@ -4,6 +4,7 @@ from langchain.agents import agent_toolkits
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.importing.utils import import_class, import_module
from langflow.settings import settings
from langflow.utils.util import build_template_from_class
@ -33,7 +34,7 @@ class ToolkitCreator(LangChainTypeCreator):
)
# if toolkit_name is not lower case it is a class
for toolkit_name in agent_toolkits.__all__
if not toolkit_name.islower()
if not toolkit_name.islower() and toolkit_name in settings.toolkits
}
return self.type_dict

View file

@ -112,7 +112,7 @@ class ToolCreator(LangChainTypeCreator):
# Copy the field and add the name
fields = []
for param in params:
field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"])
field = TOOL_INPUTS.get(param, TOOL_INPUTS["str"]).copy()
field.name = param
if param == "aiosession":
field.show = False

View file

@ -13,6 +13,7 @@ class Settings(BaseSettings):
tools: List[str] = []
memories: List[str] = []
wrappers: List[str] = []
toolkits: List[str] = []
dev: bool = False
class Config:
@ -35,6 +36,7 @@ class Settings(BaseSettings):
self.tools = new_settings.tools or []
self.memories = new_settings.memories or []
self.wrappers = new_settings.wrappers or []
self.toolkits = new_settings.toolkits or []
self.dev = new_settings.dev or False

View file

@ -1,5 +1,6 @@
from abc import ABC
from typing import Any, Union
from typing import Any, Optional, Union, Dict
from langflow.utils import constants
from pydantic import BaseModel
@ -17,6 +18,7 @@ class TemplateFieldCreator(BaseModel, ABC):
file_types: list[str] = []
content: Union[str, None] = None
password: bool = False
options: list[str] = []
# _name will be used to store the name of the field
# in the template
name: str = ""
@ -37,6 +39,88 @@ class TemplateFieldCreator(BaseModel, ABC):
result["content"] = self.content
return result
def process_field(
self, key: str, value: Dict[str, Any], name: Optional[str] = None
) -> None:
_type = value["type"]
# Remove 'Optional' wrapper
if "Optional" in _type:
_type = _type.replace("Optional[", "")[:-1]
# Check for list type
if "List" in _type:
_type = _type.replace("List[", "")[:-1]
self.is_list = True
else:
self.is_list = False
# Replace 'Mapping' with 'dict'
if "Mapping" in _type:
_type = _type.replace("Mapping", "dict")
# Change type from str to Tool
self.field_type = "Tool" if key in ["allowed_tools"] else _type
self.field_type = "int" if key in ["max_value_length"] else self.field_type
# Show or not field
self.show = bool(
(self.required and key not in ["input_variables"])
or key
in [
"allowed_tools",
"memory",
"prefix",
"examples",
"temperature",
"model_name",
"headers",
"max_value_length",
]
or "api_key" in key
)
# Add password field
self.password = any(
text in key.lower() for text in ["password", "token", "api", "key"]
)
# Add multline
self.multiline = key in [
"suffix",
"prefix",
"template",
"examples",
"code",
"headers",
]
# Replace dict type with str
if "dict" in self.field_type.lower():
self.field_type = "code"
if key == "dict_":
self.field_type = "file"
self.suffixes = [".json", ".yaml", ".yml"]
self.file_types = ["json", "yaml", "yml"]
# Replace default value with actual value
if "default" in value:
self.value = value["default"]
if key == "headers":
self.value = """{'Authorization':
'Bearer <token>'}"""
# Add options to openai
if name == "OpenAI" and key == "model_name":
self.options = constants.OPENAI_MODELS
self.is_list = True
elif name == "OpenAIChat" and key == "model_name":
self.options = constants.CHAT_OPENAI_MODELS
self.is_list = True
class TemplateField(TemplateFieldCreator):
pass
@ -46,7 +130,13 @@ class Template(BaseModel):
type_name: str
fields: list[TemplateField]
def process_fields(self, name: Optional[str] = None) -> None:
for field in self.fields:
signature = field.to_dict()
field.process_field(field.name, signature, name)
def to_dict(self):
self.process_fields(self.type_name)
result = {field.name: field.to_dict() for field in self.fields}
result["_type"] = self.type_name # type: ignore
return result

View file

@ -98,7 +98,7 @@ def build_template_from_function(
return {
"template": format_dict(variables, name),
"description": docs["Description"],
"base_classes": get_base_classes(_class),
"base_classes": base_classes,
}
@ -173,7 +173,7 @@ def get_base_classes(cls):
result = [cls.__name__]
if not result:
result = [cls.__name__]
return list(set(result))
return list(set(result + [cls.__name__]))
def get_default_factory(module: str, function: str):
@ -333,8 +333,10 @@ def format_dict(d, name: Optional[str] = None):
# Add options to openai
if name == "OpenAI" and key == "model_name":
value["options"] = constants.OPENAI_MODELS
value["list"] = True
elif name == "OpenAIChat" and key == "model_name":
value["options"] = constants.CHAT_OPENAI_MODELS
value["list"] = True
return d

View file

@ -11,6 +11,9 @@ def pytest_configure():
pytest.COMPLEX_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "complex_example.json"
)
pytest.OPENAPI_EXAMPLE_PATH = (
Path(__file__).parent.absolute() / "data" / "Openapi.json"
)
pytest.CODE_WITH_SYNTAX_ERROR = """
def get_text():

445
tests/data/Openapi.json Normal file

File diff suppressed because one or more lines are too long

64
tests/test_cache.py Normal file
View file

@ -0,0 +1,64 @@
import json
import hashlib
from pathlib import Path
import dill
import tempfile
from langflow.cache.utils import compute_hash, load_cache, save_cache, PREFIX
from langflow.interface.run import load_langchain_object, process_graph
import pytest
def get_graph(_type="basic"):
"""Get a graph from a json file"""
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
return flow_graph["data"]
@pytest.fixture
def basic_data_graph():
return get_graph()
@pytest.fixture
def complex_data_graph():
return get_graph("complex")
@pytest.fixture
def openapi_data_graph():
return get_graph("openapi")
def langchain_objects_are_equal(obj1, obj2):
return str(obj1) == str(obj2)
def test_cache_creation(basic_data_graph):
# Compute hash for the input data_graph
computed_hash = compute_hash(basic_data_graph)
# Call process_graph function to build and cache the langchain_object
_ = load_langchain_object(basic_data_graph)
# Check if the cache file exists
cache_file = Path(tempfile.gettempdir()) / f"{PREFIX}_{computed_hash}.dill"
assert cache_file.exists()
def test_cache_reuse(basic_data_graph):
# Call process_graph function to build and cache the langchain_object
result1 = load_langchain_object(basic_data_graph)
# Call process_graph function again to use the cached langchain_object
result2 = load_langchain_object(basic_data_graph)
# Compare the results to ensure the same langchain_object was used
assert langchain_objects_are_equal(result1, result2)

49
tests/test_creators.py Normal file
View file

@ -0,0 +1,49 @@
from typing import Dict, List
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.agents.base import AgentCreator
import pytest
@pytest.fixture
def sample_lang_chain_type_creator() -> LangChainTypeCreator:
class SampleLangChainTypeCreator(LangChainTypeCreator):
type_name: str = "test_type"
def type_to_loader_dict(self) -> Dict:
return {"test_type": "TestClass"}
def to_list(self) -> List[str]:
return ["node1", "node2"]
def get_signature(self, name: str) -> Dict:
return {
"template": {"test_field": {"type": "str"}},
"description": "test description",
"base_classes": ["base_class1", "base_class2"],
}
return SampleLangChainTypeCreator()
@pytest.fixture
def sample_agent_creator() -> AgentCreator:
return AgentCreator()
def test_lang_chain_type_creator_to_dict(
sample_lang_chain_type_creator: LangChainTypeCreator,
):
type_dict = sample_lang_chain_type_creator.to_dict()
assert len(type_dict) == 1
assert "test_type" in type_dict
assert "node1" in type_dict["test_type"]
assert "node2" in type_dict["test_type"]
assert "template" in type_dict["test_type"]["node1"]
assert "description" in type_dict["test_type"]["node1"]
assert "base_classes" in type_dict["test_type"]["node1"]
def test_agent_creator_type_to_loader_dict(sample_agent_creator: AgentCreator):
type_to_loader_dict = sample_agent_creator.type_to_loader_dict
assert len(type_to_loader_dict) > 0
assert "JsonAgent"

View file

@ -0,0 +1,60 @@
import pytest
from typing import Dict, List
from langflow.template.base import TemplateField, FrontendNode, Template
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.agents.base import AgentCreator
@pytest.fixture
def sample_template_field() -> TemplateField:
return TemplateField(name="test_field", field_type="str")
@pytest.fixture
def sample_template(sample_template_field: TemplateField) -> Template:
return Template(type_name="test_template", fields=[sample_template_field])
@pytest.fixture
def sample_frontend_node(sample_template: Template) -> FrontendNode:
return FrontendNode(
template=sample_template,
description="test description",
base_classes=["base_class1", "base_class2"],
name="test_frontend_node",
)
def test_template_field_defaults(sample_template_field: TemplateField):
assert sample_template_field.field_type == "str"
assert sample_template_field.required == False
assert sample_template_field.placeholder == ""
assert sample_template_field.is_list == False
assert sample_template_field.show == True
assert sample_template_field.multiline == False
assert sample_template_field.value == None
assert sample_template_field.suffixes == []
assert sample_template_field.file_types == []
assert sample_template_field.content == None
assert sample_template_field.password == False
assert sample_template_field.name == "test_field"
def test_template_to_dict(
sample_template: Template, sample_template_field: TemplateField
):
template_dict = sample_template.to_dict()
assert template_dict["_type"] == "test_template"
assert len(template_dict) == 2 # _type and test_field
assert "test_field" in template_dict
assert "type" in template_dict["test_field"]
assert "required" in template_dict["test_field"]
def test_frontend_node_to_dict(sample_frontend_node: FrontendNode):
node_dict = sample_frontend_node.to_dict()
assert len(node_dict) == 1
assert "test_frontend_node" in node_dict
assert "description" in node_dict["test_frontend_node"]
assert "template" in node_dict["test_frontend_node"]
assert "base_classes" in node_dict["test_frontend_node"]

View file

@ -1,16 +1,36 @@
import json
from langflow.graph.nodes import (
WrapperNode,
AgentNode,
ToolNode,
ChainNode,
PromptNode,
LLMNode,
ToolkitNode,
FileToolNode,
)
import pytest
from langchain.agents import AgentExecutor
from langflow.graph import Edge, Graph, Node
from langflow.utils.payload import build_json, get_root_node
# Test cases for the graph module
# now we have three types of graph:
# BASIC_EXAMPLE_PATH, COMPLEX_EXAMPLE_PATH, OPENAPI_EXAMPLE_PATH
def get_graph(basic=True):
def get_graph(_type="basic"):
"""Get a graph from a json file"""
path = pytest.BASIC_EXAMPLE_PATH if basic else pytest.COMPLEX_EXAMPLE_PATH
if _type == "basic":
path = pytest.BASIC_EXAMPLE_PATH
elif _type == "complex":
path = pytest.COMPLEX_EXAMPLE_PATH
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
@ -19,26 +39,94 @@ def get_graph(basic=True):
return Graph(nodes, edges)
def test_get_nodes_with_target():
@pytest.fixture
def basic_graph():
return get_graph()
@pytest.fixture
def complex_graph():
return get_graph("complex")
@pytest.fixture
def openapi_graph():
return get_graph("openapi")
def get_node_by_type(graph, node_type):
"""Get a node by type"""
return next((node for node in graph.nodes if isinstance(node, node_type)), None)
def test_graph_structure(basic_graph):
assert isinstance(basic_graph, Graph)
assert len(basic_graph.nodes) > 0
assert len(basic_graph.edges) > 0
for node in basic_graph.nodes:
assert isinstance(node, Node)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
assert edge.source in basic_graph.nodes
assert edge.target in basic_graph.nodes
def test_circular_dependencies(basic_graph):
assert isinstance(basic_graph, Graph)
def check_circular(node, visited):
visited.add(node)
neighbors = basic_graph.get_nodes_with_target(node)
for neighbor in neighbors:
if neighbor in visited:
return True
if check_circular(neighbor, visited.copy()):
return True
return False
for node in basic_graph.nodes:
assert not check_circular(node, set())
def test_invalid_node_types():
graph_data = {
"nodes": [
{
"id": "1",
"data": {
"node": {
"base_classes": ["BaseClass"],
"template": {
"_type": "InvalidNodeType",
},
},
},
},
],
"edges": [],
}
with pytest.raises(Exception):
Graph(graph_data["nodes"], graph_data["edges"])
def test_get_nodes_with_target(basic_graph):
"""Test getting connected nodes"""
graph = get_graph()
assert isinstance(graph, Graph)
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_node(graph)
root = get_root_node(basic_graph)
assert root is not None
connected_nodes = graph.get_nodes_with_target(root)
connected_nodes = basic_graph.get_nodes_with_target(root)
assert connected_nodes is not None
def test_get_node_neighbors_basic():
def test_get_node_neighbors_basic(basic_graph):
"""Test getting node neighbors"""
graph = get_graph(basic=True)
assert isinstance(graph, Graph)
assert isinstance(basic_graph, Graph)
# Get root node
root = get_root_node(graph)
root = get_root_node(basic_graph)
assert root is not None
neighbors = graph.get_node_neighbors(root)
neighbors = basic_graph.get_node_neighbors(root)
assert neighbors is not None
assert isinstance(neighbors, dict)
# Root Node is an Agent, it requires an LLMChain and tools
@ -57,7 +145,7 @@ def test_get_node_neighbors_basic():
for neighbor, val in neighbors.items()
if "Chain" in neighbor.data["type"] and val
)
chain_neighbors = graph.get_node_neighbors(chain)
chain_neighbors = basic_graph.get_node_neighbors(chain)
assert chain_neighbors is not None
assert isinstance(chain_neighbors, dict)
# Check if there is a LLM in the chain's neighbors
@ -74,15 +162,13 @@ def test_get_node_neighbors_basic():
)
def test_get_node_neighbors_complex():
def test_get_node_neighbors_complex(complex_graph):
"""Test getting node neighbors"""
graph = get_graph(basic=False)
assert isinstance(graph, Graph)
assert isinstance(complex_graph, Graph)
# Get root node
root = get_root_node(graph)
root = get_root_node(complex_graph)
assert root is not None
neighbors = graph.get_nodes_with_target(root)
neighbors = complex_graph.get_nodes_with_target(root)
assert neighbors is not None
# Neighbors should be a list of nodes
assert isinstance(neighbors, list)
@ -93,7 +179,7 @@ def test_get_node_neighbors_complex():
assert any("Tool" in neighbor.data["type"] for neighbor in neighbors)
# Now on to the Chain's neighbors
chain = next(neighbor for neighbor in neighbors if "Chain" in neighbor.data["type"])
chain_neighbors = graph.get_nodes_with_target(chain)
chain_neighbors = complex_graph.get_nodes_with_target(chain)
assert chain_neighbors is not None
# Check if there is a LLM in the chain's neighbors
assert any("OpenAI" in neighbor.data["type"] for neighbor in chain_neighbors)
@ -101,7 +187,7 @@ def test_get_node_neighbors_complex():
assert any("Prompt" in neighbor.data["type"] for neighbor in chain_neighbors)
# Now on to the Tool's neighbors
tool = next(neighbor for neighbor in neighbors if "Tool" in neighbor.data["type"])
tool_neighbors = graph.get_nodes_with_target(tool)
tool_neighbors = complex_graph.get_nodes_with_target(tool)
assert tool_neighbors is not None
# Check if there is an Agent in the tool's neighbors
assert any("Agent" in neighbor.data["type"] for neighbor in tool_neighbors)
@ -109,7 +195,7 @@ def test_get_node_neighbors_complex():
agent = next(
neighbor for neighbor in tool_neighbors if "Agent" in neighbor.data["type"]
)
agent_neighbors = graph.get_nodes_with_target(agent)
agent_neighbors = complex_graph.get_nodes_with_target(agent)
assert agent_neighbors is not None
# Check if there is a Tool in the agent's neighbors
assert any("Tool" in neighbor.data["type"] for neighbor in agent_neighbors)
@ -117,62 +203,57 @@ def test_get_node_neighbors_complex():
tool = next(
neighbor for neighbor in agent_neighbors if "Tool" in neighbor.data["type"]
)
tool_neighbors = graph.get_nodes_with_target(tool)
tool_neighbors = complex_graph.get_nodes_with_target(tool)
assert tool_neighbors is not None
# Check if there is a PythonFunction in the tool's neighbors
assert any("PythonFunction" in neighbor.data["type"] for neighbor in tool_neighbors)
def test_get_node():
def test_get_node(basic_graph):
"""Test getting a single node"""
graph = get_graph()
node_id = graph.nodes[0].id
node = graph.get_node(node_id)
node_id = basic_graph.nodes[0].id
node = basic_graph.get_node(node_id)
assert isinstance(node, Node)
assert node.id == node_id
def test_build_nodes():
def test_build_nodes(basic_graph):
"""Test building nodes"""
graph = get_graph()
assert len(graph.nodes) == len(graph._nodes)
for node in graph.nodes:
assert len(basic_graph.nodes) == len(basic_graph._nodes)
for node in basic_graph.nodes:
assert isinstance(node, Node)
def test_build_edges():
def test_build_edges(basic_graph):
"""Test building edges"""
graph = get_graph()
assert len(graph.edges) == len(graph._edges)
for edge in graph.edges:
assert len(basic_graph.edges) == len(basic_graph._edges)
for edge in basic_graph.edges:
assert isinstance(edge, Edge)
assert isinstance(edge.source, Node)
assert isinstance(edge.target, Node)
def test_get_root_node():
def test_get_root_node(basic_graph, complex_graph):
"""Test getting root node"""
graph = get_graph(basic=True)
assert isinstance(graph, Graph)
root = get_root_node(graph)
assert isinstance(basic_graph, Graph)
root = get_root_node(basic_graph)
assert root is not None
assert isinstance(root, Node)
assert root.data["type"] == "ZeroShotAgent"
# For complex example, the root node is a ZeroShotAgent too
graph = get_graph(basic=False)
assert isinstance(graph, Graph)
root = get_root_node(graph)
assert isinstance(complex_graph, Graph)
root = get_root_node(complex_graph)
assert root is not None
assert isinstance(root, Node)
assert root.data["type"] == "ZeroShotAgent"
def test_build_json():
def test_build_json(basic_graph):
"""Test building JSON from graph"""
graph = get_graph()
assert isinstance(graph, Graph)
root = get_root_node(graph)
json_data = build_json(root, graph)
assert isinstance(basic_graph, Graph)
root = get_root_node(basic_graph)
json_data = build_json(root, basic_graph)
assert isinstance(json_data, dict)
assert json_data["_type"] == "zero-shot-react-description"
assert isinstance(json_data["llm_chain"], dict)
@ -188,38 +269,37 @@ def test_build_json():
assert all(isinstance(val, str) for val in json_data["return_values"])
def test_validate_edges():
def test_validate_edges(basic_graph):
"""Test validating edges"""
graph = get_graph()
assert isinstance(graph, Graph)
assert isinstance(basic_graph, Graph)
# all edges should be valid
assert all(edge.valid for edge in graph.edges)
assert all(edge.valid for edge in basic_graph.edges)
def test_matched_type():
def test_matched_type(basic_graph):
"""Test matched type attribute in Edge"""
graph = get_graph()
assert isinstance(graph, Graph)
assert isinstance(basic_graph, Graph)
# all edges should be valid
assert all(edge.valid for edge in graph.edges)
assert all(edge.valid for edge in basic_graph.edges)
# all edges should have a matched_type attribute
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
# The matched_type attribute should be in the source_types attr
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
def test_build_params():
def test_build_params(basic_graph):
"""Test building params"""
graph = get_graph()
assert isinstance(graph, Graph)
assert isinstance(basic_graph, Graph)
# all edges should be valid
assert all(edge.valid for edge in graph.edges)
assert all(edge.valid for edge in basic_graph.edges)
# all edges should have a matched_type attribute
assert all(hasattr(edge, "matched_type") for edge in graph.edges)
assert all(hasattr(edge, "matched_type") for edge in basic_graph.edges)
# The matched_type attribute should be in the source_types attr
assert all(edge.matched_type in edge.source_types for edge in graph.edges)
assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges)
# Get the root node
root = get_root_node(graph)
root = get_root_node(basic_graph)
# Root node is a ZeroShotAgent
# which requires an llm_chain, allowed_tools and return_values
assert isinstance(root.params, dict)
@ -261,7 +341,7 @@ def test_build_params():
assert isinstance(llm_node.params["model_name"], str)
def test_build():
def test_build(basic_graph, complex_graph):
"""Test Node's build method"""
# def build(self):
# # The params dict is used to build the module
@ -284,18 +364,81 @@ def test_build():
# # and instantiate it with the params
# # and return the instance
# return LANGCHAIN_TYPES_DICT[self.node_type](**self.params)
graph = get_graph()
assert isinstance(graph, Graph)
assert isinstance(basic_graph, Graph)
# Now we test the build method
# Build the Agent
agent = graph.build()
agent = basic_graph.build()
# The agent should be a AgentExecutor
assert isinstance(agent, AgentExecutor)
# Now we test the complex example
graph = get_graph(basic=False)
assert isinstance(graph, Graph)
assert isinstance(complex_graph, Graph)
# Now we test the build method
agent = graph.build()
agent = complex_graph.build()
# The agent should be a AgentExecutor
assert isinstance(agent, AgentExecutor)
def test_agent_node_build(basic_graph):
agent_node = get_node_by_type(basic_graph, AgentNode)
assert agent_node is not None
built_object = agent_node.build()
assert built_object is not None
# Add any further assertions specific to the AgentNode's build() method
def test_tool_node_build(basic_graph):
tool_node = get_node_by_type(basic_graph, ToolNode)
assert tool_node is not None
built_object = tool_node.build()
assert built_object is not None
# Add any further assertions specific to the ToolNode's build() method
def test_chain_node_build(complex_graph):
chain_node = get_node_by_type(complex_graph, ChainNode)
assert chain_node is not None
built_object = chain_node.build()
assert built_object is not None
# Add any further assertions specific to the ChainNode's build() method
def test_prompt_node_build(complex_graph):
prompt_node = get_node_by_type(complex_graph, PromptNode)
assert prompt_node is not None
built_object = prompt_node.build()
assert built_object is not None
# Add any further assertions specific to the PromptNode's build() method
def test_llm_node_build(basic_graph):
llm_node = get_node_by_type(basic_graph, LLMNode)
assert llm_node is not None
built_object = llm_node.build()
assert built_object is not None
# Add any further assertions specific to the LLMNode's build() method
def test_toolkit_node_build(openapi_graph):
toolkit_node = get_node_by_type(openapi_graph, ToolkitNode)
assert toolkit_node is not None
built_object = toolkit_node.build()
assert built_object is not None
# Add any further assertions specific to the ToolkitNode's build() method
def test_file_tool_node_build(openapi_graph):
file_tool_node = get_node_by_type(openapi_graph, FileToolNode)
assert file_tool_node is not None
built_object = file_tool_node.build()
assert built_object is not None
# Add any further assertions specific to the FileToolNode's build() method
def test_wrapper_node_build(openapi_graph):
wrapper_node = get_node_by_type(openapi_graph, WrapperNode)
assert wrapper_node is not None
built_object = wrapper_node.build()
assert built_object is not None
# Add any further assertions specific to the WrapperNode's build() method

291
tests/test_template.py Normal file
View file

@ -0,0 +1,291 @@
from langflow.utils.constants import CHAT_OPENAI_MODELS, OPENAI_MODELS
from pydantic import BaseModel
import pytest
import re
import importlib
from typing import Dict, List, Optional
from langflow.utils.util import (
build_template_from_class,
build_template_from_function,
format_dict,
get_base_classes,
get_default_factory,
get_class_doc,
)
# Dummy classes for testing purposes
class Parent(BaseModel):
"""Parent Class"""
parent_field: str
class Child(Parent):
"""Child Class"""
child_field: int
class ExampleClass1(BaseModel):
"""Example class 1."""
def __init__(self, data: Optional[List[int]] = None):
self.data = data or [1, 2, 3]
class ExampleClass2(BaseModel):
"""Example class 2."""
def __init__(self, data: Optional[Dict[str, int]] = None):
self.data = data or {"a": 1, "b": 2, "c": 3}
def example_loader_1() -> ExampleClass1:
"""Example loader function 1."""
return ExampleClass1()
def example_loader_2() -> ExampleClass2:
"""Example loader function 2."""
return ExampleClass2()
def test_build_template_from_function():
type_to_loader_dict = {
"example1": example_loader_1,
"example2": example_loader_2,
}
# Test with valid name
result = build_template_from_function("ExampleClass1", type_to_loader_dict)
assert "template" in result
assert "description" in result
assert "base_classes" in result
# Test with add_function=True
result_with_function = build_template_from_function(
"ExampleClass1", type_to_loader_dict, add_function=True
)
assert "function" in result_with_function["base_classes"]
# Test with invalid name
with pytest.raises(ValueError, match=r".* not found"):
build_template_from_function("NonExistent", type_to_loader_dict)
# Test build_template_from_class
def test_build_template_from_class():
type_to_cls_dict: Dict[str, type] = {"parent": Parent, "child": Child}
# Test valid input
result = build_template_from_class("Child", type_to_cls_dict)
assert "template" in result
assert "description" in result
assert "base_classes" in result
assert "Child" in result["base_classes"]
assert "Parent" in result["base_classes"]
assert result["description"] == "Child Class"
# Test invalid input
with pytest.raises(ValueError, match="InvalidClass not found."):
build_template_from_class("InvalidClass", type_to_cls_dict)
# Test format_dict
def test_format_dict():
# Test 1: Optional type removal
input_dict = {
"field1": {"type": "Optional[str]", "required": False},
}
expected_output = {
"field1": {
"type": "str",
"required": False,
"list": False,
"show": False,
"password": False,
"multiline": False,
},
}
assert format_dict(input_dict) == expected_output
# Test 2: List type processing
input_dict = {
"field1": {"type": "List[str]", "required": False},
}
expected_output = {
"field1": {
"type": "str",
"required": False,
"list": True,
"show": False,
"password": False,
"multiline": False,
},
}
assert format_dict(input_dict) == expected_output
# Test 3: Mapping type replacement
input_dict = {
"field1": {"type": "Mapping[str, int]", "required": False},
}
expected_output = {
"field1": {
"type": "code", # Mapping type is replaced with dict which is replaced with code
"required": False,
"list": False,
"show": False,
"password": False,
"multiline": False,
},
}
assert format_dict(input_dict) == expected_output
# Test 4: Replace default value with actual value
input_dict = {
"field1": {"type": "str", "required": False, "default": "test"},
}
expected_output = {
"field1": {
"type": "str",
"required": False,
"list": False,
"show": False,
"password": False,
"multiline": False,
"value": "test",
},
}
assert format_dict(input_dict) == expected_output
# Test 5: Add password field
input_dict = {
"field1": {"type": "str", "required": False},
"api_key": {"type": "str", "required": False},
}
expected_output = {
"field1": {
"type": "str",
"required": False,
"list": False,
"show": False,
"password": False,
"multiline": False,
},
"api_key": {
"type": "str",
"required": False,
"list": False,
"show": True,
"password": True,
"multiline": False,
},
}
assert format_dict(input_dict) == expected_output
# Test 6: Add multiline
input_dict = {
"field1": {"type": "str", "required": False},
"prefix": {"type": "str", "required": False},
}
expected_output = {
"field1": {
"type": "str",
"required": False,
"list": False,
"show": False,
"password": False,
"multiline": False,
},
"prefix": {
"type": "str",
"required": False,
"list": False,
"show": True,
"password": False,
"multiline": True,
},
}
assert format_dict(input_dict) == expected_output
# Test 7: Check class name-specific cases (OpenAI, OpenAIChat)
input_dict = {
"model_name": {"type": "str", "required": False},
}
expected_output_openai = {
"model_name": {
"type": "str",
"required": False,
"list": True,
"show": True,
"password": False,
"multiline": False,
"options": OPENAI_MODELS,
},
}
expected_output_openai_chat = {
"model_name": {
"type": "str",
"required": False,
"list": True,
"show": True,
"password": False,
"multiline": False,
"options": CHAT_OPENAI_MODELS,
},
}
assert format_dict(input_dict, "OpenAI") == expected_output_openai
assert format_dict(input_dict, "OpenAIChat") == expected_output_openai_chat
# Test 8: Replace dict type with str
input_dict = {
"field1": {"type": "Dict[str, int]", "required": False},
}
expected_output = {
"field1": {
"type": "code",
"required": False,
"list": False,
"show": False,
"password": False,
"multiline": False,
},
}
assert format_dict(input_dict) == expected_output
# Test get_base_classes
def test_get_base_classes():
base_classes_parent = get_base_classes(Parent)
base_classes_child = get_base_classes(Child)
assert "Parent" in base_classes_parent
assert "Child" in base_classes_child
assert "Parent" in base_classes_child
# Test get_default_factory
def test_get_default_factory():
module_name = "langflow.utils.util"
function_repr = "<function dummy_function>"
def dummy_function():
return "default_value"
# Add dummy_function to your_module
setattr(importlib.import_module(module_name), "dummy_function", dummy_function)
default_value = get_default_factory(module_name, function_repr)
assert default_value == "default_value"
# Test get_class_doc
def test_get_class_doc():
class_doc_parent = get_class_doc(Parent)
class_doc_child = get_class_doc(Child)
assert class_doc_parent["Description"] == "Parent Class"
assert class_doc_child["Description"] == "Child Class"