diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 80eafb1b9..1cf14e85c 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -1,6 +1,15 @@ import time +from typing import Optional -from fastapi import APIRouter, Body, Depends, HTTPException, WebSocket, WebSocketException, status +from fastapi import ( + APIRouter, + Body, + Depends, + HTTPException, + WebSocket, + WebSocketException, + status, +) from fastapi.responses import StreamingResponse from langflow.api.utils import build_input_keys_response, format_elapsed_time from langflow.api.v1.schemas import ( @@ -15,7 +24,10 @@ from langflow.api.v1.schemas import ( from langflow.graph.graph.base import Graph from langflow.graph.vertex.base import StatelessVertex from langflow.processing.process import process_tweaks_on_graph -from langflow.services.auth.utils import get_current_active_user, get_current_user_for_websocket +from langflow.services.auth.utils import ( + get_current_active_user, + get_current_user_for_websocket, +) from langflow.services.cache.service import BaseCacheService from langflow.services.cache.utils import update_build_status from langflow.services.chat.service import ChatService @@ -40,9 +52,13 @@ async def chat( user = await get_current_user_for_websocket(websocket, db) await websocket.accept() if not user: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" + ) elif not user.is_active: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" + ) if client_id in chat_service.cache_service: await chat_service.handle_websocket(client_id, websocket) @@ -58,7 +74,9 @@ async def chat( logger.error(f"Error in chat websocket: {exc}") messsage = exc.detail if isinstance(exc, HTTPException) else str(exc) if "Could not validate credentials" in str(exc): - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" + ) else: await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage) @@ -100,10 +118,15 @@ async def init_build( @router.get("/build/{flow_id}/status", response_model=BuiltResponse) -async def build_status(flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)): +async def build_status( + flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service) +): """Check the flow_id is in the cache_service.""" try: - built = flow_id in cache_service and cache_service[flow_id]["status"] == BuildStatus.SUCCESS + built = ( + flow_id in cache_service + and cache_service[flow_id]["status"] == BuildStatus.SUCCESS + ) return BuiltResponse( built=built, @@ -174,7 +197,9 @@ async def stream_build( valid = True logger.debug(f"Building node {str(vertex.vertex_type)}") - logger.debug(f"Output: {params[:100]}{'...' if len(params) > 100 else ''}") + logger.debug( + f"Output: {params[:100]}{'...' if len(params) > 100 else ''}" + ) if vertex.artifacts: # The artifacts will be prompt variables # passed to build_input_keys_response @@ -187,7 +212,9 @@ async def stream_build( time_elapsed = format_elapsed_time(time.perf_counter() - start_time) update_build_status(cache_service, flow_id, BuildStatus.FAILURE) - vertex_id = vertex.parent_node_id if vertex.parent_is_top_level else vertex.id + vertex_id = ( + vertex.parent_node_id if vertex.parent_is_top_level else vertex.id + ) if vertex_id in graph.top_level_vertices: response = { "valid": valid, @@ -202,7 +229,9 @@ async def stream_build( langchain_object = await graph.build() # Now we need to check the input_keys to send them to the client if hasattr(langchain_object, "input_keys"): - input_keys_response = build_input_keys_response(langchain_object, artifacts) + input_keys_response = build_input_keys_response( + langchain_object, artifacts + ) else: input_keys_response = { "input_keys": None, @@ -249,6 +278,7 @@ async def try_running_celery_task(vertex, user_id): @router.get("/build/{flow_id}/vertices", response_model=VerticesOrderResponse) async def get_vertices( flow_id: str, + component_id: Optional[str] = None, chat_service: "ChatService" = Depends(get_chat_service), session=Depends(get_session), ): @@ -259,7 +289,11 @@ async def get_vertices( raise ValueError("Invalid flow ID") graph = Graph.from_payload(flow.data) chat_service.set_cache(flow_id, graph) - vertices = graph.layered_topological_sort() + + if component_id: + vertices = graph.sort_up_to_vertex(component_id) + else: + vertices = graph.layered_topological_sort() # Now vertices is a list of lists # We need to get the id of each vertex # and return the same structure but only with the ids @@ -276,7 +310,7 @@ async def build_vertex( flow_id: str, vertex_id: str, chat_service: "ChatService" = Depends(get_chat_service), - # current_user=Depends(get_current_active_user), + current_user=Depends(get_current_active_user), tweaks: dict = Body(None), inputs: dict = Body(None), ): @@ -295,7 +329,7 @@ async def build_vertex( raise ValueError("Invalid vertex") try: if isinstance(vertex, StatelessVertex) or not vertex._built: - await vertex.build(user_id=None) + await vertex.build(user_id=current_user.id) params = vertex._built_object_repr() valid = True result_dict = vertex.get_built_result() @@ -305,7 +339,12 @@ async def build_vertex( artifacts = vertex.artifacts timedelta = time.perf_counter() - start_time duration = format_elapsed_time(timedelta) - result_dict = ResultDict(results=result_dict, artifacts=artifacts, duration=duration, timedelta=timedelta) + result_dict = ResultDict( + results=result_dict, + artifacts=artifacts, + duration=duration, + timedelta=timedelta, + ) except Exception as exc: params = str(exc) valid = False