🔀 refactor(process.py): change load_flow_from_json function signature to accept either a JSON file path or a JSON object
🔀 refactor(base.py): import Chain from langchain.chains.base instead of importing it from langflow.graph.vertex.types 🔀 refactor(process.py): remove print statement from process_tweaks function 🔀 refactor(process.py): change load_flow_from_json function signature to accept optional tweaks parameter 🔀 refactor(process.py): change return type of build method in Graph class from List[Vertex] to Chain 🧪 test(loading.py): add test case for loading a flow from a JSON file and applying tweaks 🧪 test(loading.py): remove unused import statement The import statement for Chain in base.py is now more explicit and imports it from langchain.chains.base instead of importing it from langflow.graph.vertex.types. The load_flow_from_json function in process.py now accepts either a JSON file path or a JSON object. The print statement in process_tweaks function has been removed. The load_flow_from_json function in process.py now accepts an optional tweaks parameter. The return type of the build method in the Graph class has been changed from List[Vertex] to Chain. A new test case has been added to loading.py to test loading a flow from a JSON file and applying tweaks. An unused import statement has been removed from loading.py.
This commit is contained in:
parent
5ea20aa2f0
commit
6886828ddd
3 changed files with 43 additions and 19 deletions
|
|
@ -11,6 +11,7 @@ from langflow.graph.vertex.types import (
|
||||||
from langflow.interface.tools.constants import FILE_TOOLS
|
from langflow.interface.tools.constants import FILE_TOOLS
|
||||||
from langflow.utils import payload
|
from langflow.utils import payload
|
||||||
from langflow.utils.logger import logger
|
from langflow.utils.logger import logger
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
|
|
@ -99,7 +100,7 @@ class Graph:
|
||||||
]
|
]
|
||||||
return connected_nodes
|
return connected_nodes
|
||||||
|
|
||||||
def build(self) -> List[Vertex]:
|
def build(self) -> Chain:
|
||||||
"""Builds the graph."""
|
"""Builds the graph."""
|
||||||
# Get root node
|
# Get root node
|
||||||
root_node = payload.get_root_node(self)
|
root_node = payload.get_root_node(self)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
|
from pathlib import Path
|
||||||
from langchain.schema import AgentAction
|
from langchain.schema import AgentAction
|
||||||
import json
|
import json
|
||||||
from langflow.interface.run import (
|
from langflow.interface.run import (
|
||||||
|
|
@ -11,7 +12,7 @@ from langflow.utils.logger import logger
|
||||||
from langflow.graph import Graph
|
from langflow.graph import Graph
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def fix_memory_inputs(langchain_object):
|
def fix_memory_inputs(langchain_object):
|
||||||
|
|
@ -131,34 +132,50 @@ def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
||||||
return {"result": str(result), "thought": thought.strip()}
|
return {"result": str(result), "thought": thought.strip()}
|
||||||
|
|
||||||
|
|
||||||
def load_flow_from_json(path: str, build=True):
|
def load_flow_from_json(
|
||||||
"""Load flow from json file"""
|
input: Union[str, dict], tweaks: Optional[dict] = None, build=True
|
||||||
# This is done to avoid circular imports
|
):
|
||||||
|
"""
|
||||||
|
Load flow from a JSON file or a JSON object.
|
||||||
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
:param input: JSON file path or JSON object
|
||||||
flow_graph = json.load(f)
|
:param tweaks: Optional tweaks to be processed
|
||||||
data_graph = flow_graph["data"]
|
:param build: If True, build the graph, otherwise return the graph object
|
||||||
nodes = data_graph["nodes"]
|
:return: Langchain object or Graph object depending on the build parameter
|
||||||
# Substitute ZeroShotPrompt with PromptTemplate
|
"""
|
||||||
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
# If input is a file path, load JSON from the file
|
||||||
# Add input variables
|
if isinstance(input, (str, Path)):
|
||||||
# nodes = payload.extract_input_variables(nodes)
|
with open(input, "r", encoding="utf-8") as f:
|
||||||
|
flow_graph = json.load(f)
|
||||||
|
# If input is a dictionary, assume it's a JSON object
|
||||||
|
elif isinstance(input, dict):
|
||||||
|
flow_graph = input
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Input must be either a file path (str) or a JSON object (dict)"
|
||||||
|
)
|
||||||
|
|
||||||
# Nodes, edges and root node
|
graph_data = flow_graph["data"]
|
||||||
edges = data_graph["edges"]
|
if tweaks is not None:
|
||||||
|
graph_data = process_tweaks(graph_data, tweaks)
|
||||||
|
nodes = graph_data["nodes"]
|
||||||
|
edges = graph_data["edges"]
|
||||||
graph = Graph(nodes, edges)
|
graph = Graph(nodes, edges)
|
||||||
|
|
||||||
if build:
|
if build:
|
||||||
langchain_object = graph.build()
|
langchain_object = graph.build()
|
||||||
|
|
||||||
if hasattr(langchain_object, "verbose"):
|
if hasattr(langchain_object, "verbose"):
|
||||||
langchain_object.verbose = True
|
langchain_object.verbose = True
|
||||||
|
|
||||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||||
# https://github.com/hwchase17/langchain/issues/2068
|
|
||||||
# Deactivating until we have a frontend solution
|
# Deactivating until we have a frontend solution
|
||||||
# to display intermediate steps
|
# to display intermediate steps
|
||||||
langchain_object.return_intermediate_steps = False
|
langchain_object.return_intermediate_steps = False
|
||||||
|
|
||||||
fix_memory_inputs(langchain_object)
|
fix_memory_inputs(langchain_object)
|
||||||
return langchain_object
|
return langchain_object
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,7 +198,4 @@ def process_tweaks(graph_data: Dict, tweaks: Dict):
|
||||||
for tweak_name, tweake_value in node_tweaks.items():
|
for tweak_name, tweake_value in node_tweaks.items():
|
||||||
if tweak_name in template_data:
|
if tweak_name in template_data:
|
||||||
template_data[tweak_name]["value"] = tweake_value
|
template_data[tweak_name]["value"] = tweake_value
|
||||||
print(
|
|
||||||
f"Something changed in node {node_id} with tweak {tweak_name} and value {tweake_value}"
|
|
||||||
)
|
|
||||||
return graph_data
|
return graph_data
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,15 @@ def test_load_flow_from_json():
|
||||||
assert isinstance(loaded, Chain)
|
assert isinstance(loaded, Chain)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_flow_from_json_with_tweaks():
|
||||||
|
"""Test loading a flow from a json file and applying tweaks"""
|
||||||
|
tweaks = {"dndnode_82": {"model_name": "test model"}}
|
||||||
|
loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks)
|
||||||
|
assert loaded is not None
|
||||||
|
assert isinstance(loaded, Chain)
|
||||||
|
assert loaded.llm.model_name == "test model"
|
||||||
|
|
||||||
|
|
||||||
def test_get_root_node():
|
def test_get_root_node():
|
||||||
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
|
||||||
flow_graph = json.load(f)
|
flow_graph = json.load(f)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue