Merge remote-tracking branch 'origin/dev' into two_edges

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-06-19 09:50:15 -03:00
commit cedab5c9d0
4 changed files with 43 additions and 55 deletions

View file

@ -1,5 +1,6 @@
import time
import uuid
from functools import partial
from typing import TYPE_CHECKING, Annotated, Optional
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
@ -166,8 +167,6 @@ async def build_vertex(
try:
lock = chat_service._cache_locks[flow_id_str]
(
next_runnable_vertices,
top_level_vertices,
result_dict,
params,
valid,
@ -181,6 +180,13 @@ async def build_vertex(
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
)
set_cache_coro = partial(get_chat_service().set_cache, key=flow_id_str)
next_runnable_vertices = await graph.run_manager.get_next_runnable_vertices(
lock, set_cache_coro, graph=graph, vertex=vertex, cache=False
)
top_level_vertices = graph.run_manager.get_top_level_vertices(graph, next_runnable_vertices)
result_data_response = ResultDataResponse(**result_dict.model_dump())
result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True)
except Exception as exc:
@ -214,7 +220,6 @@ async def build_vertex(
result_data_response.duration = duration
result_data_response.timedelta = timedelta
vertex.add_build_time(timedelta)
inactivated_vertices = None
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()

View file

@ -3,7 +3,7 @@ import uuid
from collections import defaultdict, deque
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Generator, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Tuple, Type, Union
from loguru import logger
@ -259,6 +259,7 @@ class Graph:
Returns:
List[Optional["ResultData"]]: The outputs of the graph.
"""
if input_components and not isinstance(input_components, list):
raise ValueError(f"Invalid components value: {input_components}. Expected list")
elif input_components is None:
@ -765,14 +766,12 @@ class Graph:
async def build_vertex(
self,
lock: asyncio.Lock,
chat_service: ChatService,
vertex_id: str,
inputs_dict: Optional[Dict[str, str]] = None,
files: Optional[list[str]] = None,
user_id: Optional[str] = None,
fallback_to_env_vars: bool = False,
cache: bool = True,
):
"""
Builds a vertex in the graph.
@ -827,43 +826,15 @@ class Graph:
artifacts = vertex.artifacts
else:
raise ValueError(f"No result found for vertex {vertex_id}")
set_cache_coro = partial(chat_service.set_cache, key=self.flow_id)
next_runnable_vertices, top_level_vertices = await self.get_next_and_top_level_vertices(
lock, set_cache_coro, vertex, cache=cache
)
flow_id = self.flow_id
log_transaction(flow_id, vertex, status="success")
return next_runnable_vertices, top_level_vertices, result_dict, params, valid, artifacts, vertex
return result_dict, params, valid, artifacts, vertex
except Exception as exc:
logger.exception(f"Error building Component: {exc}")
flow_id = self.flow_id
log_transaction(flow_id, vertex, status="failure", error=str(exc))
raise exc
async def get_next_and_top_level_vertices(
self,
lock: asyncio.Lock,
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
vertex: Vertex,
cache: bool = True,
):
"""
Retrieves the next runnable vertices and the top level vertices for a given vertex.
Args:
lock (asyncio.Lock): The lock used to synchronize access to the graph.
set_cache_coro (Coroutine): The coroutine used to set the cache for the graph.
vertex (Vertex): The vertex for which to retrieve the next runnable and top level vertices.
Returns:
Tuple[List[Vertex], List[Vertex]]: A tuple containing the next runnable vertices and the top level vertices.
"""
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
lock, set_cache_coro, self, vertex, cache=cache
)
top_level_vertices = self.run_manager.get_top_level_vertices(self, next_runnable_vertices)
return next_runnable_vertices, top_level_vertices
def get_vertex_edges(
self,
vertex_id: str,
@ -910,13 +881,11 @@ class Graph:
vertex = self.get_vertex(vertex_id)
task = asyncio.create_task(
self.build_vertex(
lock=lock,
chat_service=chat_service,
vertex_id=vertex_id,
user_id=self.user_id,
inputs_dict={},
fallback_to_env_vars=fallback_to_env_vars,
cache=False,
),
name=f"{vertex.display_name} Run {vertex_task_run_count.get(vertex_id, 0)}",
)
@ -925,7 +894,7 @@ class Graph:
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
try:
next_runnable_vertices = await self._execute_tasks(tasks)
next_runnable_vertices = await self._execute_tasks(tasks, lock=lock)
except Exception as e:
logger.error(f"Error executing tasks in layer {layer_index}: {e}")
break
@ -937,10 +906,11 @@ class Graph:
logger.debug("Graph processing complete")
return self
async def _execute_tasks(self, tasks: List[asyncio.Task]) -> List[str]:
async def _execute_tasks(self, tasks: List[asyncio.Task], lock: asyncio.Lock) -> List[str]:
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
vertices: List[Vertex] = []
for i, result in enumerate(completed_tasks):
task_name = tasks[i].get_name()
@ -950,13 +920,23 @@ class Graph:
for t in tasks[i + 1 :]:
t.cancel()
raise result
elif isinstance(result, tuple) and len(result) == 7:
# Get the next runnable vertices
next_runnable_vertices = result[0]
results.extend(next_runnable_vertices)
elif isinstance(result, tuple) and len(result) == 5:
vertices.append(result[4])
else:
raise ValueError(f"Invalid result from task {task_name}: {result}")
for v in vertices:
# set all executed vertices as non-runnable to not run them again.
# they could be calculated as predecessor or successors of parallel vertices
# This could usually happen with input vertices like ChatInput
self.run_manager.remove_vertex_from_runnables(v.id)
set_cache_coro = partial(get_chat_service().set_cache, key=self.flow_id)
for v in vertices:
next_runnable_vertices = await self.run_manager.get_next_runnable_vertices(
lock, set_cache_coro, graph=self, vertex=v, cache=False
)
results.extend(next_runnable_vertices)
return results
def topological_sort(self) -> List[Vertex]:
@ -1438,5 +1418,3 @@ class Graph:
predecessor_map[edge.target_id].append(edge.source_id)
successor_map[edge.source_id].append(edge.target_id)
return predecessor_map, successor_map
return predecessor_map, successor_map
return predecessor_map, successor_map

View file

@ -1,6 +1,6 @@
import asyncio
from collections import defaultdict
from typing import TYPE_CHECKING, Awaitable, Callable, List
from typing import TYPE_CHECKING, Callable, List, Coroutine
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
@ -83,20 +83,20 @@ class RunnableVerticesManager:
async def get_next_runnable_vertices(
self,
lock: asyncio.Lock,
set_cache_coro: Callable[["Graph", asyncio.Lock], Awaitable[None]],
set_cache_coro: Callable[["Graph", asyncio.Lock], Coroutine],
graph: "Graph",
vertex: "Vertex",
cache: bool = True,
):
) -> List[str]:
"""
Retrieves the next runnable vertices in the graph for a given vertex.
Args:
graph (Graph): The graph object representing the flow.
vertex (Vertex): The current vertex.
vertex_id (str): The ID of the current vertex.
chat_service (ChatService): The chat service object.
flow_id (str): The ID of the flow.
lock (asyncio.Lock): The lock object to be used for synchronization.
set_cache_coro (Callable): The coroutine function to set the cache.
graph (Graph): The graph object containing the vertices.
vertex (Vertex): The vertex object for which the next runnable vertices are to be retrieved.
cache (bool, optional): A flag to indicate if the cache should be updated. Defaults to True.
Returns:
list: A list of IDs of the next runnable vertices.
@ -112,12 +112,15 @@ class RunnableVerticesManager:
next_runnable_vertices = direct_successors_ready
for v_id in set(next_runnable_vertices): # Use set to avoid duplicates
self.update_vertex_run_state(v_id, is_runnable=False)
self.remove_from_predecessors(v_id)
self.remove_vertex_from_runnables(v_id)
if cache:
await set_cache_coro(data=graph, lock=lock) # type: ignore
return next_runnable_vertices
def remove_vertex_from_runnables(self, v_id):
self.update_vertex_run_state(v_id, is_runnable=False)
self.remove_from_predecessors(v_id)
@staticmethod
def get_top_level_vertices(graph, vertices_ids):
"""

View file

@ -510,6 +510,8 @@ def load_flows_from_directory():
if existing:
logger.info(f"Updating existing flow: {flow_id} with endpoint name {flow_endpoint_name}")
for key, value in flow.items():
if key == "last_tested_version":
continue
setattr(existing, key, value)
existing.updated_at = datetime.utcnow()
existing.user_id = user_id