diff --git a/src/backend/langflow/services/socket/service.py b/src/backend/langflow/services/socket/service.py index 14ebcaaea..a84c9ca86 100644 --- a/src/backend/langflow/services/socket/service.py +++ b/src/backend/langflow/services/socket/service.py @@ -1,17 +1,11 @@ -import time -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import socketio -from langflow.api.utils import format_elapsed_time -from langflow.api.v1.schemas import ResultDict, VertexBuildResponse -from langflow.graph.graph.base import Graph -from langflow.graph.vertex.base import StatelessVertex -from langflow.services.base import Service -from langflow.services.database.models.flow.model import Flow -from langflow.services.deps import get_chat_service, get_session -from langflow.services.monitor.utils import log_vertex_build from loguru import logger -from sqlmodel import select + +from langflow.services.base import Service +from langflow.services.deps import get_chat_service +from langflow.services.socket.utils import build_vertex, get_vertices if TYPE_CHECKING: from langflow.services.cache.service import BaseCacheService @@ -89,81 +83,3 @@ class SocketIOService(Service): } self.cache_service.upsert(sid, result_dict) return sid in self.cache_service - - -async def build_vertex( - sio: socketio.AsyncServer, - sid: str, - flow_id: str, - vertex_id: str, - get_cache: Callable, - set_cache: Callable, - tweaks=None, - inputs=None, -): - try: - cache = get_cache(flow_id) - graph = cache.get("result") - - if not isinstance(graph, Graph): - await sio.emit("error", data="Invalid graph", to=sid) - return - - vertex = graph.get_vertex(vertex_id) - if not vertex: - await sio.emit("error", data="Invalid vertex", to=sid) - return - start_time = time.perf_counter() - try: - if isinstance(vertex, StatelessVertex) or not vertex._built: - await vertex.build(user_id=None) - params = vertex._built_object_repr() - valid = True - result_dict = vertex.get_built_result() - # We need to set the artifacts to pass information - # to the frontend - vertex.set_artifacts() - artifacts = vertex.artifacts - timedelta = time.perf_counter() - start_time - duration = format_elapsed_time(timedelta) - result_dict = ResultDict(results=result_dict, artifacts=artifacts, duration=duration, timedelta=timedelta) - except Exception as exc: - params = str(exc) - valid = False - result_dict = ResultDict(results={}) - artifacts = {} - set_cache(flow_id, graph) - await log_vertex_build( - flow_id=flow_id, - vertex_id=vertex_id, - valid=valid, - params=params, - data=result_dict, - artifacts=artifacts, - ) - - # Emit the vertex build response - response = VertexBuildResponse(valid=valid, params=params, id=vertex.id, data=result_dict) - await sio.emit("vertex_build", data=response.model_dump(), to=sid) - - except Exception as exc: - await sio.emit("error", data=str(exc), to=sid) - - -async def get_vertices(sio, sid, flow_id, chat_service): - try: - session = get_session() - flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first() - if not flow or not flow.data: - await sio.emit("error", data="Invalid flow ID", to=sid) - return - - graph = Graph.from_payload(flow.data) - chat_service.set_cache(flow_id, graph) - vertices = graph.layered_topological_sort() - - # Emit the vertices to the client - await sio.emit("vertices_order", data=vertices, to=sid) - - except Exception as exc: - await sio.emit("error", data=str(exc), to=sid) diff --git a/src/backend/langflow/services/socket/utils.py b/src/backend/langflow/services/socket/utils.py new file mode 100644 index 000000000..6db623392 --- /dev/null +++ b/src/backend/langflow/services/socket/utils.py @@ -0,0 +1,97 @@ +import time +from typing import Callable + +import socketio +from langflow.api.utils import format_elapsed_time +from langflow.api.v1.schemas import ResultDict, VertexBuildResponse +from langflow.graph.graph.base import Graph +from langflow.graph.vertex.base import StatelessVertex +from langflow.services.database.models.flow.model import Flow +from langflow.services.deps import get_session +from langflow.services.monitor.utils import log_vertex_build +from sqlmodel import select + + +def set_socketio_server(socketio_server): + from langflow.services.deps import get_socket_service + + socket_service = get_socket_service() + socket_service.init(socketio_server) + + +async def get_vertices(sio, sid, flow_id, chat_service): + try: + session = get_session() + flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first() + if not flow or not flow.data: + await sio.emit("error", data="Invalid flow ID", to=sid) + return + + graph = Graph.from_payload(flow.data) + chat_service.set_cache(flow_id, graph) + vertices = graph.layered_topological_sort() + + # Emit the vertices to the client + await sio.emit("vertices_order", data=vertices, to=sid) + + except Exception as exc: + await sio.emit("error", data=str(exc), to=sid) + + +async def build_vertex( + sio: socketio.AsyncServer, + sid: str, + flow_id: str, + vertex_id: str, + get_cache: Callable, + set_cache: Callable, + tweaks=None, + inputs=None, +): + try: + cache = get_cache(flow_id) + graph = cache.get("result") + + if not isinstance(graph, Graph): + await sio.emit("error", data="Invalid graph", to=sid) + return + + vertex = graph.get_vertex(vertex_id) + if not vertex: + await sio.emit("error", data="Invalid vertex", to=sid) + return + start_time = time.perf_counter() + try: + if isinstance(vertex, StatelessVertex) or not vertex._built: + await vertex.build(user_id=None) + params = vertex._built_object_repr() + valid = True + result_dict = vertex.get_built_result() + # We need to set the artifacts to pass information + # to the frontend + vertex.set_artifacts() + artifacts = vertex.artifacts + timedelta = time.perf_counter() - start_time + duration = format_elapsed_time(timedelta) + result_dict = ResultDict(results=result_dict, artifacts=artifacts, duration=duration, timedelta=timedelta) + except Exception as exc: + params = str(exc) + valid = False + result_dict = ResultDict(results={}) + artifacts = {} + set_cache(flow_id, graph) + await log_vertex_build( + flow_id=flow_id, + vertex_id=vertex_id, + valid=valid, + params=params, + data=result_dict, + artifacts=artifacts, + ) + + # Emit the vertex build response + response = VertexBuildResponse(valid=valid, params=params, id=vertex.id, data=result_dict) + await sio.emit("vertex_build", data=response.model_dump(), to=sid) + + except Exception as exc: + await sio.emit("error", data=str(exc), to=sid)