From aebab3f7e3d13d8d13eba3ec5e9f3b4d343ec3ac Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 20 Feb 2024 13:17:04 -0300 Subject: [PATCH] Add new build_and_cache_graph in utils.py --- src/backend/langflow/api/utils.py | 60 +++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/src/backend/langflow/api/utils.py b/src/backend/langflow/api/utils.py index d2d078b97..f419ea269 100644 --- a/src/backend/langflow/api/utils.py +++ b/src/backend/langflow/api/utils.py @@ -1,13 +1,16 @@ import warnings from pathlib import Path -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from fastapi import HTTPException from langchain_core.documents import Document +from langflow.graph.graph.base import Graph +from langflow.services.chat.service import ChatService +from langflow.services.database.models.flow import Flow +from langflow.services.store.schema import StoreComponentCreate from platformdirs import user_cache_dir from pydantic import BaseModel - -from langflow.services.store.schema import StoreComponentCreate +from sqlmodel import Session if TYPE_CHECKING: from langflow.services.database.models.flow.model import Flow @@ -17,7 +20,9 @@ API_WORDS = ["api", "key", "token"] def has_api_terms(word: str): - return "api" in word and ("key" in word or ("token" in word and "tokens" not in word)) + return "api" in word and ( + "key" in word or ("token" in word and "tokens" not in word) + ) def remove_api_keys(flow: dict): @@ -27,7 +32,11 @@ def remove_api_keys(flow: dict): node_data = node.get("data").get("node") template = node_data.get("template") for value in template.values(): - if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"): + if ( + isinstance(value, dict) + and has_api_terms(value["name"]) + and value.get("password") + ): value["value"] = None return flow @@ -48,7 +57,9 @@ def build_input_keys_response(langchain_object, artifacts): input_keys_response["input_keys"][key] = value # If the object has memory, that memory will have a memory_variables attribute # memory variables should be removed from the input keys - if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"): + if hasattr(langchain_object, "memory") and hasattr( + langchain_object.memory, "memory_variables" + ): # Remove memory variables from input keys input_keys_response["input_keys"] = { key: value @@ -58,7 +69,9 @@ def build_input_keys_response(langchain_object, artifacts): # Add memory variables to memory_keys input_keys_response["memory_keys"] = langchain_object.memory.memory_variables - if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"): + if hasattr(langchain_object, "prompt") and hasattr( + langchain_object.prompt, "template" + ): input_keys_response["template"] = langchain_object.prompt.template return input_keys_response @@ -93,7 +106,11 @@ def raw_frontend_data_is_valid(raw_frontend_data): def is_valid_data(frontend_node, raw_frontend_data): """Check if the data is valid for processing.""" - return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data) + return ( + frontend_node + and "template" in frontend_node + and raw_frontend_data_is_valid(raw_frontend_data) + ) def update_template_values(frontend_template, raw_template): @@ -133,7 +150,9 @@ def get_file_path_value(file_path): # If the path is not in the cache dir, return empty string # This is to prevent access to files outside the cache dir # If the path is not a file, return empty string - if not path.exists() or not str(path).startswith(user_cache_dir("langflow", "langflow")): + if not path.exists() or not str(path).startswith( + user_cache_dir("langflow", "langflow") + ): return "" return file_path @@ -164,7 +183,9 @@ async def check_langflow_version(component: StoreComponentCreate): langflow_version = get_lf_version_from_pypi() if langflow_version is None: - raise HTTPException(status_code=500, detail="Unable to verify the latest version of Langflow") + raise HTTPException( + status_code=500, detail="Unable to verify the latest version of Langflow" + ) elif langflow_version != component.last_tested_version: warnings.warn( f"Your version of Langflow ({component.last_tested_version}) is outdated. " @@ -206,3 +227,22 @@ def serialize_field(value): elif isinstance(value, str): return {"result": value} return value + + +def build_and_cache_graph( + flow_id: str, + session: Session, + chat_service: "ChatService", + graph: Optional[Graph] = None, +): + """Build and cache the graph.""" + flow: Flow = session.get(Flow, flow_id) + if not flow or not flow.data: + raise ValueError("Invalid flow ID") + other_graph = Graph.from_payload(flow.data) + if graph is None: + graph = other_graph + else: + graph = graph.update(other_graph) + chat_service.set_cache(flow_id, graph) + return graph