diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 5cefdadae..4fa2f4d17 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -11,6 +11,7 @@ from langflow.graph.vertex.types import ( from langflow.interface.tools.constants import FILE_TOOLS from langflow.utils import payload from langflow.utils.logger import logger +from langchain.chains.base import Chain class Graph: @@ -99,7 +100,7 @@ class Graph: ] return connected_nodes - def build(self) -> List[Vertex]: + def build(self) -> Chain: """Builds the graph.""" # Get root node root_node = payload.get_root_node(self) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index c25f7b3d1..760f73723 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -1,5 +1,6 @@ import contextlib import io +from pathlib import Path from langchain.schema import AgentAction import json from langflow.interface.run import ( @@ -11,7 +12,7 @@ from langflow.utils.logger import logger from langflow.graph import Graph -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union def fix_memory_inputs(langchain_object): @@ -131,34 +132,50 @@ def process_graph_cached(data_graph: Dict[str, Any], message: str): return {"result": str(result), "thought": thought.strip()} -def load_flow_from_json(path: str, build=True): - """Load flow from json file""" - # This is done to avoid circular imports +def load_flow_from_json( + input: Union[str, dict], tweaks: Optional[dict] = None, build=True +): + """ + Load flow from a JSON file or a JSON object. - with open(path, "r", encoding="utf-8") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - # Substitute ZeroShotPrompt with PromptTemplate - # nodes = replace_zero_shot_prompt_with_prompt_template(nodes) - # Add input variables - # nodes = payload.extract_input_variables(nodes) + :param input: JSON file path or JSON object + :param tweaks: Optional tweaks to be processed + :param build: If True, build the graph, otherwise return the graph object + :return: Langchain object or Graph object depending on the build parameter + """ + # If input is a file path, load JSON from the file + if isinstance(input, (str, Path)): + with open(input, "r", encoding="utf-8") as f: + flow_graph = json.load(f) + # If input is a dictionary, assume it's a JSON object + elif isinstance(input, dict): + flow_graph = input + else: + raise TypeError( + "Input must be either a file path (str) or a JSON object (dict)" + ) - # Nodes, edges and root node - edges = data_graph["edges"] + graph_data = flow_graph["data"] + if tweaks is not None: + graph_data = process_tweaks(graph_data, tweaks) + nodes = graph_data["nodes"] + edges = graph_data["edges"] graph = Graph(nodes, edges) + if build: langchain_object = graph.build() + if hasattr(langchain_object, "verbose"): langchain_object.verbose = True if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 # Deactivating until we have a frontend solution # to display intermediate steps langchain_object.return_intermediate_steps = False + fix_memory_inputs(langchain_object) return langchain_object + return graph @@ -181,7 +198,4 @@ def process_tweaks(graph_data: Dict, tweaks: Dict): for tweak_name, tweake_value in node_tweaks.items(): if tweak_name in template_data: template_data[tweak_name]["value"] = tweake_value - print( - f"Something changed in node {node_id} with tweak {tweak_name} and value {tweake_value}" - ) return graph_data diff --git a/tests/test_loading.py b/tests/test_loading.py index 885eb7a82..11fa8e471 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -14,6 +14,15 @@ def test_load_flow_from_json(): assert isinstance(loaded, Chain) +def test_load_flow_from_json_with_tweaks(): + """Test loading a flow from a json file and applying tweaks""" + tweaks = {"dndnode_82": {"model_name": "test model"}} + loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks) + assert loaded is not None + assert isinstance(loaded, Chain) + assert loaded.llm.model_name == "test model" + + def test_get_root_node(): with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f)