From 6886828ddd020fed9a533ed84da7cc6f2f0e54b1 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 19 Jun 2023 11:36:43 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20refactor(process.py):=20change?= =?UTF-8?q?=20load=5Fflow=5Ffrom=5Fjson=20function=20signature=20to=20acce?= =?UTF-8?q?pt=20either=20a=20JSON=20file=20path=20or=20a=20JSON=20object?= =?UTF-8?q?=20=F0=9F=94=80=20refactor(base.py):=20import=20Chain=20from=20?= =?UTF-8?q?langchain.chains.base=20instead=20of=20importing=20it=20from=20?= =?UTF-8?q?langflow.graph.vertex.types=20=F0=9F=94=80=20refactor(process.p?= =?UTF-8?q?y):=20remove=20print=20statement=20from=20process=5Ftweaks=20fu?= =?UTF-8?q?nction=20=F0=9F=94=80=20refactor(process.py):=20change=20load?= =?UTF-8?q?=5Fflow=5Ffrom=5Fjson=20function=20signature=20to=20accept=20op?= =?UTF-8?q?tional=20tweaks=20parameter=20=F0=9F=94=80=20refactor(process.p?= =?UTF-8?q?y):=20change=20return=20type=20of=20build=20method=20in=20Graph?= =?UTF-8?q?=20class=20from=20List[Vertex]=20to=20Chain=20=F0=9F=A7=AA=20te?= =?UTF-8?q?st(loading.py):=20add=20test=20case=20for=20loading=20a=20flow?= =?UTF-8?q?=20from=20a=20JSON=20file=20and=20applying=20tweaks=20?= =?UTF-8?q?=F0=9F=A7=AA=20test(loading.py):=20remove=20unused=20import=20s?= =?UTF-8?q?tatement=20The=20import=20statement=20for=20Chain=20in=20base.p?= =?UTF-8?q?y=20is=20now=20more=20explicit=20and=20imports=20it=20from=20la?= =?UTF-8?q?ngchain.chains.base=20instead=20of=20importing=20it=20from=20la?= =?UTF-8?q?ngflow.graph.vertex.types.=20The=20load=5Fflow=5Ffrom=5Fjson=20?= =?UTF-8?q?function=20in=20process.py=20now=20accepts=20either=20a=20JSON?= =?UTF-8?q?=20file=20path=20or=20a=20JSON=20object.=20The=20print=20statem?= =?UTF-8?q?ent=20in=20process=5Ftweaks=20function=20has=20been=20removed.?= =?UTF-8?q?=20The=20load=5Fflow=5Ffrom=5Fjson=20function=20in=20process.py?= =?UTF-8?q?=20now=20accepts=20an=20optional=20tweaks=20parameter.=20The=20?= =?UTF-8?q?return=20type=20of=20the=20build=20method=20in=20the=20Graph=20?= =?UTF-8?q?class=20has=20been=20changed=20from=20List[Vertex]=20to=20Chain?= =?UTF-8?q?.=20A=20new=20test=20case=20has=20been=20added=20to=20loading.p?= =?UTF-8?q?y=20to=20test=20loading=20a=20flow=20from=20a=20JSON=20file=20a?= =?UTF-8?q?nd=20applying=20tweaks.=20An=20unused=20import=20statement=20ha?= =?UTF-8?q?s=20been=20removed=20from=20loading.py.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/graph/base.py | 3 +- src/backend/langflow/processing/process.py | 50 ++++++++++++++-------- tests/test_loading.py | 9 ++++ 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 5cefdadae..4fa2f4d17 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -11,6 +11,7 @@ from langflow.graph.vertex.types import ( from langflow.interface.tools.constants import FILE_TOOLS from langflow.utils import payload from langflow.utils.logger import logger +from langchain.chains.base import Chain class Graph: @@ -99,7 +100,7 @@ class Graph: ] return connected_nodes - def build(self) -> List[Vertex]: + def build(self) -> Chain: """Builds the graph.""" # Get root node root_node = payload.get_root_node(self) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index c25f7b3d1..760f73723 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -1,5 +1,6 @@ import contextlib import io +from pathlib import Path from langchain.schema import AgentAction import json from langflow.interface.run import ( @@ -11,7 +12,7 @@ from langflow.utils.logger import logger from langflow.graph import Graph -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union def fix_memory_inputs(langchain_object): @@ -131,34 +132,50 @@ def process_graph_cached(data_graph: Dict[str, Any], message: str): return {"result": str(result), "thought": thought.strip()} -def load_flow_from_json(path: str, build=True): - """Load flow from json file""" - # This is done to avoid circular imports +def load_flow_from_json( + input: Union[str, dict], tweaks: Optional[dict] = None, build=True +): + """ + Load flow from a JSON file or a JSON object. - with open(path, "r", encoding="utf-8") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - # Substitute ZeroShotPrompt with PromptTemplate - # nodes = replace_zero_shot_prompt_with_prompt_template(nodes) - # Add input variables - # nodes = payload.extract_input_variables(nodes) + :param input: JSON file path or JSON object + :param tweaks: Optional tweaks to be processed + :param build: If True, build the graph, otherwise return the graph object + :return: Langchain object or Graph object depending on the build parameter + """ + # If input is a file path, load JSON from the file + if isinstance(input, (str, Path)): + with open(input, "r", encoding="utf-8") as f: + flow_graph = json.load(f) + # If input is a dictionary, assume it's a JSON object + elif isinstance(input, dict): + flow_graph = input + else: + raise TypeError( + "Input must be either a file path (str) or a JSON object (dict)" + ) - # Nodes, edges and root node - edges = data_graph["edges"] + graph_data = flow_graph["data"] + if tweaks is not None: + graph_data = process_tweaks(graph_data, tweaks) + nodes = graph_data["nodes"] + edges = graph_data["edges"] graph = Graph(nodes, edges) + if build: langchain_object = graph.build() + if hasattr(langchain_object, "verbose"): langchain_object.verbose = True if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 # Deactivating until we have a frontend solution # to display intermediate steps langchain_object.return_intermediate_steps = False + fix_memory_inputs(langchain_object) return langchain_object + return graph @@ -181,7 +198,4 @@ def process_tweaks(graph_data: Dict, tweaks: Dict): for tweak_name, tweake_value in node_tweaks.items(): if tweak_name in template_data: template_data[tweak_name]["value"] = tweake_value - print( - f"Something changed in node {node_id} with tweak {tweak_name} and value {tweake_value}" - ) return graph_data diff --git a/tests/test_loading.py b/tests/test_loading.py index 885eb7a82..11fa8e471 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -14,6 +14,15 @@ def test_load_flow_from_json(): assert isinstance(loaded, Chain) +def test_load_flow_from_json_with_tweaks(): + """Test loading a flow from a json file and applying tweaks""" + tweaks = {"dndnode_82": {"model_name": "test model"}} + loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks) + assert loaded is not None + assert isinstance(loaded, Chain) + assert loaded.llm.model_name == "test model" + + def test_get_root_node(): with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f)