fix: load_flow_from_json now uses the Graph

This commit is contained in:
Gabriel Almeida 2023-03-28 22:04:27 -03:00
commit 86e65575ee
2 changed files with 7 additions and 102 deletions

View file

@ -52,17 +52,13 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
def load_flow_from_json(path: str):
# This is done to avoid circular imports
from langflow.graph.graph import Graph
"""Load flow from json file"""
with open(path, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
extracted_json = extract_json(data_graph)
return load_langchain_type_from_config(config=extracted_json)
def extract_json(data_graph):
from langflow.graph.graph import Graph
nodes = data_graph["nodes"]
# Substitute ZeroShotPrompt with PromptTemplate
nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
@ -71,8 +67,7 @@ def extract_json(data_graph):
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
root = payload.get_root_node(graph)
return payload.build_json(root, graph)
return graph.build()
def replace_zero_shot_prompt_with_prompt_template(nodes):

View file

@ -1,27 +1,17 @@
import json
from langchain import LLMChain, OpenAI
from langflow.graph.graph import Graph
import pytest
from langflow import load_flow_from_json
from langflow.interface.loading import extract_json
from langflow.utils.payload import get_root_node, build_json
from langflow.interface.loading import load_langchain_type_from_config
from langflow.utils.payload import get_root_node
from langchain.agents import AgentExecutor
def test_load_flow_from_json():
"""Test loading a flow from a json file"""
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH)
assert loaded is not None
def test_extract_json():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
extracted = extract_json(data_graph)
assert extracted is not None
assert isinstance(extracted, dict)
assert isinstance(loaded, AgentExecutor)
def test_get_root_node():
@ -35,83 +25,3 @@ def test_get_root_node():
assert root is not None
assert hasattr(root, "id")
assert hasattr(root, "data")
def test_build_json():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
graph = Graph(nodes, edges)
root = get_root_node(graph)
built_json = build_json(root, graph)
assert built_json is not None
assert isinstance(built_json, dict)
def test_build_json_missing_child():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
# Modify nodes to create a missing required child scenario
for node in nodes:
if "data" in node and "node" in node["data"]:
for key, value in node["data"]["node"]["template"].items():
if isinstance(value, dict) and "required" in value:
value["required"] = True
with pytest.raises(ValueError):
graph = Graph(nodes, edges)
root = get_root_node(graph)
build_json(root, graph)
def test_build_json_no_nodes():
with pytest.raises(TypeError):
build_json(None, [], [])
def test_build_json_invalid_edge():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
edges = data_graph["edges"]
# Modify edges to create an invalid edge scenario
for edge in edges:
edge["source"] = "invalid_id"
with pytest.raises(ValueError):
graph = Graph(nodes, edges)
root = get_root_node(graph)
build_json(root, nodes, edges)
def test_load_langchain_type_from_config():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
extracted = extract_json(data_graph)
agent_config = extracted.copy()
agent_type = "AgentExecutor" # Replace with the actual agent type in the JSON
invalid_config = extracted.copy()
invalid_config["_type"] = "invalid_type"
agent_loaded = load_langchain_type_from_config(agent_config)
assert agent_loaded is not None
assert agent_loaded.__class__.__name__ == agent_type
assert hasattr(agent_loaded.agent, "llm_chain")
assert isinstance(
agent_loaded.agent.llm_chain, LLMChain
) # Replace Chain with the appropriate class
assert hasattr(agent_loaded.agent.llm_chain, "llm")
assert isinstance(agent_loaded.agent.llm_chain.llm, OpenAI)
with pytest.raises(ValueError):
load_langchain_type_from_config(invalid_config)