Refactor cache service and fix async issues (#1512)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-09 22:55:56 -03:00 committed by GitHub
commit 67bccdc753
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 466 additions and 171 deletions

View file

@ -197,7 +197,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"
def build_and_cache_graph(
async def build_and_cache_graph(
flow_id: str,
session: Session,
chat_service: "ChatService",
@ -212,7 +212,7 @@ def build_and_cache_graph(
graph = other_graph
else:
graph = graph.update(other_graph)
chat_service.set_cache(flow_id, graph)
await chat_service.set_cache(flow_id, graph)
return graph

View file

@ -58,9 +58,9 @@ async def get_vertices(
try:
# First, we need to check if the flow_id is in the cache
graph = None
if cache := chat_service.get_cache(flow_id):
if cache := await chat_service.get_cache(flow_id):
graph = cache.get("result")
graph = build_and_cache_graph(flow_id, session, chat_service, graph)
graph = await build_and_cache_graph(flow_id, session, chat_service, graph)
if stop_component_id or start_component_id:
try:
vertices = graph.sort_vertices(stop_component_id, start_component_id)
@ -98,11 +98,11 @@ async def build_vertex(
next_vertices_ids = []
try:
start_time = time.perf_counter()
cache = chat_service.get_cache(flow_id)
cache = await chat_service.get_cache(flow_id)
if not cache:
# If there's no cache
logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}")
graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
graph = await build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
else:
graph = cache.get("result")
result_data_response = ResultDataResponse(results={})
@ -121,8 +121,11 @@ async def build_vertex(
artifacts = vertex.artifacts
else:
raise ValueError(f"No result found for vertex {vertex_id}")
next_vertices_ids = vertex.successors_ids
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
async with chat_service._cache_locks[flow_id] as lock:
graph.remove_from_predecessors(vertex_id)
next_vertices_ids = vertex.successors_ids
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
result_data_response = ResultDataResponse(**result_dict.model_dump())
@ -134,7 +137,7 @@ async def build_vertex(
artifacts = {}
# If there's an error building the vertex
# we need to clear the cache
chat_service.clear_cache(flow_id)
await chat_service.clear_cache(flow_id)
# Log the vertex build
if not vertex.will_stream:
@ -157,7 +160,7 @@ async def build_vertex(
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
chat_service.set_cache(flow_id, graph)
await chat_service.set_cache(flow_id, graph)
# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex

View file

@ -1,5 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from langflow.api.v1.schemas import Token
from langflow.services.auth.utils import (
authenticate_user,
@ -8,7 +10,6 @@ from langflow.services.auth.utils import (
create_user_tokens,
)
from langflow.services.deps import get_session, get_settings_service
from sqlmodel import Session
router = APIRouter(tags=["Login"])

View file

@ -104,7 +104,8 @@ def parse_text_file_to_record(file_path: str, silent_errors: bool) -> Optional[R
elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
text = yaml.safe_load(text)
elif file_path.endswith(".xml"):
text = ET.fromstring(text)
xml_element = ET.fromstring(text)
text = ET.tostring(xml_element, encoding="unicode")
except Exception as e:
if not silent_errors:
raise ValueError(f"Error loading file {file_path}: {e}") from e

View file

@ -26,6 +26,7 @@ class ChatComponent(CustomComponent):
"session_id": {
"display_name": "Session ID",
"info": "If provided, the message will be stored in the memory.",
"advanced": True,
},
"return_record": {
"display_name": "Return Record",

View file

@ -0,0 +1,70 @@
# from typing import Dict, List
# import dspy
# from langflow import CustomComponent
# from langflow.field_typing import Text
# class ReActAgentComponent(CustomComponent):
# display_name = "ReAct Agent"
# description = "A component to create a ReAct Agent."
# icon = "user-secret"
# def build_config(self):
# return {
# "input_value": {
# "display_name": "Input",
# "input_types": ["Text"],
# "info": "The input value for the ReAct Agent.",
# },
# "instructions": {
# "display_name": "Instructions",
# "info": "The Prompt.",
# },
# "inputs": {
# "display_name": "Inputs",
# "info": "The Name and Description of the Input Fields.",
# },
# "outputs": {
# "display_name": "Outputs",
# "info": "The Name and Description of the Output Fields.",
# },
# }
# def build(
# self,
# input_value: List[dict],
# instructions: Text,
# inputs: List[dict],
# outputs: List[Dict],
# ) -> Text:
# # inputs is a list of dictionaries where the key is the name of the input
# # and the value is the description of the input
# input_fields = (
# {}
# ) # dict[str, FieldInfo] InputField and OutputField are subclasses of pydantic.Field
# for input_dict in inputs:
# for name, description in input_dict.items():
# prefix = name if ":" in name else f"{name}:"
# input_fields[name] = dspy.InputField(
# prefix=prefix, description=description
# )
# output_fields = {} # dict[str, FieldInfo]
# for output_dict in outputs:
# for name, description in output_dict.items():
# prefix = name if ":" in name else f"{name}:"
# output_fields[name] = dspy.OutputField(
# prefix=prefix, description=description
# )
# signature = dspy.make_signature(inputs, instructions=instructions)
# agent = dspy.ReAct(
# signature=signature,
# )
# inputs_dict = {}
# for input_dict in input_value:
# inputs_dict.update(input_dict)
# result = agent(inputs_dict)

View file

@ -106,7 +106,7 @@ class APIRequest(CustomComponent):
bodies = [body.data]
if len(urls) != len(bodies):
# add bodies with None
bodies += [None] * (len(urls) - len(bodies))
bodies += [None] * (len(urls) - len(bodies)) # type: ignore
async with httpx.AsyncClient() as client:
results = await asyncio.gather(
*[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)]

View file

@ -3,12 +3,10 @@ from .ExtractDataFromRecord import ExtractKeyFromRecordComponent
from .GetNotified import GetNotifiedComponent
from .ListFlows import ListFlowsComponent
from .MergeRecords import MergeRecordsComponent
from .MessageHistory import MessageHistoryComponent
from .Notify import NotifyComponent
from .RunFlow import RunFlowComponent
from .RunnableExecutor import RunnableExecComponent
from .SQLExecutor import SQLExecutorComponent
from .TextToRecord import TextToRecordComponent
__all__ = [
"ClearMessageHistoryComponent",

View file

@ -1,8 +1,10 @@
from .CustomComponent import Component
from .DocumentToRecord import DocumentToRecordComponent
from .IDGenerator import UUIDGeneratorComponent
from .MessageHistory import MessageHistoryComponent
from .PythonFunction import PythonFunctionComponent
from .RecordsAsText import RecordsAsTextComponent
from .TextToRecord import TextToRecordComponent
__all__ = [
"Component",
@ -10,4 +12,6 @@ __all__ = [
"UUIDGeneratorComponent",
"PythonFunctionComponent",
"RecordsAsTextComponent",
"TextToRecordComponent",
"MessageHistoryComponent",
]

View file

@ -1,4 +1,5 @@
from langchain_community.chat_models.cohere import ChatCohere
from pydantic.v1 import SecretStr
from langflow.components.models.base.model import LCModelComponent
from langflow.field_typing import Text
@ -44,8 +45,9 @@ class CohereComponent(LCModelComponent):
temperature: float = 0.75,
stream: bool = False,
) -> Text:
api_key = SecretStr(cohere_api_key)
output = ChatCohere( # type: ignore
cohere_api_key=cohere_api_key,
cohere_api_key=api_key,
temperature=temperature,
)
return self.get_result(output=output, stream=stream, input_value=input_value)

View file

@ -122,7 +122,9 @@ class ContractEdge(Edge):
return
if not source._built:
await source.build()
# The system should be read-only, so we should not be building vertices
# that are not already built.
raise ValueError(f"Source vertex {source.id} is not built.")
if self.matched_type == "Text":
self.result = source._built_result
@ -132,7 +134,7 @@ class ContractEdge(Edge):
target.params[self.target_param] = self.result
self.is_fulfilled = True
async def get_result(self, source: "Vertex", target: "Vertex"):
async def get_result_from_source(self, source: "Vertex", target: "Vertex"):
# Fulfill the contract if it has not been fulfilled.
if not self.is_fulfilled:
await self.honor(source, target)

View file

@ -240,6 +240,7 @@ class Graph:
def build_graph_maps(self):
self.predecessor_map, self.successor_map = self.build_adjacency_maps()
self.in_degree_map = self.build_in_degree()
self.parent_child_map = self.build_parent_child_map()
@ -295,6 +296,15 @@ class Graph:
successor_map[edge.source_id].append(edge.target_id)
return predecessor_map, successor_map
def build_run_map(self):
run_map = defaultdict(list)
# The run map gets the predecessor_map and maps the info like this:
# {vertex_id: every id that contains the vertex_id in the predecessor_map}
for vertex_id, predecessors in self.predecessor_map.items():
for predecessor in predecessors:
run_map[predecessor].append(vertex_id)
return run_map
@classmethod
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph":
"""
@ -939,16 +949,37 @@ class Graph:
# save the only the rest
self.vertices_layers = vertices_layers[1:]
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
self.run_map, self.run_predecessors = (
self.build_run_map(),
self.predecessor_map.copy(),
)
# Return just the first layer
return first_layer
def vertex_has_no_more_predecessors(self, vertex_id: str) -> bool:
"""Returns whether a vertex has no more predecessors."""
return not self.run_predecessors.get(vertex_id)
def should_run_vertex(self, vertex_id: str) -> bool:
"""Returns whether a component should be run."""
should_run = vertex_id in self.vertices_to_run
# the self.run_map is a map of vertex_id to a list of predecessors
# each time a vertex is run, we remove it from the list of predecessors
# if a vertex has no more predecessors, it should be run
should_run = vertex_id in self.vertices_to_run and self.vertex_has_no_more_predecessors(vertex_id)
if should_run:
self.vertices_to_run.remove(vertex_id)
# remove the vertex from the run_map
self.remove_from_predecessors(vertex_id)
return should_run
def remove_from_predecessors(self, vertex_id: str):
predecessors = self.run_map.get(vertex_id, [])
for predecessor in predecessors:
if vertex_id in self.run_predecessors[predecessor]:
self.run_predecessors[predecessor].remove(vertex_id)
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""

View file

@ -1,4 +1,5 @@
import ast
import asyncio
import inspect
import types
from enum import Enum
@ -7,7 +8,6 @@ from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Iterator,
List,
@ -56,6 +56,7 @@ class Vertex:
) -> None:
# is_external means that the Vertex send or receives data from
# an external source (e.g the chat)
self._lock = asyncio.Lock()
self.will_stream = False
self.updated_raw_params = False
self.id: str = data["id"]
@ -171,6 +172,7 @@ class Vertex:
}
def __setstate__(self, state):
self._lock = asyncio.Lock()
self._data = state["_data"]
self.params = state["params"]
self.base_type = state["base_type"]
@ -383,7 +385,7 @@ class Vertex:
Initiate the build process.
"""
logger.debug(f"Building {self.display_name}")
await self._build_each_node_in_params_dict(user_id)
await self._build_each_vertex_in_params_dict(user_id)
await self._get_and_instantiate_class(user_id)
self._validate_built_object()
@ -452,105 +454,123 @@ class Vertex:
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
self._built_result = result
async def _build_each_node_in_params_dict(self, user_id=None):
async def _build_each_vertex_in_params_dict(self, user_id=None):
"""
Iterates over each node in the params dictionary and builds it.
Iterates over each vertex in the params dictionary and builds it.
"""
for key, value in self._raw_params.items():
if self._is_node(value):
if self._is_vertex(value):
if value == self:
del self.params[key]
continue
await self._build_node_and_update_params(key, value, user_id)
elif isinstance(value, list) and self._is_list_of_nodes(value):
await self._build_list_of_nodes_and_update_params(key, value, user_id)
await self._build_vertex_and_update_params(
key,
value,
)
elif isinstance(value, list) and self._is_list_of_vertices(value):
await self._build_list_of_vertices_and_update_params(key, value)
elif isinstance(value, dict):
await self._build_dict_and_update_params(key, value, user_id)
await self._build_dict_and_update_params(
key,
value,
)
elif key not in self.params or self.updated_raw_params:
self.params[key] = value
async def _build_dict_and_update_params(self, key, nodes_dict: Dict[str, "Vertex"], user_id=None):
async def _build_dict_and_update_params(
self,
key,
vertices_dict: Dict[str, "Vertex"],
):
"""
Iterates over a dictionary of nodes, builds each and updates the params dictionary.
Iterates over a dictionary of vertices, builds each and updates the params dictionary.
"""
for sub_key, value in nodes_dict.items():
if not self._is_node(value):
for sub_key, value in vertices_dict.items():
if not self._is_vertex(value):
self.params[key][sub_key] = value
else:
built = await value.get_result(requester=self, user_id=user_id)
self.params[key][sub_key] = built
result = await value.get_result()
self.params[key][sub_key] = result
def _is_node(self, value):
def _is_vertex(self, value):
"""
Checks if the provided value is an instance of Vertex.
"""
return isinstance(value, Vertex)
def _is_list_of_nodes(self, value):
def _is_list_of_vertices(self, value):
"""
Checks if the provided value is a list of Vertex instances.
"""
return all(self._is_node(node) for node in value)
return all(self._is_vertex(vertex) for vertex in value)
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
# PLEASE REVIEW THIS IF STATEMENT
# Check if the Vertex was built already
if self._built:
return self._built_object if not self.use_result else self._built_result
if self.is_task and self.task_id is not None:
task = self.get_task()
result = task.get(timeout=timeout)
if isinstance(result, Coroutine):
result = await result
if result is not None: # If result is ready
self._update_built_object_and_artifacts(result)
return self._built_object
else:
# Handle the case when the result is not ready (retry, throw exception, etc.)
pass
# If there's no task_id, build the vertex locally
await self.build(requester=requester, user_id=user_id)
return self._built_object
async def _build_node_and_update_params(self, key, node: "Vertex", user_id=None):
async def get_result(
self,
) -> Any:
"""
Builds a given node and updates the params dictionary accordingly.
Retrieves the result of the vertex.
This is a read-only method so it raises an error if the vertex has not been built yet.
Returns:
The result of the vertex.
"""
async with self._lock:
return await self._get_result()
async def _get_result(self) -> Any:
"""
Retrieves the result of the built component.
If the component has not been built yet, a ValueError is raised.
Returns:
The built result if use_result is True, else the built object.
"""
if not self._built:
raise ValueError(f"Component {self.display_name} has not been built yet")
return self._built_result if self.use_result else self._built_object
async def _build_vertex_and_update_params(self, key, vertex: "Vertex"):
"""
Builds a given vertex and updates the params dictionary accordingly.
"""
result = await node.get_result(requester=self, user_id=user_id)
result = await vertex.get_result()
self._handle_func(key, result)
if isinstance(result, list):
self._extend_params_list_with_result(key, result)
self.params[key] = result
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
async def _build_list_of_vertices_and_update_params(
self,
key,
vertices: List["Vertex"],
):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
Iterates over a list of vertices, builds each and updates the params dictionary.
"""
self.params[key] = []
for node in nodes:
built = await node.get_result(requester=self, user_id=user_id)
for vertex in vertices:
result = await vertex.get_result()
# Weird check to see if the params[key] is a list
# because sometimes it is a Record and breaks the code
if not isinstance(self.params[key], list):
self.params[key] = [self.params[key]]
if isinstance(built, list):
self.params[key].extend(built)
if isinstance(result, list):
self.params[key].extend(result)
else:
try:
if self.params[key] == built:
if self.params[key] == result:
continue
self.params[key].append(built)
self.params[key].append(result)
except AttributeError as e:
logger.exception(e)
raise ValueError(
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {built}"
f"Error building node {self.display_name}: {str(e)}"
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}"
f"Error building vertex {self.display_name}: {str(e)}"
) from e
def _handle_func(self, key, result):
@ -580,12 +600,9 @@ class Vertex:
Gets the class from a dictionary and instantiates it with the params.
"""
if self.base_type is None:
raise ValueError(f"Base type for node {self.display_name} not found")
raise ValueError(f"Base type for vertex {self.display_name} not found")
try:
result = await loading.instantiate_class(
node_type=self.vertex_type,
base_type=self.base_type,
params=self.params,
user_id=user_id,
vertex=self,
)
@ -593,7 +610,7 @@ class Vertex:
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc
raise ValueError(f"Error building vertex {self.display_name}: {str(exc)}") from exc
def _update_built_object_and_artifacts(self, result):
"""
@ -647,33 +664,34 @@ class Vertex:
requester: Optional["Vertex"] = None,
**kwargs,
) -> Any:
if self.state == VertexStates.INACTIVE:
# If the vertex is inactive, return None
self.build_inactive()
return
async with self._lock:
if self.state == VertexStates.INACTIVE:
# If the vertex is inactive, return None
self.build_inactive()
return
if self.frozen and self._built:
return self.get_requester_result(requester)
elif self._built and requester is not None:
# This means that the vertex has already been built
# and we are just getting the result for the requester
return await self.get_requester_result(requester)
self._reset()
if self.frozen and self._built:
return self.get_requester_result(requester)
elif self._built and requester is not None:
# This means that the vertex has already been built
# and we are just getting the result for the requester
return await self.get_requester_result(requester)
self._reset()
if self._is_chat_input() and inputs:
inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")}
self.update_raw_params(inputs, overwrite=True)
if self._is_chat_input() and inputs:
inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")}
self.update_raw_params(inputs, overwrite=True)
# Run steps
for step in self.steps:
if step not in self.steps_ran:
if inspect.iscoroutinefunction(step):
await step(user_id=user_id, **kwargs)
else:
step(user_id=user_id, **kwargs)
self.steps_ran.append(step)
# Run steps
for step in self.steps:
if step not in self.steps_ran:
if inspect.iscoroutinefunction(step):
await step(user_id=user_id, **kwargs)
else:
step(user_id=user_id, **kwargs)
self.steps_ran.append(step)
self._finalize_build()
self._finalize_build()
return await self.get_requester_result(requester)
@ -686,7 +704,11 @@ class Vertex:
# Get the requester edge
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
# Return the result of the requester edge
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
return (
None
if requester_edge is None
else await requester_edge.get_result_from_source(source=self, target=requester)
)
def add_edge(self, edge: "ContractEdge") -> None:
if edge not in self.edges:

View file

@ -1,6 +1,6 @@
import inspect
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Type
import orjson
from langchain.agents import agent as agent_module
@ -40,27 +40,29 @@ if TYPE_CHECKING:
async def instantiate_class(
node_type: str,
base_type: str,
params: Dict,
vertex: "Vertex",
user_id=None,
vertex: Optional["Vertex"] = None,
) -> Any:
"""Instantiate class from module type and key, and params"""
vertex_type = vertex.vertex_type
base_type = vertex.base_type
params = vertex.params
params = convert_params_to_sets(params)
params = convert_kwargs(params)
if node_type in CUSTOM_NODES:
if custom_node := CUSTOM_NODES.get(node_type):
if vertex_type in CUSTOM_NODES:
if custom_node := CUSTOM_NODES.get(vertex_type):
if hasattr(custom_node, "initialize"):
return custom_node.initialize(**params)
return custom_node(**params)
logger.debug(f"Instantiating {node_type} of type {base_type}")
class_object = import_by_type(_type=base_type, name=node_type)
logger.debug(f"Instantiating {vertex_type} of type {base_type}")
if not base_type:
raise ValueError("No base type provided for vertex")
class_object = import_by_type(_type=base_type, name=vertex_type)
return await instantiate_based_on_type(
class_object=class_object,
base_type=base_type,
node_type=node_type,
node_type=vertex_type,
params=params,
user_id=user_id,
vertex=vertex,

View file

@ -16,7 +16,7 @@ class Record(BaseModel):
_default_value: str = ""
@model_validator(mode="before")
def validate_data(values):
def validate_data(cls, values):
if not values.get("data"):
values["data"] = {}
# Any other keyword should be added to the data dictionary

View file

@ -1,9 +1,17 @@
from . import factory, service
from langflow.services.cache.service import InMemoryCache
from langflow.services.cache.service import (
AsyncInMemoryCache,
BaseCacheService,
RedisCache,
ThreadingInMemoryCache,
)
from . import factory, service
__all__ = [
"factory",
"service",
"InMemoryCache",
"ThreadingInMemoryCache",
"AsyncInMemoryCache",
"BaseCacheService",
"RedisCache",
]

View file

@ -1,4 +1,7 @@
import abc
import asyncio
import threading
from typing import Optional
from langflow.services.base import Service
@ -11,7 +14,7 @@ class BaseCacheService(Service):
name = "cache_service"
@abc.abstractmethod
def get(self, key):
def get(self, key, lock: Optional[threading.Lock] = None):
"""
Retrieve an item from the cache.
@ -23,7 +26,7 @@ class BaseCacheService(Service):
"""
@abc.abstractmethod
def set(self, key, value):
def set(self, key, value, lock: Optional[threading.Lock] = None):
"""
Add an item to the cache.
@ -33,7 +36,7 @@ class BaseCacheService(Service):
"""
@abc.abstractmethod
def upsert(self, key, value):
def upsert(self, key, value, lock: Optional[threading.Lock] = None):
"""
Add an item to the cache if it doesn't exist, or update it if it does.
@ -43,7 +46,7 @@ class BaseCacheService(Service):
"""
@abc.abstractmethod
def delete(self, key):
def delete(self, key, lock: Optional[threading.Lock] = None):
"""
Remove an item from the cache.
@ -52,7 +55,7 @@ class BaseCacheService(Service):
"""
@abc.abstractmethod
def clear(self):
def clear(self, lock: Optional[threading.Lock] = None):
"""
Clear all items from the cache.
"""
@ -96,3 +99,70 @@ class BaseCacheService(Service):
Args:
key: The key of the item to remove.
"""
class AsyncBaseCacheService(Service):
"""
Abstract base class for a async cache.
"""
name = "cache_service"
@abc.abstractmethod
async def get(self, key, lock: Optional[asyncio.Lock] = None):
"""
Retrieve an item from the cache.
Args:
key: The key of the item to retrieve.
Returns:
The value associated with the key, or None if the key is not found.
"""
@abc.abstractmethod
async def set(self, key, value, lock: Optional[asyncio.Lock] = None):
"""
Add an item to the cache.
Args:
key: The key of the item.
value: The value to cache.
"""
@abc.abstractmethod
async def upsert(self, key, value, lock: Optional[asyncio.Lock] = None):
"""
Add an item to the cache if it doesn't exist, or update it if it does.
Args:
key: The key of the item.
value: The value to cache.
"""
@abc.abstractmethod
async def delete(self, key, lock: Optional[asyncio.Lock] = None):
"""
Remove an item from the cache.
Args:
key: The key of the item to remove.
"""
@abc.abstractmethod
async def clear(self, lock: Optional[asyncio.Lock] = None):
"""
Clear all items from the cache.
"""
@abc.abstractmethod
def __contains__(self, key):
"""
Check if the key is in the cache.
Args:
key: The key of the item to check.
Returns:
True if the key is in the cache, False otherwise.
"""

View file

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from langflow.services.cache.service import BaseCacheService, InMemoryCache, RedisCache
from langflow.services.cache.service import (
AsyncInMemoryCache,
BaseCacheService,
RedisCache,
ThreadingInMemoryCache,
)
from langflow.services.factory import ServiceFactory
from langflow.utils.logger import logger
@ -29,7 +34,9 @@ class CacheServiceFactory(ServiceFactory):
logger.debug("Redis cache is connected")
return redis_cache
logger.warning("Redis cache is not connected, falling back to in-memory cache")
return InMemoryCache()
return ThreadingInMemoryCache()
elif settings_service.settings.CACHE_TYPE == "memory":
return InMemoryCache()
return ThreadingInMemoryCache()
elif settings_service.settings.CACHE_TYPE == "async":
return AsyncInMemoryCache()

View file

@ -1,16 +1,17 @@
import asyncio
import pickle
import threading
import time
from collections import OrderedDict
from typing import Optional
from loguru import logger
from langflow.services.base import Service
from langflow.services.cache.base import BaseCacheService
from langflow.services.cache.base import AsyncBaseCacheService, BaseCacheService
class InMemoryCache(BaseCacheService, Service):
class ThreadingInMemoryCache(BaseCacheService, Service):
"""
A simple in-memory cache using an OrderedDict.
@ -49,7 +50,7 @@ class InMemoryCache(BaseCacheService, Service):
self.max_size = max_size
self.expiration_time = expiration_time
def get(self, key):
def get(self, key, lock: Optional[threading.Lock] = None):
"""
Retrieve an item from the cache.
@ -59,7 +60,7 @@ class InMemoryCache(BaseCacheService, Service):
Returns:
The value associated with the key, or None if the key is not found or the item has expired.
"""
with self._lock:
with lock or self._lock:
return self._get_without_lock(key)
def _get_without_lock(self, key):
@ -80,7 +81,7 @@ class InMemoryCache(BaseCacheService, Service):
self.delete(key)
return None
def set(self, key, value, pickle=False):
def set(self, key, value, lock: Optional[threading.Lock] = None):
"""
Add an item to the cache.
@ -90,7 +91,7 @@ class InMemoryCache(BaseCacheService, Service):
key: The key of the item.
value: The value to cache.
"""
with self._lock:
with lock or self._lock:
if key in self._cache:
# Remove existing key before re-inserting to update order
self.delete(key)
@ -98,12 +99,10 @@ class InMemoryCache(BaseCacheService, Service):
# Remove least recently used item
self._cache.popitem(last=False)
# pickle locally to mimic Redis
if pickle:
value = pickle.dumps(value)
self._cache[key] = {"value": value, "time": time.time()}
def upsert(self, key, value):
def upsert(self, key, value, lock: Optional[threading.Lock] = None):
"""
Inserts or updates a value in the cache.
If the existing value and the new value are both dictionaries, they are merged.
@ -112,7 +111,7 @@ class InMemoryCache(BaseCacheService, Service):
key: The key of the item.
value: The value to insert or update.
"""
with self._lock:
with lock or self._lock:
existing_value = self._get_without_lock(key)
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
@ -120,7 +119,7 @@ class InMemoryCache(BaseCacheService, Service):
self.set(key, value)
def get_or_set(self, key, value):
def get_or_set(self, key, value, lock: Optional[threading.Lock] = None):
"""
Retrieve an item from the cache. If the item does not exist,
set it with the provided value.
@ -132,27 +131,27 @@ class InMemoryCache(BaseCacheService, Service):
Returns:
The cached value associated with the key.
"""
with self._lock:
with lock or self._lock:
if key in self._cache:
return self.get(key)
self.set(key, value)
return value
def delete(self, key):
def delete(self, key, lock: Optional[threading.Lock] = None):
"""
Remove an item from the cache.
Args:
key: The key of the item to remove.
"""
with self._lock:
with lock or self._lock:
self._cache.pop(key, None)
def clear(self):
def clear(self, lock: Optional[threading.Lock] = None):
"""
Clear all items from the cache.
"""
with self._lock:
with lock or self._lock:
self._cache.clear()
def __contains__(self, key):
@ -323,3 +322,85 @@ class RedisCache(BaseCacheService, Service):
def __repr__(self):
"""Return a string representation of the RedisCache instance."""
return f"RedisCache(expiration_time={self.expiration_time})"
class AsyncInMemoryCache(AsyncBaseCacheService, Service):
def __init__(self, max_size=None, expiration_time=3600):
self.cache = OrderedDict()
self.lock = asyncio.Lock()
self.max_size = max_size
self.expiration_time = expiration_time
async def get(self, key, lock: Optional[asyncio.Lock] = None):
if not lock:
async with self.lock:
return await self._get(key)
else:
return await self._get(key)
async def _get(self, key):
item = self.cache.get(key, None)
if item and (time.time() - item["time"] < self.expiration_time):
self.cache.move_to_end(key)
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
if item:
await self.delete(key)
return None
async def set(self, key, value, lock: Optional[asyncio.Lock] = None):
if not lock:
async with self.lock:
await self._set(
key,
value,
)
else:
await self._set(
key,
value,
)
async def _set(self, key, value):
if self.max_size and len(self.cache) >= self.max_size:
self.cache.popitem(last=False)
self.cache[key] = {"value": value, "time": time.time()}
self.cache.move_to_end(key)
async def delete(self, key, lock: Optional[asyncio.Lock] = None):
if not lock:
async with self.lock:
await self._delete(key)
else:
await self._delete(key)
async def _delete(self, key):
if key in self.cache:
del self.cache[key]
async def clear(self, lock: Optional[asyncio.Lock] = None):
if not lock:
async with self.lock:
await self._clear()
else:
await self._clear()
async def _clear(self):
self.cache.clear()
async def upsert(self, key, value, lock: Optional[asyncio.Lock] = None):
if not lock:
async with self.lock:
await self._upsert(key, value)
else:
await self._upsert(key, value)
async def _upsert(self, key, value):
existing_value = await self.get(key)
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
value = existing_value
await self.set(key, value)
def __contains__(self, key):
return key in self.cache

View file

@ -1,4 +1,6 @@
from typing import Any
import asyncio
from collections import defaultdict
from typing import Any, Optional
from langflow.services.base import Service
from langflow.services.deps import get_cache_service
@ -8,30 +10,30 @@ class ChatService(Service):
name = "chat_service"
def __init__(self):
self._cache_locks = defaultdict(asyncio.Lock)
self.cache_service = get_cache_service()
def set_cache(self, client_id: str, data: Any) -> bool:
async def set_cache(self, flow_id: str, data: Any, lock: Optional[asyncio.Lock] = None) -> bool:
"""
Set the cache for a client.
"""
# client_id is the flow id but that already exists in the cache
# so we need to change it to something else
result_dict = {
"result": data,
"type": type(data),
}
self.cache_service.upsert(client_id, result_dict)
return client_id in self.cache_service
await self.cache_service.upsert(flow_id, result_dict, lock=lock or self._cache_locks[flow_id])
return flow_id in self.cache_service
def get_cache(self, client_id: str) -> Any:
async def get_cache(self, flow_id: str, lock: Optional[asyncio.Lock] = None) -> Any:
"""
Get the cache for a client.
"""
return self.cache_service.get(client_id)
return await self.cache_service.get(flow_id, lock=lock or self._cache_locks[flow_id])
def clear_cache(self, client_id: str):
async def clear_cache(self, flow_id: str, lock: Optional[asyncio.Lock] = None):
"""
Clear the cache for a client.
"""
self.cache_service.delete(client_id)
self.cache_service.delete(flow_id, lock=lock or self._cache_locks[flow_id])

View file

@ -38,7 +38,7 @@ class Settings(BaseSettings):
DEV: bool = False
DATABASE_URL: Optional[str] = None
CACHE_TYPE: str = "memory"
CACHE_TYPE: str = "async"
REMOVE_API_KEYS: bool = False
COMPONENTS_PATH: List[str] = []
LANGCHAIN_CACHE: str = "InMemoryCache"

View file

@ -9,7 +9,7 @@ const SvgBotMessageSquare = (props) => (
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
class="lucide lucide-bot-message-square"
className="lucide lucide-bot-message-square"
{...props}
>
<path d="M12 6V2H8" />

View file

@ -1,8 +1,9 @@
import json
from langflow.graph import Graph
import pytest
from langflow.graph import Graph
def get_graph(_type="basic"):
"""Get a graph from a json file"""
@ -38,7 +39,8 @@ def langchain_objects_are_equal(obj1, obj2):
# Test build_graph
def test_build_graph(client, basic_data_graph):
@pytest.mark.asyncio
async def test_build_graph(client, basic_data_graph):
graph = Graph.from_payload(basic_data_graph)
assert graph is not None
assert len(graph.vertices) == len(basic_data_graph["nodes"])

View file

@ -28,9 +28,7 @@ async def test_successful_get_request(api_request):
respx.get(url).mock(return_value=Response(200, json=mock_response))
# Making the request
result = await api_request.make_request(
client=httpx.AsyncClient(), method=method, url=url
)
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
# Assertions
assert result.data["status_code"] == 200
@ -46,9 +44,7 @@ async def test_failed_request(api_request):
respx.get(url).mock(return_value=Response(404))
# Making the request
result = await api_request.make_request(
client=httpx.AsyncClient(), method=method, url=url
)
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
# Assertions
assert result.data["status_code"] == 404
@ -60,14 +56,10 @@ async def test_timeout(api_request):
# Mocking a timeout
url = "https://example.com/api/timeout"
method = "GET"
respx.get(url).mock(
side_effect=httpx.TimeoutException(message="Timeout", request=None)
)
respx.get(url).mock(side_effect=httpx.TimeoutException(message="Timeout", request=None))
# Making the request
result = await api_request.make_request(
client=httpx.AsyncClient(), method=method, url=url, timeout=1
)
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url, timeout=1)
# Assertions
assert result.data["status_code"] == 408
@ -106,7 +98,6 @@ def test_directory_component_build_with_multithreading(
# Arrange
directory_component = data.DirectoryComponent()
path = os.path.dirname(os.path.abspath(__file__))
types = ["py"]
depth = 1
max_concurrency = 2
load_hidden = False
@ -123,7 +114,6 @@ def test_directory_component_build_with_multithreading(
# Act
directory_component.build(
path,
types,
depth,
max_concurrency,
load_hidden,
@ -134,9 +124,7 @@ def test_directory_component_build_with_multithreading(
# Assert
mock_resolve_path.assert_called_once_with(path)
mock_retrieve_file_paths.assert_called_once_with(
path, types, load_hidden, recursive, depth
)
mock_retrieve_file_paths.assert_called_once_with(path, load_hidden, recursive, depth)
mock_parallel_load_records.assert_called_once_with(
mock_retrieve_file_paths.return_value, silent_errors, max_concurrency
)