diff --git a/poetry.lock b/poetry.lock index f1ff37f2b..4fecda219 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7144,6 +7144,17 @@ files = [ {file = "types_cachetools-5.3.0.6-py3-none-any.whl", hash = "sha256:f7f8a25bfe306f2e6bc2ad0a2f949d9e72f2d91036d509c36d3810bf728bc6e1"}, ] +[[package]] +name = "types-passlib" +version = "1.7.7.13" +description = "Typing stubs for passlib" +optional = false +python-versions = "*" +files = [ + {file = "types-passlib-1.7.7.13.tar.gz", hash = "sha256:f152639f1f2103d7f59a56e2aec5f9398a75a80830991d0d68aac5c2b9c32a77"}, + {file = "types_passlib-1.7.7.13-py3-none-any.whl", hash = "sha256:414b5ee9c88313357c9261cfcf816509b1e8e4673f0796bd61e9ef249f6fe076"}, +] + [[package]] name = "types-pillow" version = "9.5.0.6" @@ -7155,6 +7166,31 @@ files = [ {file = "types_Pillow-9.5.0.6-py3-none-any.whl", hash = "sha256:1d238abaa9d529b04941d805b7f4d3f7df30702bb14521ec507617f117406fb4"}, ] +[[package]] +name = "types-pyasn1" +version = "0.4.0.6" +description = "Typing stubs for pyasn1" +optional = false +python-versions = "*" +files = [ + {file = "types-pyasn1-0.4.0.6.tar.gz", hash = "sha256:8f1965d0b79152f9d1efc89f9aa9a8cdda7cd28b2619df6737c095cbedeff98b"}, + {file = "types_pyasn1-0.4.0.6-py3-none-any.whl", hash = "sha256:dd5fc818864e63a66cd714be0a7a59a493f4a81b87ee9ac978c41f1eaa9a0cef"}, +] + +[[package]] +name = "types-python-jose" +version = "3.3.4.8" +description = "Typing stubs for python-jose" +optional = false +python-versions = "*" +files = [ + {file = "types-python-jose-3.3.4.8.tar.gz", hash = "sha256:3c316675c3cee059ccb9aff87358254344915239fa7f19cee2787155a7db14ac"}, + {file = "types_python_jose-3.3.4.8-py3-none-any.whl", hash = "sha256:95592273443b45dc5cc88f7c56aa5a97725428753fb738b794e63ccb4904954e"}, +] + +[package.dependencies] +types-pyasn1 = "*" + [[package]] name = "types-pytz" version = "2023.3.0.1" @@ -7909,4 +7945,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.11" -content-hash = "c1581fd2c650cc1365793ca33278917dd86bb4527cc956106a82a5c513ecc483" +content-hash = "8ad605e7ea30f2819dbc03eac6c2e67576a98d1efa4890912414a7568fc27441" diff --git a/pyproject.toml b/pyproject.toml index 282b17e61..4e7a34ac0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,8 @@ pandas-stubs = "^2.0.0.230412" types-pillow = "^9.5.0.2" types-appdirs = "^1.4.3.5" types-pyyaml = "^6.0.12.8" +types-python-jose = "^3.3.4.8" +types-passlib = "^1.7.7.13" [tool.poetry.extras] diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index f531fea7a..a6b85dde2 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,12 +1,11 @@ from http import HTTPStatus -from typing import Annotated, Optional +from typing import Annotated, Optional, Union from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow -from langflow.processing.process import process_tweaks +from langflow.processing.process import process_graph_cached, process_tweaks from langflow.services.utils import get_settings_manager from langflow.utils.logger import logger -from langflow.worker import process_graph_cached as process_graph_cached_worker from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body from langflow.interface.custom.custom_component import CustomComponent @@ -76,6 +75,7 @@ async def process_flow( inputs: Optional[dict] = None, tweaks: Optional[dict] = None, clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 + session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 session: Session = Depends(get_session), ): """ @@ -95,15 +95,10 @@ async def process_flow( graph_data = process_tweaks(graph_data, tweaks) except Exception as exc: logger.error(f"Error processing tweaks: {exc}") - # ! This was added just for testing purposes - response = process_graph_cached_worker.delay( - graph_data=graph_data, - inputs=inputs, - clear_cache=clear_cache, - ).get() - return ProcessResponse( - result=response, + response, session_id = process_graph_cached( + graph_data, inputs, clear_cache, session_id ) + return ProcessResponse(result=response, session_id=session_id) except Exception as e: # Log stack trace logger.exception(e) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 776e90034..65bf64dca 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -47,6 +47,7 @@ class ProcessResponse(BaseModel): """Process response schema.""" result: dict + session_id: Optional[str] = None class ChatMessage(BaseModel): diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index cb0573bf7..42cea0e98 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, Tuple from langflow.services.cache.utils import memoize_dict from langflow.graph import Graph from langflow.utils.logger import logger @@ -15,7 +16,7 @@ def build_langchain_object_with_caching(data_graph): @memoize_dict(maxsize=10) -def build_sorted_vertices_with_caching(data_graph): +def build_sorted_vertices_with_caching(data_graph) -> Tuple[Any, Dict]: """ Build langchain object from data_graph. """ diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 8cefb1f44..396135e16 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -85,39 +85,55 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]: return list(inputs.values())[0] if len(inputs) == 1 else None -def process_graph_cached( - data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False -): - """ - Process graph by extracting input variables and replacing ZeroShotPrompt - with PromptTemplate,then run the graph and return the result and thought. - """ - # Load langchain object +def get_build_result(data_graph, session_id): + # If session_id is provided, load the langchain_object from the session + # using build_sorted_vertices_with_caching.get_result_by_session_id + # if it returns something different than None, return it + # otherwise, build the graph and return the result + if session_id: + logger.debug(f"Loading LangChain object from session {session_id}") + result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id) + if result is not None: + logger.debug("Loaded LangChain object") + return result + + logger.debug("Building langchain object") + return build_sorted_vertices_with_caching(data_graph) + + +def clear_caches_if_needed(clear_cache: bool): if clear_cache: build_sorted_vertices_with_caching.clear_cache() logger.debug("Cleared cache") - langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph) - logger.debug("Loaded LangChain object") - if inputs is None: - inputs = {} - # Add artifacts to inputs - # artifacts can be documents loaded when building - # the flow - for ( - key, - value, - ) in artifacts.items(): - if key not in inputs or not inputs[key]: - inputs[key] = value + +def load_langchain_object( + data_graph: Dict[str, Any], session_id: str +) -> Tuple[Union[Chain, VectorStore], Dict[str, Any], str]: + langchain_object, artifacts = get_build_result(data_graph, session_id) + session_id = build_sorted_vertices_with_caching.hash + logger.debug("Loaded LangChain object") if langchain_object is None: - # Raise user facing error raise ValueError( "There was an error loading the langchain_object. Please, check all the nodes and try again." ) - # Generate result and thought + return langchain_object, artifacts, session_id + + +def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict: + if inputs is None: + inputs = {} + + for key, value in artifacts.items(): + if key not in inputs or not inputs[key]: + inputs[key] = value + + return inputs + + +def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict): if isinstance(langchain_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -130,9 +146,28 @@ def process_graph_cached( raise ValueError( f"Unknown langchain_object type: {type(langchain_object).__name__}" ) + return result +def process_graph_cached( + data_graph: Dict[str, Any], + inputs: Optional[dict] = None, + clear_cache=False, + session_id=None, +) -> Tuple[Any, str]: + clear_caches_if_needed(clear_cache) + # If session_id is provided, load the langchain_object from the session + # else build the graph and return the result and the new session_id + langchain_object, artifacts, session_id = load_langchain_object( + data_graph, session_id + ) + processed_inputs = process_inputs(inputs, artifacts) + result = generate_result(langchain_object, processed_inputs) + + return result, session_id + + def load_flow_from_json( flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True ): diff --git a/src/backend/langflow/services/cache/utils.py b/src/backend/langflow/services/cache/utils.py index 3deabe9f4..2333eb5f4 100644 --- a/src/backend/langflow/services/cache/utils.py +++ b/src/backend/langflow/services/cache/utils.py @@ -30,6 +30,7 @@ def create_cache_folder(func): def memoize_dict(maxsize=128): cache = OrderedDict() + hash_to_key = {} # Mapping from hash to cache key def decorator(func): @functools.wraps(func) @@ -39,16 +40,29 @@ def memoize_dict(maxsize=128): if key not in cache: result = func(*args, **kwargs) cache[key] = result + hash_to_key[hashed] = key # Store the mapping if len(cache) > maxsize: - cache.popitem(last=False) + oldest_key = next(iter(cache)) + oldest_hash = oldest_key[1] + del cache[oldest_key] + del hash_to_key[oldest_hash] else: result = cache[key] + + wrapper.session_id = hashed # Store hash in the wrapper return result def clear_cache(): cache.clear() + hash_to_key.clear() + + def get_result_by_session_id(session_id): + key = hash_to_key.get(session_id) + return cache.get(key) if key is not None else None wrapper.clear_cache = clear_cache # type: ignore + wrapper.get_result_by_session_id = get_result_by_session_id # type: ignore + wrapper.hash = None wrapper.cache = cache # type: ignore return wrapper diff --git a/tests/test_process.py b/tests/test_process.py index 2d0c349ce..a0d91b5df 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,4 +1,5 @@ -from langflow.processing.process import process_tweaks +from langflow.interface.run import build_sorted_vertices_with_caching +from langflow.processing.process import load_langchain_object, process_tweaks def test_no_tweaks(): @@ -194,3 +195,54 @@ def test_tweak_not_in_template(): tweaks = {"node1": {"param3": 5}} result = process_tweaks(graph_data, tweaks) assert result == graph_data + + +def test_load_langchain_object_with_cached_session(client, basic_graph_data): + # Build the langchain_object once and get the session_id + langchain_object1, artifacts1, session_id1 = load_langchain_object( + basic_graph_data, None + ) + # Use the same session_id to get the langchain_object again + langchain_object2, artifacts2, session_id2 = load_langchain_object( + basic_graph_data, session_id1 + ) + + assert session_id1 == session_id2 + assert id(langchain_object1) == id(langchain_object2) + assert artifacts1 == artifacts2 + + +def test_load_langchain_object_with_no_cached_session(client, basic_graph_data): + # Provide a non-existent session_id + langchain_object1, artifacts1, session_id1 = load_langchain_object( + basic_graph_data, "non_existent_session" + ) + # Clear the cache + build_sorted_vertices_with_caching.clear_cache() + # Use the new session_id to get the langchain_object again + langchain_object2, artifacts2, session_id2 = load_langchain_object( + basic_graph_data, session_id1 + ) + + assert session_id1 == session_id2 + assert id(langchain_object1) != id( + langchain_object2 + ) # Since the cache was cleared, objects should be different + + +def test_load_langchain_object_without_session_id(client, basic_graph_data): + # Build the langchain_object without providing a session_id + langchain_object1, artifacts1, session_id1 = load_langchain_object( + basic_graph_data, None + ) + # Build the langchain_object again without providing a session_id + langchain_object2, artifacts2, session_id2 = load_langchain_object( + basic_graph_data, None + ) + + assert session_id1 == session_id2 + + assert id(langchain_object1) == id( + langchain_object2 + ) # Since no session_id was provided, the hash will be based on the graph_data + assert artifacts1 == artifacts2