Merge remote-tracking branch 'origin/dev' into celery

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-16 21:49:42 -03:00
commit eb17233cde
8 changed files with 174 additions and 38 deletions

38
poetry.lock generated
View file

@ -7144,6 +7144,17 @@ files = [
{file = "types_cachetools-5.3.0.6-py3-none-any.whl", hash = "sha256:f7f8a25bfe306f2e6bc2ad0a2f949d9e72f2d91036d509c36d3810bf728bc6e1"},
]
[[package]]
name = "types-passlib"
version = "1.7.7.13"
description = "Typing stubs for passlib"
optional = false
python-versions = "*"
files = [
{file = "types-passlib-1.7.7.13.tar.gz", hash = "sha256:f152639f1f2103d7f59a56e2aec5f9398a75a80830991d0d68aac5c2b9c32a77"},
{file = "types_passlib-1.7.7.13-py3-none-any.whl", hash = "sha256:414b5ee9c88313357c9261cfcf816509b1e8e4673f0796bd61e9ef249f6fe076"},
]
[[package]]
name = "types-pillow"
version = "9.5.0.6"
@ -7155,6 +7166,31 @@ files = [
{file = "types_Pillow-9.5.0.6-py3-none-any.whl", hash = "sha256:1d238abaa9d529b04941d805b7f4d3f7df30702bb14521ec507617f117406fb4"},
]
[[package]]
name = "types-pyasn1"
version = "0.4.0.6"
description = "Typing stubs for pyasn1"
optional = false
python-versions = "*"
files = [
{file = "types-pyasn1-0.4.0.6.tar.gz", hash = "sha256:8f1965d0b79152f9d1efc89f9aa9a8cdda7cd28b2619df6737c095cbedeff98b"},
{file = "types_pyasn1-0.4.0.6-py3-none-any.whl", hash = "sha256:dd5fc818864e63a66cd714be0a7a59a493f4a81b87ee9ac978c41f1eaa9a0cef"},
]
[[package]]
name = "types-python-jose"
version = "3.3.4.8"
description = "Typing stubs for python-jose"
optional = false
python-versions = "*"
files = [
{file = "types-python-jose-3.3.4.8.tar.gz", hash = "sha256:3c316675c3cee059ccb9aff87358254344915239fa7f19cee2787155a7db14ac"},
{file = "types_python_jose-3.3.4.8-py3-none-any.whl", hash = "sha256:95592273443b45dc5cc88f7c56aa5a97725428753fb738b794e63ccb4904954e"},
]
[package.dependencies]
types-pyasn1 = "*"
[[package]]
name = "types-pytz"
version = "2023.3.0.1"
@ -7909,4 +7945,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.11"
content-hash = "c1581fd2c650cc1365793ca33278917dd86bb4527cc956106a82a5c513ecc483"
content-hash = "8ad605e7ea30f2819dbc03eac6c2e67576a98d1efa4890912414a7568fc27441"

View file

@ -101,6 +101,8 @@ pandas-stubs = "^2.0.0.230412"
types-pillow = "^9.5.0.2"
types-appdirs = "^1.4.3.5"
types-pyyaml = "^6.0.12.8"
types-python-jose = "^3.3.4.8"
types-passlib = "^1.7.7.13"
[tool.poetry.extras]

View file

@ -1,12 +1,11 @@
from http import HTTPStatus
from typing import Annotated, Optional
from typing import Annotated, Optional, Union
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.processing.process import process_tweaks
from langflow.processing.process import process_graph_cached, process_tweaks
from langflow.services.utils import get_settings_manager
from langflow.utils.logger import logger
from langflow.worker import process_graph_cached as process_graph_cached_worker
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body
from langflow.interface.custom.custom_component import CustomComponent
@ -76,6 +75,7 @@ async def process_flow(
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
session: Session = Depends(get_session),
):
"""
@ -95,15 +95,10 @@ async def process_flow(
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
# ! This was added just for testing purposes
response = process_graph_cached_worker.delay(
graph_data=graph_data,
inputs=inputs,
clear_cache=clear_cache,
).get()
return ProcessResponse(
result=response,
response, session_id = process_graph_cached(
graph_data, inputs, clear_cache, session_id
)
return ProcessResponse(result=response, session_id=session_id)
except Exception as e:
# Log stack trace
logger.exception(e)

View file

@ -47,6 +47,7 @@ class ProcessResponse(BaseModel):
"""Process response schema."""
result: dict
session_id: Optional[str] = None
class ChatMessage(BaseModel):

View file

@ -1,3 +1,4 @@
from typing import Any, Dict, Tuple
from langflow.services.cache.utils import memoize_dict
from langflow.graph import Graph
from langflow.utils.logger import logger
@ -15,7 +16,7 @@ def build_langchain_object_with_caching(data_graph):
@memoize_dict(maxsize=10)
def build_sorted_vertices_with_caching(data_graph):
def build_sorted_vertices_with_caching(data_graph) -> Tuple[Any, Dict]:
"""
Build langchain object from data_graph.
"""

View file

@ -85,39 +85,55 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]:
return list(inputs.values())[0] if len(inputs) == 1 else None
def process_graph_cached(
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
def get_build_result(data_graph, session_id):
# If session_id is provided, load the langchain_object from the session
# using build_sorted_vertices_with_caching.get_result_by_session_id
# if it returns something different than None, return it
# otherwise, build the graph and return the result
if session_id:
logger.debug(f"Loading LangChain object from session {session_id}")
result = build_sorted_vertices_with_caching.get_result_by_session_id(session_id)
if result is not None:
logger.debug("Loaded LangChain object")
return result
logger.debug("Building langchain object")
return build_sorted_vertices_with_caching(data_graph)
def clear_caches_if_needed(clear_cache: bool):
if clear_cache:
build_sorted_vertices_with_caching.clear_cache()
logger.debug("Cleared cache")
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
logger.debug("Loaded LangChain object")
if inputs is None:
inputs = {}
# Add artifacts to inputs
# artifacts can be documents loaded when building
# the flow
for (
key,
value,
) in artifacts.items():
if key not in inputs or not inputs[key]:
inputs[key] = value
def load_langchain_object(
data_graph: Dict[str, Any], session_id: str
) -> Tuple[Union[Chain, VectorStore], Dict[str, Any], str]:
langchain_object, artifacts = get_build_result(data_graph, session_id)
session_id = build_sorted_vertices_with_caching.hash
logger.debug("Loaded LangChain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
return langchain_object, artifacts, session_id
def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
if inputs is None:
inputs = {}
for key, value in artifacts.items():
if key not in inputs or not inputs[key]:
inputs[key] = value
return inputs
def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
if isinstance(langchain_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -130,9 +146,28 @@ def process_graph_cached(
raise ValueError(
f"Unknown langchain_object type: {type(langchain_object).__name__}"
)
return result
def process_graph_cached(
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
clear_cache=False,
session_id=None,
) -> Tuple[Any, str]:
clear_caches_if_needed(clear_cache)
# If session_id is provided, load the langchain_object from the session
# else build the graph and return the result and the new session_id
langchain_object, artifacts, session_id = load_langchain_object(
data_graph, session_id
)
processed_inputs = process_inputs(inputs, artifacts)
result = generate_result(langchain_object, processed_inputs)
return result, session_id
def load_flow_from_json(
flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True
):

View file

@ -30,6 +30,7 @@ def create_cache_folder(func):
def memoize_dict(maxsize=128):
cache = OrderedDict()
hash_to_key = {} # Mapping from hash to cache key
def decorator(func):
@functools.wraps(func)
@ -39,16 +40,29 @@ def memoize_dict(maxsize=128):
if key not in cache:
result = func(*args, **kwargs)
cache[key] = result
hash_to_key[hashed] = key # Store the mapping
if len(cache) > maxsize:
cache.popitem(last=False)
oldest_key = next(iter(cache))
oldest_hash = oldest_key[1]
del cache[oldest_key]
del hash_to_key[oldest_hash]
else:
result = cache[key]
wrapper.session_id = hashed # Store hash in the wrapper
return result
def clear_cache():
cache.clear()
hash_to_key.clear()
def get_result_by_session_id(session_id):
key = hash_to_key.get(session_id)
return cache.get(key) if key is not None else None
wrapper.clear_cache = clear_cache # type: ignore
wrapper.get_result_by_session_id = get_result_by_session_id # type: ignore
wrapper.hash = None
wrapper.cache = cache # type: ignore
return wrapper

View file

@ -1,4 +1,5 @@
from langflow.processing.process import process_tweaks
from langflow.interface.run import build_sorted_vertices_with_caching
from langflow.processing.process import load_langchain_object, process_tweaks
def test_no_tweaks():
@ -194,3 +195,54 @@ def test_tweak_not_in_template():
tweaks = {"node1": {"param3": 5}}
result = process_tweaks(graph_data, tweaks)
assert result == graph_data
def test_load_langchain_object_with_cached_session(client, basic_graph_data):
# Build the langchain_object once and get the session_id
langchain_object1, artifacts1, session_id1 = load_langchain_object(
basic_graph_data, None
)
# Use the same session_id to get the langchain_object again
langchain_object2, artifacts2, session_id2 = load_langchain_object(
basic_graph_data, session_id1
)
assert session_id1 == session_id2
assert id(langchain_object1) == id(langchain_object2)
assert artifacts1 == artifacts2
def test_load_langchain_object_with_no_cached_session(client, basic_graph_data):
# Provide a non-existent session_id
langchain_object1, artifacts1, session_id1 = load_langchain_object(
basic_graph_data, "non_existent_session"
)
# Clear the cache
build_sorted_vertices_with_caching.clear_cache()
# Use the new session_id to get the langchain_object again
langchain_object2, artifacts2, session_id2 = load_langchain_object(
basic_graph_data, session_id1
)
assert session_id1 == session_id2
assert id(langchain_object1) != id(
langchain_object2
) # Since the cache was cleared, objects should be different
def test_load_langchain_object_without_session_id(client, basic_graph_data):
# Build the langchain_object without providing a session_id
langchain_object1, artifacts1, session_id1 = load_langchain_object(
basic_graph_data, None
)
# Build the langchain_object again without providing a session_id
langchain_object2, artifacts2, session_id2 = load_langchain_object(
basic_graph_data, None
)
assert session_id1 == session_id2
assert id(langchain_object1) == id(
langchain_object2
) # Since no session_id was provided, the hash will be based on the graph_data
assert artifacts1 == artifacts2