Refactor graph building and caching

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-20 13:13:01 -03:00
commit 5cef5b868a

View file

@ -22,7 +22,6 @@ from langflow.api.v1.schemas import (
VerticesOrderResponse,
)
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import StatelessVertex
from langflow.processing.process import process_tweaks_on_graph
from langflow.services.auth.utils import (
get_current_active_user,
@ -275,6 +274,25 @@ async def try_running_celery_task(vertex, user_id):
return vertex
def build_and_cache_graph(
flow_id: str,
session: Session,
chat_service: "ChatService",
graph: Optional[Graph] = None,
):
"""Build and cache the graph."""
flow: Flow = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
other_graph = Graph.from_payload(flow.data)
if graph is None:
graph = other_graph
else:
graph = graph.update(other_graph)
chat_service.set_cache(flow_id, graph)
return graph
@router.get("/build/{flow_id}/vertices", response_model=VerticesOrderResponse)
async def get_vertices(
flow_id: str,
@ -288,17 +306,7 @@ async def get_vertices(
graph = None
if cache := chat_service.get_cache(flow_id):
graph: Graph = cache.get("result")
flow: Flow = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
other_graph = Graph.from_payload(flow.data)
if graph is None:
graph = other_graph
else:
graph = graph.update(other_graph)
chat_service.set_cache(flow_id, graph)
graph = build_and_cache_graph(flow_id, session, chat_service, graph)
if component_id:
try:
vertices = graph.sort_up_to_vertex(component_id)
@ -332,34 +340,48 @@ async def build_vertex(
"""Build a vertex instead of the entire graph."""
try:
cache = chat_service.get_cache(flow_id)
graph = cache.get("result")
if not cache:
# If there's no cache
logger.warning(
f"No cache found for {flow_id}. Building graph starting at {vertex_id}"
)
graph = build_and_cache_graph(
flow_id=flow_id, session=get_session(), chat_service=chat_service
)
else:
graph = cache.get("result")
result_dict = {}
duration = ""
start_time = time.perf_counter()
if tweaks:
graph = process_tweaks_on_graph(graph, tweaks)
if not isinstance(graph, Graph):
raise ValueError("Invalid graph")
if not (vertex := graph.get_vertex(vertex_id)):
raise ValueError("Invalid vertex")
raise ValueError(f"Invalid vertex {vertex_id}")
try:
if isinstance(vertex, StatelessVertex) or not vertex._built:
if not vertex.pinned or not vertex._built:
await vertex.build(user_id=current_user.id)
params = vertex._built_object_repr()
valid = True
result_dict = vertex.get_built_result()
# We need to set the artifacts to pass information
# to the frontend
vertex.set_artifacts()
artifacts = vertex.artifacts
timedelta = time.perf_counter() - start_time
duration = format_elapsed_time(timedelta)
result_dict = ResultDict(
results=result_dict,
artifacts=artifacts,
duration=duration,
timedelta=timedelta,
)
params = vertex._built_object_repr()
valid = True
result_dict = vertex.get_built_result()
# We need to set the artifacts to pass information
# to the frontend
vertex.set_artifacts()
artifacts = vertex.artifacts
timedelta = time.perf_counter() - start_time
duration = format_elapsed_time(timedelta)
result_dict = ResultDict(
results=result_dict,
artifacts=artifacts,
duration=duration,
timedelta=timedelta,
)
vertex.set_result(result_dict)
elif vertex.result is not None:
params = vertex._built_object_repr()
valid = True
result_dict = vertex.result
else:
raise ValueError(f"No result found for vertex {vertex_id}")
chat_service.set_cache(flow_id, graph)
except Exception as exc:
params = str(exc)