Add component_id to sort endpoint
This commit is contained in:
parent
b86ca31cd4
commit
3d42e1b902
1 changed files with 53 additions and 14 deletions
|
|
@ -1,6 +1,15 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, WebSocket, WebSocketException, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
WebSocket,
|
||||
WebSocketException,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langflow.api.utils import build_input_keys_response, format_elapsed_time
|
||||
from langflow.api.v1.schemas import (
|
||||
|
|
@ -15,7 +24,10 @@ from langflow.api.v1.schemas import (
|
|||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import StatelessVertex
|
||||
from langflow.processing.process import process_tweaks_on_graph
|
||||
from langflow.services.auth.utils import get_current_active_user, get_current_user_for_websocket
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_user,
|
||||
get_current_user_for_websocket,
|
||||
)
|
||||
from langflow.services.cache.service import BaseCacheService
|
||||
from langflow.services.cache.utils import update_build_status
|
||||
from langflow.services.chat.service import ChatService
|
||||
|
|
@ -40,9 +52,13 @@ async def chat(
|
|||
user = await get_current_user_for_websocket(websocket, db)
|
||||
await websocket.accept()
|
||||
if not user:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
elif not user.is_active:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
|
||||
if client_id in chat_service.cache_service:
|
||||
await chat_service.handle_websocket(client_id, websocket)
|
||||
|
|
@ -58,7 +74,9 @@ async def chat(
|
|||
logger.error(f"Error in chat websocket: {exc}")
|
||||
messsage = exc.detail if isinstance(exc, HTTPException) else str(exc)
|
||||
if "Could not validate credentials" in str(exc):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
else:
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage)
|
||||
|
||||
|
|
@ -100,10 +118,15 @@ async def init_build(
|
|||
|
||||
|
||||
@router.get("/build/{flow_id}/status", response_model=BuiltResponse)
|
||||
async def build_status(flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)):
|
||||
async def build_status(
|
||||
flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)
|
||||
):
|
||||
"""Check the flow_id is in the cache_service."""
|
||||
try:
|
||||
built = flow_id in cache_service and cache_service[flow_id]["status"] == BuildStatus.SUCCESS
|
||||
built = (
|
||||
flow_id in cache_service
|
||||
and cache_service[flow_id]["status"] == BuildStatus.SUCCESS
|
||||
)
|
||||
|
||||
return BuiltResponse(
|
||||
built=built,
|
||||
|
|
@ -174,7 +197,9 @@ async def stream_build(
|
|||
valid = True
|
||||
|
||||
logger.debug(f"Building node {str(vertex.vertex_type)}")
|
||||
logger.debug(f"Output: {params[:100]}{'...' if len(params) > 100 else ''}")
|
||||
logger.debug(
|
||||
f"Output: {params[:100]}{'...' if len(params) > 100 else ''}"
|
||||
)
|
||||
if vertex.artifacts:
|
||||
# The artifacts will be prompt variables
|
||||
# passed to build_input_keys_response
|
||||
|
|
@ -187,7 +212,9 @@ async def stream_build(
|
|||
time_elapsed = format_elapsed_time(time.perf_counter() - start_time)
|
||||
update_build_status(cache_service, flow_id, BuildStatus.FAILURE)
|
||||
|
||||
vertex_id = vertex.parent_node_id if vertex.parent_is_top_level else vertex.id
|
||||
vertex_id = (
|
||||
vertex.parent_node_id if vertex.parent_is_top_level else vertex.id
|
||||
)
|
||||
if vertex_id in graph.top_level_vertices:
|
||||
response = {
|
||||
"valid": valid,
|
||||
|
|
@ -202,7 +229,9 @@ async def stream_build(
|
|||
langchain_object = await graph.build()
|
||||
# Now we need to check the input_keys to send them to the client
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
input_keys_response = build_input_keys_response(langchain_object, artifacts)
|
||||
input_keys_response = build_input_keys_response(
|
||||
langchain_object, artifacts
|
||||
)
|
||||
else:
|
||||
input_keys_response = {
|
||||
"input_keys": None,
|
||||
|
|
@ -249,6 +278,7 @@ async def try_running_celery_task(vertex, user_id):
|
|||
@router.get("/build/{flow_id}/vertices", response_model=VerticesOrderResponse)
|
||||
async def get_vertices(
|
||||
flow_id: str,
|
||||
component_id: Optional[str] = None,
|
||||
chat_service: "ChatService" = Depends(get_chat_service),
|
||||
session=Depends(get_session),
|
||||
):
|
||||
|
|
@ -259,7 +289,11 @@ async def get_vertices(
|
|||
raise ValueError("Invalid flow ID")
|
||||
graph = Graph.from_payload(flow.data)
|
||||
chat_service.set_cache(flow_id, graph)
|
||||
vertices = graph.layered_topological_sort()
|
||||
|
||||
if component_id:
|
||||
vertices = graph.sort_up_to_vertex(component_id)
|
||||
else:
|
||||
vertices = graph.layered_topological_sort()
|
||||
# Now vertices is a list of lists
|
||||
# We need to get the id of each vertex
|
||||
# and return the same structure but only with the ids
|
||||
|
|
@ -276,7 +310,7 @@ async def build_vertex(
|
|||
flow_id: str,
|
||||
vertex_id: str,
|
||||
chat_service: "ChatService" = Depends(get_chat_service),
|
||||
# current_user=Depends(get_current_active_user),
|
||||
current_user=Depends(get_current_active_user),
|
||||
tweaks: dict = Body(None),
|
||||
inputs: dict = Body(None),
|
||||
):
|
||||
|
|
@ -295,7 +329,7 @@ async def build_vertex(
|
|||
raise ValueError("Invalid vertex")
|
||||
try:
|
||||
if isinstance(vertex, StatelessVertex) or not vertex._built:
|
||||
await vertex.build(user_id=None)
|
||||
await vertex.build(user_id=current_user.id)
|
||||
params = vertex._built_object_repr()
|
||||
valid = True
|
||||
result_dict = vertex.get_built_result()
|
||||
|
|
@ -305,7 +339,12 @@ async def build_vertex(
|
|||
artifacts = vertex.artifacts
|
||||
timedelta = time.perf_counter() - start_time
|
||||
duration = format_elapsed_time(timedelta)
|
||||
result_dict = ResultDict(results=result_dict, artifacts=artifacts, duration=duration, timedelta=timedelta)
|
||||
result_dict = ResultDict(
|
||||
results=result_dict,
|
||||
artifacts=artifacts,
|
||||
duration=duration,
|
||||
timedelta=timedelta,
|
||||
)
|
||||
except Exception as exc:
|
||||
params = str(exc)
|
||||
valid = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue