Refactor cache service and fix async issues (#1512)
This commit is contained in:
parent
23fe37326b
commit
67bccdc753
24 changed files with 466 additions and 171 deletions
|
|
@ -197,7 +197,7 @@ def format_elapsed_time(elapsed_time: float) -> str:
|
|||
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"
|
||||
|
||||
|
||||
def build_and_cache_graph(
|
||||
async def build_and_cache_graph(
|
||||
flow_id: str,
|
||||
session: Session,
|
||||
chat_service: "ChatService",
|
||||
|
|
@ -212,7 +212,7 @@ def build_and_cache_graph(
|
|||
graph = other_graph
|
||||
else:
|
||||
graph = graph.update(other_graph)
|
||||
chat_service.set_cache(flow_id, graph)
|
||||
await chat_service.set_cache(flow_id, graph)
|
||||
return graph
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -58,9 +58,9 @@ async def get_vertices(
|
|||
try:
|
||||
# First, we need to check if the flow_id is in the cache
|
||||
graph = None
|
||||
if cache := chat_service.get_cache(flow_id):
|
||||
if cache := await chat_service.get_cache(flow_id):
|
||||
graph = cache.get("result")
|
||||
graph = build_and_cache_graph(flow_id, session, chat_service, graph)
|
||||
graph = await build_and_cache_graph(flow_id, session, chat_service, graph)
|
||||
if stop_component_id or start_component_id:
|
||||
try:
|
||||
vertices = graph.sort_vertices(stop_component_id, start_component_id)
|
||||
|
|
@ -98,11 +98,11 @@ async def build_vertex(
|
|||
next_vertices_ids = []
|
||||
try:
|
||||
start_time = time.perf_counter()
|
||||
cache = chat_service.get_cache(flow_id)
|
||||
cache = await chat_service.get_cache(flow_id)
|
||||
if not cache:
|
||||
# If there's no cache
|
||||
logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}")
|
||||
graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
|
||||
graph = await build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service)
|
||||
else:
|
||||
graph = cache.get("result")
|
||||
result_data_response = ResultDataResponse(results={})
|
||||
|
|
@ -121,8 +121,11 @@ async def build_vertex(
|
|||
artifacts = vertex.artifacts
|
||||
else:
|
||||
raise ValueError(f"No result found for vertex {vertex_id}")
|
||||
next_vertices_ids = vertex.successors_ids
|
||||
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
|
||||
async with chat_service._cache_locks[flow_id] as lock:
|
||||
graph.remove_from_predecessors(vertex_id)
|
||||
next_vertices_ids = vertex.successors_ids
|
||||
next_vertices_ids = [v for v in next_vertices_ids if graph.should_run_vertex(v)]
|
||||
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
|
||||
|
||||
result_data_response = ResultDataResponse(**result_dict.model_dump())
|
||||
|
||||
|
|
@ -134,7 +137,7 @@ async def build_vertex(
|
|||
artifacts = {}
|
||||
# If there's an error building the vertex
|
||||
# we need to clear the cache
|
||||
chat_service.clear_cache(flow_id)
|
||||
await chat_service.clear_cache(flow_id)
|
||||
|
||||
# Log the vertex build
|
||||
if not vertex.will_stream:
|
||||
|
|
@ -157,7 +160,7 @@ async def build_vertex(
|
|||
inactivated_vertices = list(graph.inactivated_vertices)
|
||||
graph.reset_inactivated_vertices()
|
||||
graph.reset_activated_vertices()
|
||||
chat_service.set_cache(flow_id, graph)
|
||||
await chat_service.set_cache(flow_id, graph)
|
||||
|
||||
# graph.stop_vertex tells us if the user asked
|
||||
# to stop the build of the graph at a certain vertex
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.api.v1.schemas import Token
|
||||
from langflow.services.auth.utils import (
|
||||
authenticate_user,
|
||||
|
|
@ -8,7 +10,6 @@ from langflow.services.auth.utils import (
|
|||
create_user_tokens,
|
||||
)
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter(tags=["Login"])
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,8 @@ def parse_text_file_to_record(file_path: str, silent_errors: bool) -> Optional[R
|
|||
elif file_path.endswith(".yaml") or file_path.endswith(".yml"):
|
||||
text = yaml.safe_load(text)
|
||||
elif file_path.endswith(".xml"):
|
||||
text = ET.fromstring(text)
|
||||
xml_element = ET.fromstring(text)
|
||||
text = ET.tostring(xml_element, encoding="unicode")
|
||||
except Exception as e:
|
||||
if not silent_errors:
|
||||
raise ValueError(f"Error loading file {file_path}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class ChatComponent(CustomComponent):
|
|||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "If provided, the message will be stored in the memory.",
|
||||
"advanced": True,
|
||||
},
|
||||
"return_record": {
|
||||
"display_name": "Return Record",
|
||||
|
|
|
|||
70
src/backend/langflow/components/agents/ReActAgent.py
Normal file
70
src/backend/langflow/components/agents/ReActAgent.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# from typing import Dict, List
|
||||
|
||||
# import dspy
|
||||
|
||||
# from langflow import CustomComponent
|
||||
# from langflow.field_typing import Text
|
||||
|
||||
|
||||
# class ReActAgentComponent(CustomComponent):
|
||||
# display_name = "ReAct Agent"
|
||||
# description = "A component to create a ReAct Agent."
|
||||
# icon = "user-secret"
|
||||
|
||||
# def build_config(self):
|
||||
# return {
|
||||
# "input_value": {
|
||||
# "display_name": "Input",
|
||||
# "input_types": ["Text"],
|
||||
# "info": "The input value for the ReAct Agent.",
|
||||
# },
|
||||
# "instructions": {
|
||||
# "display_name": "Instructions",
|
||||
# "info": "The Prompt.",
|
||||
# },
|
||||
# "inputs": {
|
||||
# "display_name": "Inputs",
|
||||
# "info": "The Name and Description of the Input Fields.",
|
||||
# },
|
||||
# "outputs": {
|
||||
# "display_name": "Outputs",
|
||||
# "info": "The Name and Description of the Output Fields.",
|
||||
# },
|
||||
# }
|
||||
|
||||
# def build(
|
||||
# self,
|
||||
# input_value: List[dict],
|
||||
# instructions: Text,
|
||||
# inputs: List[dict],
|
||||
# outputs: List[Dict],
|
||||
# ) -> Text:
|
||||
# # inputs is a list of dictionaries where the key is the name of the input
|
||||
# # and the value is the description of the input
|
||||
# input_fields = (
|
||||
# {}
|
||||
# ) # dict[str, FieldInfo] InputField and OutputField are subclasses of pydantic.Field
|
||||
# for input_dict in inputs:
|
||||
# for name, description in input_dict.items():
|
||||
# prefix = name if ":" in name else f"{name}:"
|
||||
# input_fields[name] = dspy.InputField(
|
||||
# prefix=prefix, description=description
|
||||
# )
|
||||
|
||||
# output_fields = {} # dict[str, FieldInfo]
|
||||
# for output_dict in outputs:
|
||||
# for name, description in output_dict.items():
|
||||
# prefix = name if ":" in name else f"{name}:"
|
||||
# output_fields[name] = dspy.OutputField(
|
||||
# prefix=prefix, description=description
|
||||
# )
|
||||
|
||||
# signature = dspy.make_signature(inputs, instructions=instructions)
|
||||
# agent = dspy.ReAct(
|
||||
# signature=signature,
|
||||
# )
|
||||
# inputs_dict = {}
|
||||
# for input_dict in input_value:
|
||||
# inputs_dict.update(input_dict)
|
||||
|
||||
# result = agent(inputs_dict)
|
||||
|
|
@ -106,7 +106,7 @@ class APIRequest(CustomComponent):
|
|||
bodies = [body.data]
|
||||
if len(urls) != len(bodies):
|
||||
# add bodies with None
|
||||
bodies += [None] * (len(urls) - len(bodies))
|
||||
bodies += [None] * (len(urls) - len(bodies)) # type: ignore
|
||||
async with httpx.AsyncClient() as client:
|
||||
results = await asyncio.gather(
|
||||
*[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)]
|
||||
|
|
|
|||
|
|
@ -3,12 +3,10 @@ from .ExtractDataFromRecord import ExtractKeyFromRecordComponent
|
|||
from .GetNotified import GetNotifiedComponent
|
||||
from .ListFlows import ListFlowsComponent
|
||||
from .MergeRecords import MergeRecordsComponent
|
||||
from .MessageHistory import MessageHistoryComponent
|
||||
from .Notify import NotifyComponent
|
||||
from .RunFlow import RunFlowComponent
|
||||
from .RunnableExecutor import RunnableExecComponent
|
||||
from .SQLExecutor import SQLExecutorComponent
|
||||
from .TextToRecord import TextToRecordComponent
|
||||
|
||||
__all__ = [
|
||||
"ClearMessageHistoryComponent",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from .CustomComponent import Component
|
||||
from .DocumentToRecord import DocumentToRecordComponent
|
||||
from .IDGenerator import UUIDGeneratorComponent
|
||||
from .MessageHistory import MessageHistoryComponent
|
||||
from .PythonFunction import PythonFunctionComponent
|
||||
from .RecordsAsText import RecordsAsTextComponent
|
||||
from .TextToRecord import TextToRecordComponent
|
||||
|
||||
__all__ = [
|
||||
"Component",
|
||||
|
|
@ -10,4 +12,6 @@ __all__ = [
|
|||
"UUIDGeneratorComponent",
|
||||
"PythonFunctionComponent",
|
||||
"RecordsAsTextComponent",
|
||||
"TextToRecordComponent",
|
||||
"MessageHistoryComponent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from langchain_community.chat_models.cohere import ChatCohere
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow.components.models.base.model import LCModelComponent
|
||||
from langflow.field_typing import Text
|
||||
|
|
@ -44,8 +45,9 @@ class CohereComponent(LCModelComponent):
|
|||
temperature: float = 0.75,
|
||||
stream: bool = False,
|
||||
) -> Text:
|
||||
api_key = SecretStr(cohere_api_key)
|
||||
output = ChatCohere( # type: ignore
|
||||
cohere_api_key=cohere_api_key,
|
||||
cohere_api_key=api_key,
|
||||
temperature=temperature,
|
||||
)
|
||||
return self.get_result(output=output, stream=stream, input_value=input_value)
|
||||
|
|
|
|||
|
|
@ -122,7 +122,9 @@ class ContractEdge(Edge):
|
|||
return
|
||||
|
||||
if not source._built:
|
||||
await source.build()
|
||||
# The system should be read-only, so we should not be building vertices
|
||||
# that are not already built.
|
||||
raise ValueError(f"Source vertex {source.id} is not built.")
|
||||
|
||||
if self.matched_type == "Text":
|
||||
self.result = source._built_result
|
||||
|
|
@ -132,7 +134,7 @@ class ContractEdge(Edge):
|
|||
target.params[self.target_param] = self.result
|
||||
self.is_fulfilled = True
|
||||
|
||||
async def get_result(self, source: "Vertex", target: "Vertex"):
|
||||
async def get_result_from_source(self, source: "Vertex", target: "Vertex"):
|
||||
# Fulfill the contract if it has not been fulfilled.
|
||||
if not self.is_fulfilled:
|
||||
await self.honor(source, target)
|
||||
|
|
|
|||
|
|
@ -240,6 +240,7 @@ class Graph:
|
|||
|
||||
def build_graph_maps(self):
|
||||
self.predecessor_map, self.successor_map = self.build_adjacency_maps()
|
||||
|
||||
self.in_degree_map = self.build_in_degree()
|
||||
self.parent_child_map = self.build_parent_child_map()
|
||||
|
||||
|
|
@ -295,6 +296,15 @@ class Graph:
|
|||
successor_map[edge.source_id].append(edge.target_id)
|
||||
return predecessor_map, successor_map
|
||||
|
||||
def build_run_map(self):
|
||||
run_map = defaultdict(list)
|
||||
# The run map gets the predecessor_map and maps the info like this:
|
||||
# {vertex_id: every id that contains the vertex_id in the predecessor_map}
|
||||
for vertex_id, predecessors in self.predecessor_map.items():
|
||||
for predecessor in predecessors:
|
||||
run_map[predecessor].append(vertex_id)
|
||||
return run_map
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
|
|
@ -939,16 +949,37 @@ class Graph:
|
|||
# save the only the rest
|
||||
self.vertices_layers = vertices_layers[1:]
|
||||
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
|
||||
self.run_map, self.run_predecessors = (
|
||||
self.build_run_map(),
|
||||
self.predecessor_map.copy(),
|
||||
)
|
||||
|
||||
# Return just the first layer
|
||||
return first_layer
|
||||
|
||||
def vertex_has_no_more_predecessors(self, vertex_id: str) -> bool:
|
||||
"""Returns whether a vertex has no more predecessors."""
|
||||
return not self.run_predecessors.get(vertex_id)
|
||||
|
||||
def should_run_vertex(self, vertex_id: str) -> bool:
|
||||
"""Returns whether a component should be run."""
|
||||
should_run = vertex_id in self.vertices_to_run
|
||||
# the self.run_map is a map of vertex_id to a list of predecessors
|
||||
# each time a vertex is run, we remove it from the list of predecessors
|
||||
# if a vertex has no more predecessors, it should be run
|
||||
should_run = vertex_id in self.vertices_to_run and self.vertex_has_no_more_predecessors(vertex_id)
|
||||
|
||||
if should_run:
|
||||
self.vertices_to_run.remove(vertex_id)
|
||||
# remove the vertex from the run_map
|
||||
self.remove_from_predecessors(vertex_id)
|
||||
return should_run
|
||||
|
||||
def remove_from_predecessors(self, vertex_id: str):
|
||||
predecessors = self.run_map.get(vertex_id, [])
|
||||
for predecessor in predecessors:
|
||||
if vertex_id in self.run_predecessors[predecessor]:
|
||||
self.run_predecessors[predecessor].remove(vertex_id)
|
||||
|
||||
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import ast
|
||||
import asyncio
|
||||
import inspect
|
||||
import types
|
||||
from enum import Enum
|
||||
|
|
@ -7,7 +8,6 @@ from typing import (
|
|||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
|
|
@ -56,6 +56,7 @@ class Vertex:
|
|||
) -> None:
|
||||
# is_external means that the Vertex send or receives data from
|
||||
# an external source (e.g the chat)
|
||||
self._lock = asyncio.Lock()
|
||||
self.will_stream = False
|
||||
self.updated_raw_params = False
|
||||
self.id: str = data["id"]
|
||||
|
|
@ -171,6 +172,7 @@ class Vertex:
|
|||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._lock = asyncio.Lock()
|
||||
self._data = state["_data"]
|
||||
self.params = state["params"]
|
||||
self.base_type = state["base_type"]
|
||||
|
|
@ -383,7 +385,7 @@ class Vertex:
|
|||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.display_name}")
|
||||
await self._build_each_node_in_params_dict(user_id)
|
||||
await self._build_each_vertex_in_params_dict(user_id)
|
||||
await self._get_and_instantiate_class(user_id)
|
||||
self._validate_built_object()
|
||||
|
||||
|
|
@ -452,105 +454,123 @@ class Vertex:
|
|||
result = await generate_result(self._built_object, inputs, self.has_external_output, session_id)
|
||||
self._built_result = result
|
||||
|
||||
async def _build_each_node_in_params_dict(self, user_id=None):
|
||||
async def _build_each_vertex_in_params_dict(self, user_id=None):
|
||||
"""
|
||||
Iterates over each node in the params dictionary and builds it.
|
||||
Iterates over each vertex in the params dictionary and builds it.
|
||||
"""
|
||||
for key, value in self._raw_params.items():
|
||||
if self._is_node(value):
|
||||
if self._is_vertex(value):
|
||||
if value == self:
|
||||
del self.params[key]
|
||||
continue
|
||||
await self._build_node_and_update_params(key, value, user_id)
|
||||
elif isinstance(value, list) and self._is_list_of_nodes(value):
|
||||
await self._build_list_of_nodes_and_update_params(key, value, user_id)
|
||||
await self._build_vertex_and_update_params(
|
||||
key,
|
||||
value,
|
||||
)
|
||||
elif isinstance(value, list) and self._is_list_of_vertices(value):
|
||||
await self._build_list_of_vertices_and_update_params(key, value)
|
||||
elif isinstance(value, dict):
|
||||
await self._build_dict_and_update_params(key, value, user_id)
|
||||
await self._build_dict_and_update_params(
|
||||
key,
|
||||
value,
|
||||
)
|
||||
elif key not in self.params or self.updated_raw_params:
|
||||
self.params[key] = value
|
||||
|
||||
async def _build_dict_and_update_params(self, key, nodes_dict: Dict[str, "Vertex"], user_id=None):
|
||||
async def _build_dict_and_update_params(
|
||||
self,
|
||||
key,
|
||||
vertices_dict: Dict[str, "Vertex"],
|
||||
):
|
||||
"""
|
||||
Iterates over a dictionary of nodes, builds each and updates the params dictionary.
|
||||
Iterates over a dictionary of vertices, builds each and updates the params dictionary.
|
||||
"""
|
||||
for sub_key, value in nodes_dict.items():
|
||||
if not self._is_node(value):
|
||||
for sub_key, value in vertices_dict.items():
|
||||
if not self._is_vertex(value):
|
||||
self.params[key][sub_key] = value
|
||||
else:
|
||||
built = await value.get_result(requester=self, user_id=user_id)
|
||||
self.params[key][sub_key] = built
|
||||
result = await value.get_result()
|
||||
self.params[key][sub_key] = result
|
||||
|
||||
def _is_node(self, value):
|
||||
def _is_vertex(self, value):
|
||||
"""
|
||||
Checks if the provided value is an instance of Vertex.
|
||||
"""
|
||||
return isinstance(value, Vertex)
|
||||
|
||||
def _is_list_of_nodes(self, value):
|
||||
def _is_list_of_vertices(self, value):
|
||||
"""
|
||||
Checks if the provided value is a list of Vertex instances.
|
||||
"""
|
||||
return all(self._is_node(node) for node in value)
|
||||
return all(self._is_vertex(vertex) for vertex in value)
|
||||
|
||||
async def get_result(self, requester: Optional["Vertex"] = None, user_id=None, timeout=None) -> Any:
|
||||
# PLEASE REVIEW THIS IF STATEMENT
|
||||
# Check if the Vertex was built already
|
||||
if self._built:
|
||||
return self._built_object if not self.use_result else self._built_result
|
||||
|
||||
if self.is_task and self.task_id is not None:
|
||||
task = self.get_task()
|
||||
|
||||
result = task.get(timeout=timeout)
|
||||
if isinstance(result, Coroutine):
|
||||
result = await result
|
||||
if result is not None: # If result is ready
|
||||
self._update_built_object_and_artifacts(result)
|
||||
return self._built_object
|
||||
else:
|
||||
# Handle the case when the result is not ready (retry, throw exception, etc.)
|
||||
pass
|
||||
|
||||
# If there's no task_id, build the vertex locally
|
||||
await self.build(requester=requester, user_id=user_id)
|
||||
return self._built_object
|
||||
|
||||
async def _build_node_and_update_params(self, key, node: "Vertex", user_id=None):
|
||||
async def get_result(
|
||||
self,
|
||||
) -> Any:
|
||||
"""
|
||||
Builds a given node and updates the params dictionary accordingly.
|
||||
Retrieves the result of the vertex.
|
||||
|
||||
This is a read-only method so it raises an error if the vertex has not been built yet.
|
||||
|
||||
Returns:
|
||||
The result of the vertex.
|
||||
"""
|
||||
async with self._lock:
|
||||
return await self._get_result()
|
||||
|
||||
async def _get_result(self) -> Any:
|
||||
"""
|
||||
Retrieves the result of the built component.
|
||||
|
||||
If the component has not been built yet, a ValueError is raised.
|
||||
|
||||
Returns:
|
||||
The built result if use_result is True, else the built object.
|
||||
"""
|
||||
if not self._built:
|
||||
raise ValueError(f"Component {self.display_name} has not been built yet")
|
||||
return self._built_result if self.use_result else self._built_object
|
||||
|
||||
async def _build_vertex_and_update_params(self, key, vertex: "Vertex"):
|
||||
"""
|
||||
Builds a given vertex and updates the params dictionary accordingly.
|
||||
"""
|
||||
|
||||
result = await node.get_result(requester=self, user_id=user_id)
|
||||
result = await vertex.get_result()
|
||||
self._handle_func(key, result)
|
||||
if isinstance(result, list):
|
||||
self._extend_params_list_with_result(key, result)
|
||||
self.params[key] = result
|
||||
|
||||
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
|
||||
async def _build_list_of_vertices_and_update_params(
|
||||
self,
|
||||
key,
|
||||
vertices: List["Vertex"],
|
||||
):
|
||||
"""
|
||||
Iterates over a list of nodes, builds each and updates the params dictionary.
|
||||
Iterates over a list of vertices, builds each and updates the params dictionary.
|
||||
"""
|
||||
self.params[key] = []
|
||||
for node in nodes:
|
||||
built = await node.get_result(requester=self, user_id=user_id)
|
||||
for vertex in vertices:
|
||||
result = await vertex.get_result()
|
||||
# Weird check to see if the params[key] is a list
|
||||
# because sometimes it is a Record and breaks the code
|
||||
if not isinstance(self.params[key], list):
|
||||
self.params[key] = [self.params[key]]
|
||||
|
||||
if isinstance(built, list):
|
||||
self.params[key].extend(built)
|
||||
if isinstance(result, list):
|
||||
self.params[key].extend(result)
|
||||
else:
|
||||
try:
|
||||
if self.params[key] == built:
|
||||
if self.params[key] == result:
|
||||
continue
|
||||
|
||||
self.params[key].append(built)
|
||||
self.params[key].append(result)
|
||||
except AttributeError as e:
|
||||
logger.exception(e)
|
||||
raise ValueError(
|
||||
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {built}"
|
||||
f"Error building node {self.display_name}: {str(e)}"
|
||||
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}"
|
||||
f"Error building vertex {self.display_name}: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _handle_func(self, key, result):
|
||||
|
|
@ -580,12 +600,9 @@ class Vertex:
|
|||
Gets the class from a dictionary and instantiates it with the params.
|
||||
"""
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.display_name} not found")
|
||||
raise ValueError(f"Base type for vertex {self.display_name} not found")
|
||||
try:
|
||||
result = await loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
params=self.params,
|
||||
user_id=user_id,
|
||||
vertex=self,
|
||||
)
|
||||
|
|
@ -593,7 +610,7 @@ class Vertex:
|
|||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
|
||||
raise ValueError(f"Error building node {self.display_name}: {str(exc)}") from exc
|
||||
raise ValueError(f"Error building vertex {self.display_name}: {str(exc)}") from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result):
|
||||
"""
|
||||
|
|
@ -647,33 +664,34 @@ class Vertex:
|
|||
requester: Optional["Vertex"] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
if self.state == VertexStates.INACTIVE:
|
||||
# If the vertex is inactive, return None
|
||||
self.build_inactive()
|
||||
return
|
||||
async with self._lock:
|
||||
if self.state == VertexStates.INACTIVE:
|
||||
# If the vertex is inactive, return None
|
||||
self.build_inactive()
|
||||
return
|
||||
|
||||
if self.frozen and self._built:
|
||||
return self.get_requester_result(requester)
|
||||
elif self._built and requester is not None:
|
||||
# This means that the vertex has already been built
|
||||
# and we are just getting the result for the requester
|
||||
return await self.get_requester_result(requester)
|
||||
self._reset()
|
||||
if self.frozen and self._built:
|
||||
return self.get_requester_result(requester)
|
||||
elif self._built and requester is not None:
|
||||
# This means that the vertex has already been built
|
||||
# and we are just getting the result for the requester
|
||||
return await self.get_requester_result(requester)
|
||||
self._reset()
|
||||
|
||||
if self._is_chat_input() and inputs:
|
||||
inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")}
|
||||
self.update_raw_params(inputs, overwrite=True)
|
||||
if self._is_chat_input() and inputs:
|
||||
inputs = {"input_value": inputs.get(INPUT_FIELD_NAME, "")}
|
||||
self.update_raw_params(inputs, overwrite=True)
|
||||
|
||||
# Run steps
|
||||
for step in self.steps:
|
||||
if step not in self.steps_ran:
|
||||
if inspect.iscoroutinefunction(step):
|
||||
await step(user_id=user_id, **kwargs)
|
||||
else:
|
||||
step(user_id=user_id, **kwargs)
|
||||
self.steps_ran.append(step)
|
||||
# Run steps
|
||||
for step in self.steps:
|
||||
if step not in self.steps_ran:
|
||||
if inspect.iscoroutinefunction(step):
|
||||
await step(user_id=user_id, **kwargs)
|
||||
else:
|
||||
step(user_id=user_id, **kwargs)
|
||||
self.steps_ran.append(step)
|
||||
|
||||
self._finalize_build()
|
||||
self._finalize_build()
|
||||
|
||||
return await self.get_requester_result(requester)
|
||||
|
||||
|
|
@ -686,7 +704,11 @@ class Vertex:
|
|||
# Get the requester edge
|
||||
requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None)
|
||||
# Return the result of the requester edge
|
||||
return None if requester_edge is None else await requester_edge.get_result(source=self, target=requester)
|
||||
return (
|
||||
None
|
||||
if requester_edge is None
|
||||
else await requester_edge.get_result_from_source(source=self, target=requester)
|
||||
)
|
||||
|
||||
def add_edge(self, edge: "ContractEdge") -> None:
|
||||
if edge not in self.edges:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Type
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Type
|
||||
|
||||
import orjson
|
||||
from langchain.agents import agent as agent_module
|
||||
|
|
@ -40,27 +40,29 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
async def instantiate_class(
|
||||
node_type: str,
|
||||
base_type: str,
|
||||
params: Dict,
|
||||
vertex: "Vertex",
|
||||
user_id=None,
|
||||
vertex: Optional["Vertex"] = None,
|
||||
) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
vertex_type = vertex.vertex_type
|
||||
base_type = vertex.base_type
|
||||
params = vertex.params
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_node := CUSTOM_NODES.get(node_type):
|
||||
if vertex_type in CUSTOM_NODES:
|
||||
if custom_node := CUSTOM_NODES.get(vertex_type):
|
||||
if hasattr(custom_node, "initialize"):
|
||||
return custom_node.initialize(**params)
|
||||
return custom_node(**params)
|
||||
logger.debug(f"Instantiating {node_type} of type {base_type}")
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
logger.debug(f"Instantiating {vertex_type} of type {base_type}")
|
||||
if not base_type:
|
||||
raise ValueError("No base type provided for vertex")
|
||||
class_object = import_by_type(_type=base_type, name=vertex_type)
|
||||
return await instantiate_based_on_type(
|
||||
class_object=class_object,
|
||||
base_type=base_type,
|
||||
node_type=node_type,
|
||||
node_type=vertex_type,
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
vertex=vertex,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class Record(BaseModel):
|
|||
_default_value: str = ""
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_data(values):
|
||||
def validate_data(cls, values):
|
||||
if not values.get("data"):
|
||||
values["data"] = {}
|
||||
# Any other keyword should be added to the data dictionary
|
||||
|
|
|
|||
14
src/backend/langflow/services/cache/__init__.py
vendored
14
src/backend/langflow/services/cache/__init__.py
vendored
|
|
@ -1,9 +1,17 @@
|
|||
from . import factory, service
|
||||
from langflow.services.cache.service import InMemoryCache
|
||||
from langflow.services.cache.service import (
|
||||
AsyncInMemoryCache,
|
||||
BaseCacheService,
|
||||
RedisCache,
|
||||
ThreadingInMemoryCache,
|
||||
)
|
||||
|
||||
from . import factory, service
|
||||
|
||||
__all__ = [
|
||||
"factory",
|
||||
"service",
|
||||
"InMemoryCache",
|
||||
"ThreadingInMemoryCache",
|
||||
"AsyncInMemoryCache",
|
||||
"BaseCacheService",
|
||||
"RedisCache",
|
||||
]
|
||||
|
|
|
|||
80
src/backend/langflow/services/cache/base.py
vendored
80
src/backend/langflow/services/cache/base.py
vendored
|
|
@ -1,4 +1,7 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from langflow.services.base import Service
|
||||
|
||||
|
|
@ -11,7 +14,7 @@ class BaseCacheService(Service):
|
|||
name = "cache_service"
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key):
|
||||
def get(self, key, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
|
|
@ -23,7 +26,7 @@ class BaseCacheService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(self, key, value):
|
||||
def set(self, key, value, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Add an item to the cache.
|
||||
|
||||
|
|
@ -33,7 +36,7 @@ class BaseCacheService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def upsert(self, key, value):
|
||||
def upsert(self, key, value, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Add an item to the cache if it doesn't exist, or update it if it does.
|
||||
|
||||
|
|
@ -43,7 +46,7 @@ class BaseCacheService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key):
|
||||
def delete(self, key, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Remove an item from the cache.
|
||||
|
||||
|
|
@ -52,7 +55,7 @@ class BaseCacheService(Service):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self):
|
||||
def clear(self, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Clear all items from the cache.
|
||||
"""
|
||||
|
|
@ -96,3 +99,70 @@ class BaseCacheService(Service):
|
|||
Args:
|
||||
key: The key of the item to remove.
|
||||
"""
|
||||
|
||||
|
||||
class AsyncBaseCacheService(Service):
|
||||
"""
|
||||
Abstract base class for a async cache.
|
||||
"""
|
||||
|
||||
name = "cache_service"
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, key, lock: Optional[asyncio.Lock] = None):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to retrieve.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def set(self, key, value, lock: Optional[asyncio.Lock] = None):
|
||||
"""
|
||||
Add an item to the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert(self, key, value, lock: Optional[asyncio.Lock] = None):
|
||||
"""
|
||||
Add an item to the cache if it doesn't exist, or update it if it does.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, key, lock: Optional[asyncio.Lock] = None):
|
||||
"""
|
||||
Remove an item from the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to remove.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear(self, lock: Optional[asyncio.Lock] = None):
|
||||
"""
|
||||
Clear all items from the cache.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __contains__(self, key):
|
||||
"""
|
||||
Check if the key is in the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to check.
|
||||
|
||||
Returns:
|
||||
True if the key is in the cache, False otherwise.
|
||||
"""
|
||||
|
|
|
|||
13
src/backend/langflow/services/cache/factory.py
vendored
13
src/backend/langflow/services/cache/factory.py
vendored
|
|
@ -1,6 +1,11 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from langflow.services.cache.service import BaseCacheService, InMemoryCache, RedisCache
|
||||
from langflow.services.cache.service import (
|
||||
AsyncInMemoryCache,
|
||||
BaseCacheService,
|
||||
RedisCache,
|
||||
ThreadingInMemoryCache,
|
||||
)
|
||||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
|
@ -29,7 +34,9 @@ class CacheServiceFactory(ServiceFactory):
|
|||
logger.debug("Redis cache is connected")
|
||||
return redis_cache
|
||||
logger.warning("Redis cache is not connected, falling back to in-memory cache")
|
||||
return InMemoryCache()
|
||||
return ThreadingInMemoryCache()
|
||||
|
||||
elif settings_service.settings.CACHE_TYPE == "memory":
|
||||
return InMemoryCache()
|
||||
return ThreadingInMemoryCache()
|
||||
elif settings_service.settings.CACHE_TYPE == "async":
|
||||
return AsyncInMemoryCache()
|
||||
|
|
|
|||
115
src/backend/langflow/services/cache/service.py
vendored
115
src/backend/langflow/services/cache/service.py
vendored
|
|
@ -1,16 +1,17 @@
|
|||
import asyncio
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.cache.base import BaseCacheService
|
||||
from langflow.services.cache.base import AsyncBaseCacheService, BaseCacheService
|
||||
|
||||
|
||||
class InMemoryCache(BaseCacheService, Service):
|
||||
|
||||
class ThreadingInMemoryCache(BaseCacheService, Service):
|
||||
"""
|
||||
A simple in-memory cache using an OrderedDict.
|
||||
|
||||
|
|
@ -49,7 +50,7 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
self.max_size = max_size
|
||||
self.expiration_time = expiration_time
|
||||
|
||||
def get(self, key):
|
||||
def get(self, key, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
|
|
@ -59,7 +60,7 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
Returns:
|
||||
The value associated with the key, or None if the key is not found or the item has expired.
|
||||
"""
|
||||
with self._lock:
|
||||
with lock or self._lock:
|
||||
return self._get_without_lock(key)
|
||||
|
||||
def _get_without_lock(self, key):
|
||||
|
|
@ -80,7 +81,7 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
self.delete(key)
|
||||
return None
|
||||
|
||||
def set(self, key, value, pickle=False):
|
||||
def set(self, key, value, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Add an item to the cache.
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
key: The key of the item.
|
||||
value: The value to cache.
|
||||
"""
|
||||
with self._lock:
|
||||
with lock or self._lock:
|
||||
if key in self._cache:
|
||||
# Remove existing key before re-inserting to update order
|
||||
self.delete(key)
|
||||
|
|
@ -98,12 +99,10 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
# Remove least recently used item
|
||||
self._cache.popitem(last=False)
|
||||
# pickle locally to mimic Redis
|
||||
if pickle:
|
||||
value = pickle.dumps(value)
|
||||
|
||||
self._cache[key] = {"value": value, "time": time.time()}
|
||||
|
||||
def upsert(self, key, value):
|
||||
def upsert(self, key, value, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Inserts or updates a value in the cache.
|
||||
If the existing value and the new value are both dictionaries, they are merged.
|
||||
|
|
@ -112,7 +111,7 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
key: The key of the item.
|
||||
value: The value to insert or update.
|
||||
"""
|
||||
with self._lock:
|
||||
with lock or self._lock:
|
||||
existing_value = self._get_without_lock(key)
|
||||
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
|
||||
existing_value.update(value)
|
||||
|
|
@ -120,7 +119,7 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
|
||||
self.set(key, value)
|
||||
|
||||
def get_or_set(self, key, value):
|
||||
def get_or_set(self, key, value, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Retrieve an item from the cache. If the item does not exist,
|
||||
set it with the provided value.
|
||||
|
|
@ -132,27 +131,27 @@ class InMemoryCache(BaseCacheService, Service):
|
|||
Returns:
|
||||
The cached value associated with the key.
|
||||
"""
|
||||
with self._lock:
|
||||
with lock or self._lock:
|
||||
if key in self._cache:
|
||||
return self.get(key)
|
||||
self.set(key, value)
|
||||
return value
|
||||
|
||||
def delete(self, key):
|
||||
def delete(self, key, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Remove an item from the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to remove.
|
||||
"""
|
||||
with self._lock:
|
||||
with lock or self._lock:
|
||||
self._cache.pop(key, None)
|
||||
|
||||
def clear(self):
|
||||
def clear(self, lock: Optional[threading.Lock] = None):
|
||||
"""
|
||||
Clear all items from the cache.
|
||||
"""
|
||||
with self._lock:
|
||||
with lock or self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def __contains__(self, key):
|
||||
|
|
@ -323,3 +322,85 @@ class RedisCache(BaseCacheService, Service):
|
|||
def __repr__(self):
|
||||
"""Return a string representation of the RedisCache instance."""
|
||||
return f"RedisCache(expiration_time={self.expiration_time})"
|
||||
|
||||
|
||||
class AsyncInMemoryCache(AsyncBaseCacheService, Service):
|
||||
def __init__(self, max_size=None, expiration_time=3600):
|
||||
self.cache = OrderedDict()
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
self.max_size = max_size
|
||||
self.expiration_time = expiration_time
|
||||
|
||||
async def get(self, key, lock: Optional[asyncio.Lock] = None):
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
return await self._get(key)
|
||||
else:
|
||||
return await self._get(key)
|
||||
|
||||
async def _get(self, key):
|
||||
item = self.cache.get(key, None)
|
||||
if item and (time.time() - item["time"] < self.expiration_time):
|
||||
self.cache.move_to_end(key)
|
||||
return pickle.loads(item["value"]) if isinstance(item["value"], bytes) else item["value"]
|
||||
if item:
|
||||
await self.delete(key)
|
||||
return None
|
||||
|
||||
async def set(self, key, value, lock: Optional[asyncio.Lock] = None):
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._set(
|
||||
key,
|
||||
value,
|
||||
)
|
||||
else:
|
||||
await self._set(
|
||||
key,
|
||||
value,
|
||||
)
|
||||
|
||||
async def _set(self, key, value):
|
||||
if self.max_size and len(self.cache) >= self.max_size:
|
||||
self.cache.popitem(last=False)
|
||||
self.cache[key] = {"value": value, "time": time.time()}
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
async def delete(self, key, lock: Optional[asyncio.Lock] = None):
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._delete(key)
|
||||
else:
|
||||
await self._delete(key)
|
||||
|
||||
async def _delete(self, key):
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
|
||||
async def clear(self, lock: Optional[asyncio.Lock] = None):
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._clear()
|
||||
else:
|
||||
await self._clear()
|
||||
|
||||
async def _clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
async def upsert(self, key, value, lock: Optional[asyncio.Lock] = None):
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._upsert(key, value)
|
||||
else:
|
||||
await self._upsert(key, value)
|
||||
|
||||
async def _upsert(self, key, value):
|
||||
existing_value = await self.get(key)
|
||||
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
|
||||
existing_value.update(value)
|
||||
value = existing_value
|
||||
await self.set(key, value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.cache
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from typing import Any
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.deps import get_cache_service
|
||||
|
|
@ -8,30 +10,30 @@ class ChatService(Service):
|
|||
name = "chat_service"
|
||||
|
||||
def __init__(self):
|
||||
self._cache_locks = defaultdict(asyncio.Lock)
|
||||
self.cache_service = get_cache_service()
|
||||
|
||||
def set_cache(self, client_id: str, data: Any) -> bool:
|
||||
async def set_cache(self, flow_id: str, data: Any, lock: Optional[asyncio.Lock] = None) -> bool:
|
||||
"""
|
||||
Set the cache for a client.
|
||||
"""
|
||||
# client_id is the flow id but that already exists in the cache
|
||||
# so we need to change it to something else
|
||||
|
||||
result_dict = {
|
||||
"result": data,
|
||||
"type": type(data),
|
||||
}
|
||||
self.cache_service.upsert(client_id, result_dict)
|
||||
return client_id in self.cache_service
|
||||
await self.cache_service.upsert(flow_id, result_dict, lock=lock or self._cache_locks[flow_id])
|
||||
return flow_id in self.cache_service
|
||||
|
||||
def get_cache(self, client_id: str) -> Any:
|
||||
async def get_cache(self, flow_id: str, lock: Optional[asyncio.Lock] = None) -> Any:
|
||||
"""
|
||||
Get the cache for a client.
|
||||
"""
|
||||
return self.cache_service.get(client_id)
|
||||
return await self.cache_service.get(flow_id, lock=lock or self._cache_locks[flow_id])
|
||||
|
||||
def clear_cache(self, client_id: str):
|
||||
async def clear_cache(self, flow_id: str, lock: Optional[asyncio.Lock] = None):
|
||||
"""
|
||||
Clear the cache for a client.
|
||||
"""
|
||||
self.cache_service.delete(client_id)
|
||||
self.cache_service.delete(flow_id, lock=lock or self._cache_locks[flow_id])
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class Settings(BaseSettings):
|
|||
|
||||
DEV: bool = False
|
||||
DATABASE_URL: Optional[str] = None
|
||||
CACHE_TYPE: str = "memory"
|
||||
CACHE_TYPE: str = "async"
|
||||
REMOVE_API_KEYS: bool = False
|
||||
COMPONENTS_PATH: List[str] = []
|
||||
LANGCHAIN_CACHE: str = "InMemoryCache"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ const SvgBotMessageSquare = (props) => (
|
|||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
class="lucide lucide-bot-message-square"
|
||||
className="lucide lucide-bot-message-square"
|
||||
{...props}
|
||||
>
|
||||
<path d="M12 6V2H8" />
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
from langflow.graph import Graph
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
def get_graph(_type="basic"):
|
||||
"""Get a graph from a json file"""
|
||||
|
|
@ -38,7 +39,8 @@ def langchain_objects_are_equal(obj1, obj2):
|
|||
|
||||
|
||||
# Test build_graph
|
||||
def test_build_graph(client, basic_data_graph):
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_graph(client, basic_data_graph):
|
||||
graph = Graph.from_payload(basic_data_graph)
|
||||
assert graph is not None
|
||||
assert len(graph.vertices) == len(basic_data_graph["nodes"])
|
||||
|
|
|
|||
|
|
@ -28,9 +28,7 @@ async def test_successful_get_request(api_request):
|
|||
respx.get(url).mock(return_value=Response(200, json=mock_response))
|
||||
|
||||
# Making the request
|
||||
result = await api_request.make_request(
|
||||
client=httpx.AsyncClient(), method=method, url=url
|
||||
)
|
||||
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
|
||||
|
||||
# Assertions
|
||||
assert result.data["status_code"] == 200
|
||||
|
|
@ -46,9 +44,7 @@ async def test_failed_request(api_request):
|
|||
respx.get(url).mock(return_value=Response(404))
|
||||
|
||||
# Making the request
|
||||
result = await api_request.make_request(
|
||||
client=httpx.AsyncClient(), method=method, url=url
|
||||
)
|
||||
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url)
|
||||
|
||||
# Assertions
|
||||
assert result.data["status_code"] == 404
|
||||
|
|
@ -60,14 +56,10 @@ async def test_timeout(api_request):
|
|||
# Mocking a timeout
|
||||
url = "https://example.com/api/timeout"
|
||||
method = "GET"
|
||||
respx.get(url).mock(
|
||||
side_effect=httpx.TimeoutException(message="Timeout", request=None)
|
||||
)
|
||||
respx.get(url).mock(side_effect=httpx.TimeoutException(message="Timeout", request=None))
|
||||
|
||||
# Making the request
|
||||
result = await api_request.make_request(
|
||||
client=httpx.AsyncClient(), method=method, url=url, timeout=1
|
||||
)
|
||||
result = await api_request.make_request(client=httpx.AsyncClient(), method=method, url=url, timeout=1)
|
||||
|
||||
# Assertions
|
||||
assert result.data["status_code"] == 408
|
||||
|
|
@ -106,7 +98,6 @@ def test_directory_component_build_with_multithreading(
|
|||
# Arrange
|
||||
directory_component = data.DirectoryComponent()
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
types = ["py"]
|
||||
depth = 1
|
||||
max_concurrency = 2
|
||||
load_hidden = False
|
||||
|
|
@ -123,7 +114,6 @@ def test_directory_component_build_with_multithreading(
|
|||
# Act
|
||||
directory_component.build(
|
||||
path,
|
||||
types,
|
||||
depth,
|
||||
max_concurrency,
|
||||
load_hidden,
|
||||
|
|
@ -134,9 +124,7 @@ def test_directory_component_build_with_multithreading(
|
|||
|
||||
# Assert
|
||||
mock_resolve_path.assert_called_once_with(path)
|
||||
mock_retrieve_file_paths.assert_called_once_with(
|
||||
path, types, load_hidden, recursive, depth
|
||||
)
|
||||
mock_retrieve_file_paths.assert_called_once_with(path, load_hidden, recursive, depth)
|
||||
mock_parallel_load_records.assert_called_once_with(
|
||||
mock_retrieve_file_paths.return_value, silent_errors, max_concurrency
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue