Add new build_and_cache_graph in utils.py
This commit is contained in:
parent
b90ff00cf4
commit
aebab3f7e3
1 changed files with 50 additions and 10 deletions
|
|
@ -1,13 +1,16 @@
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from langchain_core.documents import Document
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.store.schema import StoreComponentCreate
|
||||
from platformdirs import user_cache_dir
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.services.store.schema import StoreComponentCreate
|
||||
from sqlmodel import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
|
|
@ -17,7 +20,9 @@ API_WORDS = ["api", "key", "token"]
|
|||
|
||||
|
||||
def has_api_terms(word: str):
|
||||
return "api" in word and ("key" in word or ("token" in word and "tokens" not in word))
|
||||
return "api" in word and (
|
||||
"key" in word or ("token" in word and "tokens" not in word)
|
||||
)
|
||||
|
||||
|
||||
def remove_api_keys(flow: dict):
|
||||
|
|
@ -27,7 +32,11 @@ def remove_api_keys(flow: dict):
|
|||
node_data = node.get("data").get("node")
|
||||
template = node_data.get("template")
|
||||
for value in template.values():
|
||||
if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"):
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and has_api_terms(value["name"])
|
||||
and value.get("password")
|
||||
):
|
||||
value["value"] = None
|
||||
|
||||
return flow
|
||||
|
|
@ -48,7 +57,9 @@ def build_input_keys_response(langchain_object, artifacts):
|
|||
input_keys_response["input_keys"][key] = value
|
||||
# If the object has memory, that memory will have a memory_variables attribute
|
||||
# memory variables should be removed from the input keys
|
||||
if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"):
|
||||
if hasattr(langchain_object, "memory") and hasattr(
|
||||
langchain_object.memory, "memory_variables"
|
||||
):
|
||||
# Remove memory variables from input keys
|
||||
input_keys_response["input_keys"] = {
|
||||
key: value
|
||||
|
|
@ -58,7 +69,9 @@ def build_input_keys_response(langchain_object, artifacts):
|
|||
# Add memory variables to memory_keys
|
||||
input_keys_response["memory_keys"] = langchain_object.memory.memory_variables
|
||||
|
||||
if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"):
|
||||
if hasattr(langchain_object, "prompt") and hasattr(
|
||||
langchain_object.prompt, "template"
|
||||
):
|
||||
input_keys_response["template"] = langchain_object.prompt.template
|
||||
|
||||
return input_keys_response
|
||||
|
|
@ -93,7 +106,11 @@ def raw_frontend_data_is_valid(raw_frontend_data):
|
|||
def is_valid_data(frontend_node, raw_frontend_data):
|
||||
"""Check if the data is valid for processing."""
|
||||
|
||||
return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data)
|
||||
return (
|
||||
frontend_node
|
||||
and "template" in frontend_node
|
||||
and raw_frontend_data_is_valid(raw_frontend_data)
|
||||
)
|
||||
|
||||
|
||||
def update_template_values(frontend_template, raw_template):
|
||||
|
|
@ -133,7 +150,9 @@ def get_file_path_value(file_path):
|
|||
# If the path is not in the cache dir, return empty string
|
||||
# This is to prevent access to files outside the cache dir
|
||||
# If the path is not a file, return empty string
|
||||
if not path.exists() or not str(path).startswith(user_cache_dir("langflow", "langflow")):
|
||||
if not path.exists() or not str(path).startswith(
|
||||
user_cache_dir("langflow", "langflow")
|
||||
):
|
||||
return ""
|
||||
return file_path
|
||||
|
||||
|
|
@ -164,7 +183,9 @@ async def check_langflow_version(component: StoreComponentCreate):
|
|||
|
||||
langflow_version = get_lf_version_from_pypi()
|
||||
if langflow_version is None:
|
||||
raise HTTPException(status_code=500, detail="Unable to verify the latest version of Langflow")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Unable to verify the latest version of Langflow"
|
||||
)
|
||||
elif langflow_version != component.last_tested_version:
|
||||
warnings.warn(
|
||||
f"Your version of Langflow ({component.last_tested_version}) is outdated. "
|
||||
|
|
@ -206,3 +227,22 @@ def serialize_field(value):
|
|||
elif isinstance(value, str):
|
||||
return {"result": value}
|
||||
return value
|
||||
|
||||
|
||||
def build_and_cache_graph(
|
||||
flow_id: str,
|
||||
session: Session,
|
||||
chat_service: "ChatService",
|
||||
graph: Optional[Graph] = None,
|
||||
):
|
||||
"""Build and cache the graph."""
|
||||
flow: Flow = session.get(Flow, flow_id)
|
||||
if not flow or not flow.data:
|
||||
raise ValueError("Invalid flow ID")
|
||||
other_graph = Graph.from_payload(flow.data)
|
||||
if graph is None:
|
||||
graph = other_graph
|
||||
else:
|
||||
graph = graph.update(other_graph)
|
||||
chat_service.set_cache(flow_id, graph)
|
||||
return graph
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue