Merge remote-tracking branch 'origin/dev' into celery
This commit is contained in:
commit
eb17233cde
8 changed files with 174 additions and 38 deletions
38
poetry.lock
generated
38
poetry.lock
generated
|
|
@ -7144,6 +7144,17 @@ files = [
|
|||
{file = "types_cachetools-5.3.0.6-py3-none-any.whl", hash = "sha256:f7f8a25bfe306f2e6bc2ad0a2f949d9e72f2d91036d509c36d3810bf728bc6e1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-passlib"
|
||||
version = "1.7.7.13"
|
||||
description = "Typing stubs for passlib"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-passlib-1.7.7.13.tar.gz", hash = "sha256:f152639f1f2103d7f59a56e2aec5f9398a75a80830991d0d68aac5c2b9c32a77"},
|
||||
{file = "types_passlib-1.7.7.13-py3-none-any.whl", hash = "sha256:414b5ee9c88313357c9261cfcf816509b1e8e4673f0796bd61e9ef249f6fe076"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-pillow"
|
||||
version = "9.5.0.6"
|
||||
|
|
@ -7155,6 +7166,31 @@ files = [
|
|||
{file = "types_Pillow-9.5.0.6-py3-none-any.whl", hash = "sha256:1d238abaa9d529b04941d805b7f4d3f7df30702bb14521ec507617f117406fb4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-pyasn1"
|
||||
version = "0.4.0.6"
|
||||
description = "Typing stubs for pyasn1"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-pyasn1-0.4.0.6.tar.gz", hash = "sha256:8f1965d0b79152f9d1efc89f9aa9a8cdda7cd28b2619df6737c095cbedeff98b"},
|
||||
{file = "types_pyasn1-0.4.0.6-py3-none-any.whl", hash = "sha256:dd5fc818864e63a66cd714be0a7a59a493f4a81b87ee9ac978c41f1eaa9a0cef"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-python-jose"
|
||||
version = "3.3.4.8"
|
||||
description = "Typing stubs for python-jose"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "types-python-jose-3.3.4.8.tar.gz", hash = "sha256:3c316675c3cee059ccb9aff87358254344915239fa7f19cee2787155a7db14ac"},
|
||||
{file = "types_python_jose-3.3.4.8-py3-none-any.whl", hash = "sha256:95592273443b45dc5cc88f7c56aa5a97725428753fb738b794e63ccb4904954e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
types-pyasn1 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "types-pytz"
|
||||
version = "2023.3.0.1"
|
||||
|
|
@ -7909,4 +7945,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.11"
|
||||
content-hash = "c1581fd2c650cc1365793ca33278917dd86bb4527cc956106a82a5c513ecc483"
|
||||
content-hash = "8ad605e7ea30f2819dbc03eac6c2e67576a98d1efa4890912414a7568fc27441"
|
||||
|
|
|
|||
|
|
@ -101,6 +101,8 @@ pandas-stubs = "^2.0.0.230412"
|
|||
types-pillow = "^9.5.0.2"
|
||||
types-appdirs = "^1.4.3.5"
|
||||
types-pyyaml = "^6.0.12.8"
|
||||
types-python-jose = "^3.3.4.8"
|
||||
types-passlib = "^1.7.7.13"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.processing.process import process_tweaks
|
||||
from langflow.processing.process import process_graph_cached, process_tweaks
|
||||
from langflow.services.utils import get_settings_manager
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.worker import process_graph_cached as process_graph_cached_worker
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body
|
||||
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
|
@ -76,6 +75,7 @@ async def process_flow(
|
|||
inputs: Optional[dict] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -95,15 +95,10 @@ async def process_flow(
|
|||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing tweaks: {exc}")
|
||||
# ! This was added just for testing purposes
|
||||
response = process_graph_cached_worker.delay(
|
||||
graph_data=graph_data,
|
||||
inputs=inputs,
|
||||
clear_cache=clear_cache,
|
||||
).get()
|
||||
return ProcessResponse(
|
||||
result=response,
|
||||
response, session_id = process_graph_cached(
|
||||
graph_data, inputs, clear_cache, session_id
|
||||
)
|
||||
return ProcessResponse(result=response, session_id=session_id)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class ProcessResponse(BaseModel):
|
|||
"""Process response schema."""
|
||||
|
||||
result: dict
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Any, Dict, Tuple
|
||||
from langflow.services.cache.utils import memoize_dict
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.logger import logger
|
||||
|
|
@ -15,7 +16,7 @@ def build_langchain_object_with_caching(data_graph):
|
|||
|
||||
|
||||
@memoize_dict(maxsize=10)
|
||||
def build_sorted_vertices_with_caching(data_graph):
|
||||
def build_sorted_vertices_with_caching(data_graph) -> Tuple[Any, Dict]:
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -85,39 +85,55 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]:
|
|||
return list(inputs.values())[0] if len(inputs) == 1 else None
|
||||
|
||||
|
||||
def process_graph_cached(
|
||||
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
|
||||
):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
def get_build_result(data_graph, session_id):
|
||||
# If session_id is provided, load the langchain_object from the session
|
||||
# using build_sorted_vertices_with_caching.get_result_by_session_id
|
||||
# if it returns something different than None, return it
|
||||
# otherwise, build the graph and return the result
|
||||
if session_id:
|
||||
logger.debug(f"Loading LangChain object from session {session_id}")
|
||||
result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id)
|
||||
if result is not None:
|
||||
logger.debug("Loaded LangChain object")
|
||||
return result
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
return build_sorted_vertices_with_caching(data_graph)
|
||||
|
||||
|
||||
def clear_caches_if_needed(clear_cache: bool):
|
||||
if clear_cache:
|
||||
build_sorted_vertices_with_caching.clear_cache()
|
||||
logger.debug("Cleared cache")
|
||||
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
|
||||
logger.debug("Loaded LangChain object")
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
|
||||
# Add artifacts to inputs
|
||||
# artifacts can be documents loaded when building
|
||||
# the flow
|
||||
for (
|
||||
key,
|
||||
value,
|
||||
) in artifacts.items():
|
||||
if key not in inputs or not inputs[key]:
|
||||
inputs[key] = value
|
||||
|
||||
def load_langchain_object(
|
||||
data_graph: Dict[str, Any], session_id: str
|
||||
) -> Tuple[Union[Chain, VectorStore], Dict[str, Any], str]:
|
||||
langchain_object, artifacts = get_build_result(data_graph, session_id)
|
||||
session_id = build_sorted_vertices_with_caching.hash
|
||||
logger.debug("Loaded LangChain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
return langchain_object, artifacts, session_id
|
||||
|
||||
|
||||
def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
|
||||
for key, value in artifacts.items():
|
||||
if key not in inputs or not inputs[key]:
|
||||
inputs[key] = value
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
||||
if isinstance(langchain_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
|
|
@ -130,9 +146,28 @@ def process_graph_cached(
|
|||
raise ValueError(
|
||||
f"Unknown langchain_object type: {type(langchain_object).__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_graph_cached(
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[dict] = None,
|
||||
clear_cache=False,
|
||||
session_id=None,
|
||||
) -> Tuple[Any, str]:
|
||||
clear_caches_if_needed(clear_cache)
|
||||
# If session_id is provided, load the langchain_object from the session
|
||||
# else build the graph and return the result and the new session_id
|
||||
langchain_object, artifacts, session_id = load_langchain_object(
|
||||
data_graph, session_id
|
||||
)
|
||||
processed_inputs = process_inputs(inputs, artifacts)
|
||||
result = generate_result(langchain_object, processed_inputs)
|
||||
|
||||
return result, session_id
|
||||
|
||||
|
||||
def load_flow_from_json(
|
||||
flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True
|
||||
):
|
||||
|
|
|
|||
16
src/backend/langflow/services/cache/utils.py
vendored
16
src/backend/langflow/services/cache/utils.py
vendored
|
|
@ -30,6 +30,7 @@ def create_cache_folder(func):
|
|||
|
||||
def memoize_dict(maxsize=128):
|
||||
cache = OrderedDict()
|
||||
hash_to_key = {} # Mapping from hash to cache key
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
|
|
@ -39,16 +40,29 @@ def memoize_dict(maxsize=128):
|
|||
if key not in cache:
|
||||
result = func(*args, **kwargs)
|
||||
cache[key] = result
|
||||
hash_to_key[hashed] = key # Store the mapping
|
||||
if len(cache) > maxsize:
|
||||
cache.popitem(last=False)
|
||||
oldest_key = next(iter(cache))
|
||||
oldest_hash = oldest_key[1]
|
||||
del cache[oldest_key]
|
||||
del hash_to_key[oldest_hash]
|
||||
else:
|
||||
result = cache[key]
|
||||
|
||||
wrapper.session_id = hashed # Store hash in the wrapper
|
||||
return result
|
||||
|
||||
def clear_cache():
|
||||
cache.clear()
|
||||
hash_to_key.clear()
|
||||
|
||||
def get_result_by_session_id(session_id):
|
||||
key = hash_to_key.get(session_id)
|
||||
return cache.get(key) if key is not None else None
|
||||
|
||||
wrapper.clear_cache = clear_cache # type: ignore
|
||||
wrapper.get_result_by_session_id = get_result_by_session_id # type: ignore
|
||||
wrapper.hash = None
|
||||
wrapper.cache = cache # type: ignore
|
||||
return wrapper
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langflow.processing.process import process_tweaks
|
||||
from langflow.interface.run import build_sorted_vertices_with_caching
|
||||
from langflow.processing.process import load_langchain_object, process_tweaks
|
||||
|
||||
|
||||
def test_no_tweaks():
|
||||
|
|
@ -194,3 +195,54 @@ def test_tweak_not_in_template():
|
|||
tweaks = {"node1": {"param3": 5}}
|
||||
result = process_tweaks(graph_data, tweaks)
|
||||
assert result == graph_data
|
||||
|
||||
|
||||
def test_load_langchain_object_with_cached_session(client, basic_graph_data):
|
||||
# Build the langchain_object once and get the session_id
|
||||
langchain_object1, artifacts1, session_id1 = load_langchain_object(
|
||||
basic_graph_data, None
|
||||
)
|
||||
# Use the same session_id to get the langchain_object again
|
||||
langchain_object2, artifacts2, session_id2 = load_langchain_object(
|
||||
basic_graph_data, session_id1
|
||||
)
|
||||
|
||||
assert session_id1 == session_id2
|
||||
assert id(langchain_object1) == id(langchain_object2)
|
||||
assert artifacts1 == artifacts2
|
||||
|
||||
|
||||
def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
|
||||
# Provide a non-existent session_id
|
||||
langchain_object1, artifacts1, session_id1 = load_langchain_object(
|
||||
basic_graph_data, "non_existent_session"
|
||||
)
|
||||
# Clear the cache
|
||||
build_sorted_vertices_with_caching.clear_cache()
|
||||
# Use the new session_id to get the langchain_object again
|
||||
langchain_object2, artifacts2, session_id2 = load_langchain_object(
|
||||
basic_graph_data, session_id1
|
||||
)
|
||||
|
||||
assert session_id1 == session_id2
|
||||
assert id(langchain_object1) != id(
|
||||
langchain_object2
|
||||
) # Since the cache was cleared, objects should be different
|
||||
|
||||
|
||||
def test_load_langchain_object_without_session_id(client, basic_graph_data):
|
||||
# Build the langchain_object without providing a session_id
|
||||
langchain_object1, artifacts1, session_id1 = load_langchain_object(
|
||||
basic_graph_data, None
|
||||
)
|
||||
# Build the langchain_object again without providing a session_id
|
||||
langchain_object2, artifacts2, session_id2 = load_langchain_object(
|
||||
basic_graph_data, None
|
||||
)
|
||||
|
||||
assert session_id1 == session_id2
|
||||
|
||||
assert id(langchain_object1) == id(
|
||||
langchain_object2
|
||||
) # Since no session_id was provided, the hash will be based on the graph_data
|
||||
assert artifacts1 == artifacts2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue