From 60a55784af5efd82b7ba1f6a011c652eb699468b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 17 Dec 2023 21:27:07 -0300 Subject: [PATCH] Refactor flow loading and processing --- src/backend/langflow/__init__.py | 7 +-- src/backend/langflow/processing/load.py | 50 ++++++++++++++++++++++ src/backend/langflow/processing/process.py | 50 ++-------------------- tests/test_loading.py | 2 +- 4 files changed, 58 insertions(+), 51 deletions(-) create mode 100644 src/backend/langflow/processing/load.py diff --git a/src/backend/langflow/__init__.py b/src/backend/langflow/__init__.py index d3afbb4af..2b6dd4fb8 100644 --- a/src/backend/langflow/__init__.py +++ b/src/backend/langflow/__init__.py @@ -1,9 +1,10 @@ from importlib import metadata +from langflow.interface.custom.custom_component import CustomComponent + # Deactivate cache manager for now # from langflow.services.cache import cache_service -from langflow.processing.process import load_flow_from_json -from langflow.interface.custom.custom_component import CustomComponent +from langflow.processing.load import load_flow_from_json try: __version__ = metadata.version(__package__) @@ -12,4 +13,4 @@ except metadata.PackageNotFoundError: __version__ = "" del metadata # optional, avoids polluting the results of dir(__package__) -__all__ = ["load_flow_from_json", "cache_service", "CustomComponent"] +__all__ = ["load_flow_from_json", "CustomComponent"] diff --git a/src/backend/langflow/processing/load.py b/src/backend/langflow/processing/load.py new file mode 100644 index 000000000..8733b7b12 --- /dev/null +++ b/src/backend/langflow/processing/load.py @@ -0,0 +1,50 @@ +from pathlib import Path +from typing import Optional, Union + +from langflow.graph import Graph +from langflow.processing.process import fix_memory_inputs, process_tweaks + + +def load_flow_from_json(flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True): + """ + Load flow from a JSON file or a JSON object. + + :param flow: JSON file path or JSON object + :param tweaks: Optional tweaks to be processed + :param build: If True, build the graph, otherwise return the graph object + :return: Langchain object or Graph object depending on the build parameter + """ + # If input is a file path, load JSON from the file + if isinstance(flow, (str, Path)): + with open(flow, "r", encoding="utf-8") as f: + flow_graph = json.load(f) + # If input is a dictionary, assume it's a JSON object + elif isinstance(flow, dict): + flow_graph = flow + else: + raise TypeError("Input must be either a file path (str) or a JSON object (dict)") + + graph_data = flow_graph["data"] + if tweaks is not None: + graph_data = process_tweaks(graph_data, tweaks) + nodes = graph_data["nodes"] + edges = graph_data["edges"] + graph = Graph(nodes, edges) + + if build: + langchain_object = asyncio.run(graph.build()) + + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + + if hasattr(langchain_object, "return_intermediate_steps"): + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + + fix_memory_inputs(langchain_object) + return langchain_object + + return graph + + return graph diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 80846ec11..4b7cf8470 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -1,19 +1,16 @@ import asyncio -import json -from pathlib import Path from typing import Any, Coroutine, Dict, List, Optional, Tuple, Union from langchain.agents import AgentExecutor from langchain.chains.base import Chain from langchain.schema import AgentAction, Document from langchain.vectorstores.base import VectorStore +from langchain_core.runnables.base import Runnable from langflow.components.custom_components import CustomComponent -from langflow.graph import Graph from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys from langflow.services.deps import get_session_service from loguru import logger from pydantic import BaseModel -from langchain_core.runnables.base import Runnable def fix_memory_inputs(langchain_object): @@ -179,49 +176,6 @@ async def process_graph_cached( return Result(result=result, session_id=session_id) -def load_flow_from_json(flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True): - """ - Load flow from a JSON file or a JSON object. - - :param flow: JSON file path or JSON object - :param tweaks: Optional tweaks to be processed - :param build: If True, build the graph, otherwise return the graph object - :return: Langchain object or Graph object depending on the build parameter - """ - # If input is a file path, load JSON from the file - if isinstance(flow, (str, Path)): - with open(flow, "r", encoding="utf-8") as f: - flow_graph = json.load(f) - # If input is a dictionary, assume it's a JSON object - elif isinstance(flow, dict): - flow_graph = flow - else: - raise TypeError("Input must be either a file path (str) or a JSON object (dict)") - - graph_data = flow_graph["data"] - if tweaks is not None: - graph_data = process_tweaks(graph_data, tweaks) - nodes = graph_data["nodes"] - edges = graph_data["edges"] - graph = Graph(nodes, edges) - - if build: - langchain_object = asyncio.run(graph.build()) - - if hasattr(langchain_object, "verbose"): - langchain_object.verbose = True - - if hasattr(langchain_object, "return_intermediate_steps"): - # Deactivating until we have a frontend solution - # to display intermediate steps - langchain_object.return_intermediate_steps = False - - fix_memory_inputs(langchain_object) - return langchain_object - - return graph - - def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): raise ValueError("graph_data and tweaks should be dictionaries") @@ -271,3 +225,5 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] logger.warning("Each node should be a dictionary with an 'id' key of type str") return graph_data + return graph_data + return graph_data diff --git a/tests/test_loading.py b/tests/test_loading.py index 94cf34d8e..eb3987e93 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -3,7 +3,7 @@ import json import pytest from langchain.chains.base import Chain from langflow.graph import Graph -from langflow.processing.process import load_flow_from_json +from langflow.processing.load import load_flow_from_json from langflow.utils.payload import get_root_vertex