fix: load_flow_from_json now uses the Graph
This commit is contained in:
parent
64bc7c40ed
commit
86e65575ee
2 changed files with 7 additions and 102 deletions
|
|
@ -52,17 +52,13 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
|
||||
|
||||
def load_flow_from_json(path: str):
|
||||
# This is done to avoid circular imports
|
||||
from langflow.graph.graph import Graph
|
||||
|
||||
"""Load flow from json file"""
|
||||
with open(path, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
extracted_json = extract_json(data_graph)
|
||||
return load_langchain_type_from_config(config=extracted_json)
|
||||
|
||||
|
||||
def extract_json(data_graph):
|
||||
from langflow.graph.graph import Graph
|
||||
|
||||
nodes = data_graph["nodes"]
|
||||
# Substitute ZeroShotPrompt with PromptTemplate
|
||||
nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
||||
|
|
@ -71,8 +67,7 @@ def extract_json(data_graph):
|
|||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
root = payload.get_root_node(graph)
|
||||
return payload.build_json(root, graph)
|
||||
return graph.build()
|
||||
|
||||
|
||||
def replace_zero_shot_prompt_with_prompt_template(nodes):
|
||||
|
|
|
|||
|
|
@ -1,27 +1,17 @@
|
|||
import json
|
||||
from langchain import LLMChain, OpenAI
|
||||
from langflow.graph.graph import Graph
|
||||
import pytest
|
||||
|
||||
from langflow import load_flow_from_json
|
||||
from langflow.interface.loading import extract_json
|
||||
from langflow.utils.payload import get_root_node, build_json
|
||||
from langflow.interface.loading import load_langchain_type_from_config
|
||||
from langflow.utils.payload import get_root_node
|
||||
from langchain.agents import AgentExecutor
|
||||
|
||||
|
||||
def test_load_flow_from_json():
|
||||
"""Test loading a flow from a json file"""
|
||||
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH)
|
||||
assert loaded is not None
|
||||
|
||||
|
||||
def test_extract_json():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
extracted = extract_json(data_graph)
|
||||
assert extracted is not None
|
||||
assert isinstance(extracted, dict)
|
||||
assert isinstance(loaded, AgentExecutor)
|
||||
|
||||
|
||||
def test_get_root_node():
|
||||
|
|
@ -35,83 +25,3 @@ def test_get_root_node():
|
|||
assert root is not None
|
||||
assert hasattr(root, "id")
|
||||
assert hasattr(root, "data")
|
||||
|
||||
|
||||
def test_build_json():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
built_json = build_json(root, graph)
|
||||
assert built_json is not None
|
||||
assert isinstance(built_json, dict)
|
||||
|
||||
|
||||
def test_build_json_missing_child():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
|
||||
# Modify nodes to create a missing required child scenario
|
||||
for node in nodes:
|
||||
if "data" in node and "node" in node["data"]:
|
||||
for key, value in node["data"]["node"]["template"].items():
|
||||
if isinstance(value, dict) and "required" in value:
|
||||
value["required"] = True
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
build_json(root, graph)
|
||||
|
||||
|
||||
def test_build_json_no_nodes():
|
||||
with pytest.raises(TypeError):
|
||||
build_json(None, [], [])
|
||||
|
||||
|
||||
def test_build_json_invalid_edge():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
# Modify edges to create an invalid edge scenario
|
||||
for edge in edges:
|
||||
edge["source"] = "invalid_id"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
graph = Graph(nodes, edges)
|
||||
root = get_root_node(graph)
|
||||
build_json(root, nodes, edges)
|
||||
|
||||
|
||||
def test_load_langchain_type_from_config():
|
||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
extracted = extract_json(data_graph)
|
||||
|
||||
agent_config = extracted.copy()
|
||||
agent_type = "AgentExecutor" # Replace with the actual agent type in the JSON
|
||||
|
||||
invalid_config = extracted.copy()
|
||||
invalid_config["_type"] = "invalid_type"
|
||||
|
||||
agent_loaded = load_langchain_type_from_config(agent_config)
|
||||
assert agent_loaded is not None
|
||||
assert agent_loaded.__class__.__name__ == agent_type
|
||||
assert hasattr(agent_loaded.agent, "llm_chain")
|
||||
assert isinstance(
|
||||
agent_loaded.agent.llm_chain, LLMChain
|
||||
) # Replace Chain with the appropriate class
|
||||
assert hasattr(agent_loaded.agent.llm_chain, "llm")
|
||||
assert isinstance(agent_loaded.agent.llm_chain.llm, OpenAI)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
load_langchain_type_from_config(invalid_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue