diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index bfb919f5d..1c73368e9 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -52,17 +52,13 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any: def load_flow_from_json(path: str): + # This is done to avoid circular imports + from langflow.graph.graph import Graph + """Load flow from json file""" with open(path, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] - extracted_json = extract_json(data_graph) - return load_langchain_type_from_config(config=extracted_json) - - -def extract_json(data_graph): - from langflow.graph.graph import Graph - nodes = data_graph["nodes"] # Substitute ZeroShotPrompt with PromptTemplate nodes = replace_zero_shot_prompt_with_prompt_template(nodes) @@ -71,8 +67,7 @@ def extract_json(data_graph): # Nodes, edges and root node edges = data_graph["edges"] graph = Graph(nodes, edges) - root = payload.get_root_node(graph) - return payload.build_json(root, graph) + return graph.build() def replace_zero_shot_prompt_with_prompt_template(nodes): diff --git a/tests/test_loading.py b/tests/test_loading.py index b5742462f..a824ec4e5 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -1,27 +1,17 @@ import json -from langchain import LLMChain, OpenAI from langflow.graph.graph import Graph import pytest from langflow import load_flow_from_json -from langflow.interface.loading import extract_json -from langflow.utils.payload import get_root_node, build_json -from langflow.interface.loading import load_langchain_type_from_config +from langflow.utils.payload import get_root_node +from langchain.agents import AgentExecutor def test_load_flow_from_json(): """Test loading a flow from a json file""" loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH) assert loaded is not None - - -def test_extract_json(): - with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - extracted = extract_json(data_graph) - assert extracted is not None - assert isinstance(extracted, dict) + assert isinstance(loaded, AgentExecutor) def test_get_root_node(): @@ -35,83 +25,3 @@ def test_get_root_node(): assert root is not None assert hasattr(root, "id") assert hasattr(root, "data") - - -def test_build_json(): - with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - edges = data_graph["edges"] - graph = Graph(nodes, edges) - root = get_root_node(graph) - built_json = build_json(root, graph) - assert built_json is not None - assert isinstance(built_json, dict) - - -def test_build_json_missing_child(): - with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - edges = data_graph["edges"] - - # Modify nodes to create a missing required child scenario - for node in nodes: - if "data" in node and "node" in node["data"]: - for key, value in node["data"]["node"]["template"].items(): - if isinstance(value, dict) and "required" in value: - value["required"] = True - - with pytest.raises(ValueError): - graph = Graph(nodes, edges) - root = get_root_node(graph) - build_json(root, graph) - - -def test_build_json_no_nodes(): - with pytest.raises(TypeError): - build_json(None, [], []) - - -def test_build_json_invalid_edge(): - with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - edges = data_graph["edges"] - # Modify edges to create an invalid edge scenario - for edge in edges: - edge["source"] = "invalid_id" - - with pytest.raises(ValueError): - graph = Graph(nodes, edges) - root = get_root_node(graph) - build_json(root, nodes, edges) - - -def test_load_langchain_type_from_config(): - with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - extracted = extract_json(data_graph) - - agent_config = extracted.copy() - agent_type = "AgentExecutor" # Replace with the actual agent type in the JSON - - invalid_config = extracted.copy() - invalid_config["_type"] = "invalid_type" - - agent_loaded = load_langchain_type_from_config(agent_config) - assert agent_loaded is not None - assert agent_loaded.__class__.__name__ == agent_type - assert hasattr(agent_loaded.agent, "llm_chain") - assert isinstance( - agent_loaded.agent.llm_chain, LLMChain - ) # Replace Chain with the appropriate class - assert hasattr(agent_loaded.agent.llm_chain, "llm") - assert isinstance(agent_loaded.agent.llm_chain.llm, OpenAI) - - with pytest.raises(ValueError): - load_langchain_type_from_config(invalid_config)