From fe391ffa80890ee4bbb268b476f863ecea10e1a0 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 3 Mar 2024 19:42:07 -0300 Subject: [PATCH] Flow can now be built from a starting node or from an end node --- src/backend/langflow/api/v1/chat.py | 5 ++-- src/backend/langflow/graph/graph/base.py | 23 +++++++++++++++---- .../src/CustomNodes/GenericNode/index.tsx | 2 +- src/frontend/src/components/IOview/index.tsx | 7 ++++-- src/frontend/src/controllers/API/index.ts | 9 +++++--- src/frontend/src/stores/flowStore.ts | 10 +++++--- src/frontend/src/types/zustand/flow/index.ts | 5 +++- src/frontend/src/utils/buildUtils.ts | 21 +++++++++++------ 8 files changed, 59 insertions(+), 23 deletions(-) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index fa39a4d4a..d71e4a324 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -50,6 +50,7 @@ async def try_running_celery_task(vertex, user_id): async def get_vertices( flow_id: str, stop_component_id: Optional[str] = None, + start_component_id: Optional[str] = None, chat_service: "ChatService" = Depends(get_chat_service), session=Depends(get_session), ): @@ -60,9 +61,9 @@ async def get_vertices( if cache := chat_service.get_cache(flow_id): graph = cache.get("result") graph = build_and_cache_graph(flow_id, session, chat_service, graph) - if stop_component_id: + if stop_component_id or start_component_id: try: - vertices = graph.sort_vertices(stop_component_id) + vertices = graph.sort_vertices(stop_component_id, start_component_id) except Exception as exc: logger.error(exc) vertices = graph.sort_vertices() diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index aeb03e093..45612175f 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -630,7 +630,7 @@ class Graph: ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" - def sort_up_to_vertex(self, vertex_id: str) -> List[Vertex]: + def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]: """Cuts the graph up to a given vertex and sorts the resulting subgraph.""" # Initial setup visited = set() # To keep track of visited vertices @@ -664,11 +664,19 @@ class Graph: if current_id == vertex_id: # We should add to visited all the vertices that are successors of the current vertex # and their successors and so on + # if the vertex is a start, it means we are starting from the beginning + # and getting successors for successor in current_vertex.successors: - excluded.add(successor.id) + if is_start: + stack.append(successor.id) + else: + excluded.add(successor.id) all_successors = get_successors(successor) for successor in all_successors: - excluded.add(successor.id) + if is_start: + stack.append(successor.id) + else: + excluded.add(successor.id) # Filter the original graph's vertices and edges to keep only those in `visited` vertices_to_keep = [self.get_vertex(vid) for vid in visited] @@ -764,12 +772,19 @@ class Graph: return vertices_layers - def sort_vertices(self, stop_component_id: Optional[str] = None) -> List[List[str]]: + def sort_vertices( + self, + stop_component_id: Optional[str] = None, + start_component_id: Optional[str] = None, + ) -> List[str]: """Sorts the vertices in the graph.""" self.mark_all_vertices("ACTIVE") if stop_component_id: self.stop_vertex = stop_component_id vertices = self.sort_up_to_vertex(stop_component_id) + elif start_component_id: + vertices = self.sort_up_to_vertex(start_component_id) + else: vertices = self.vertices vertices_layers = self.layered_topological_sort(vertices) diff --git a/src/frontend/src/CustomNodes/GenericNode/index.tsx b/src/frontend/src/CustomNodes/GenericNode/index.tsx index 376c6b76d..1c27ed847 100644 --- a/src/frontend/src/CustomNodes/GenericNode/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/index.tsx @@ -509,7 +509,7 @@ export default function GenericNode({ if (buildStatus === BuildStatus.BUILDING || isBuilding) return; setValidationStatus(null); - buildFlow({ nodeId: data.id }); + buildFlow({ stopNodeId: data.id }); }} className="generic-node-status-position flex items-center justify-center" > diff --git a/src/frontend/src/components/IOview/index.tsx b/src/frontend/src/components/IOview/index.tsx index 56dee33d6..3500160a0 100644 --- a/src/frontend/src/components/IOview/index.tsx +++ b/src/frontend/src/components/IOview/index.tsx @@ -70,7 +70,10 @@ export default function IOView({ children, open, setOpen }): JSX.Element { setLockChat(true); setChatValue(""); for (let i = 0; i < count; i++) { - await buildFlow({ input_value: chatValue }).catch((err) => { + await buildFlow({ + input_value: chatValue, + startNodeId: chatInput?.id, + }).catch((err) => { console.error(err); setLockChat(false); }); @@ -273,7 +276,7 @@ export default function IOView({ children, open, setOpen }): JSX.Element { )} {haveChat ? ( -
+
{selectedViewField && (
> { // nodeId is optional and is a query parameter // if nodeId is not provided, the API will return all vertices const config = {}; - if (nodeId) { - config["params"] = { stop_component_id: nodeId }; + if (stopNodeId) { + config["params"] = { stop_component_id: stopNodeId }; + } else if (startNodeId) { + config["params"] = { start_component_id: startNodeId }; } return await api.get(`${BASE_URL_API}build/${flowId}/vertices`, config); } diff --git a/src/frontend/src/stores/flowStore.ts b/src/frontend/src/stores/flowStore.ts index 826e81590..f1cb1e8c4 100644 --- a/src/frontend/src/stores/flowStore.ts +++ b/src/frontend/src/stores/flowStore.ts @@ -416,10 +416,12 @@ const useFlowStore = create((set, get) => ({ }); }, buildFlow: async ({ - nodeId, + startNodeId, + stopNodeId, input_value, }: { - nodeId?: string; + startNodeId?: string; + stopNodeId?: string; input_value?: string; }) => { get().setIsBuilding(true); @@ -480,11 +482,13 @@ const useFlowStore = create((set, get) => ({ await buildVertices({ input_value, flowId: currentFlow!.id, - nodeId, + startNodeId, + stopNodeId, onGetOrderSuccess: () => { setNoticeData({ title: "Running components" }); }, onBuildComplete: () => { + const nodeId = startNodeId || stopNodeId; if (nodeId) { setSuccessData({ title: `${ diff --git a/src/frontend/src/types/zustand/flow/index.ts b/src/frontend/src/types/zustand/flow/index.ts index bd4595d32..d7469f7a7 100644 --- a/src/frontend/src/types/zustand/flow/index.ts +++ b/src/frontend/src/types/zustand/flow/index.ts @@ -90,10 +90,13 @@ export type FlowStoreType = { onConnect: (connection: Connection) => void; unselectAll: () => void; buildFlow: ({ - nodeId, + startNodeId, + stopNodeId, input_value, }: { nodeId?: string; + startNodeId?: string; + stopNodeId?: string; input_value?: string; }) => Promise; getFlow: () => { nodes: Node[]; edges: Edge[]; viewport: Viewport }; diff --git a/src/frontend/src/utils/buildUtils.ts b/src/frontend/src/utils/buildUtils.ts index dbe2a360d..e631c91de 100644 --- a/src/frontend/src/utils/buildUtils.ts +++ b/src/frontend/src/utils/buildUtils.ts @@ -8,7 +8,8 @@ import { VertexBuildTypeAPI } from "../types/api"; type BuildVerticesParams = { flowId: string; // Assuming FlowType is the type for your flow input_value?: any; // Replace any with the actual type if it's not any - nodeId?: string | null; // Assuming nodeId is of type string, and it's optional + startNodeId?: string | null; // Assuming nodeId is of type string, and it's optional + stopNodeId?: string | null; // Assuming nodeId is of type string, and it's optional onGetOrderSuccess?: () => void; onBuildUpdate?: ( data: VertexBuildTypeAPI, @@ -43,7 +44,8 @@ function getInactiveVertexData(vertexId: string): VertexBuildTypeAPI { export async function updateVerticesOrder( flowId: string, - nodeId: string | null + startNodeId?: string | null, + stopNodeId?: string | null ): Promise<{ verticesLayers: string[][]; verticesIds: string[]; @@ -53,7 +55,7 @@ export async function updateVerticesOrder( const setErrorData = useAlertStore.getState().setErrorData; let orderResponse; try { - orderResponse = await getVerticesOrder(flowId, nodeId); + orderResponse = await getVerticesOrder(flowId, startNodeId, stopNodeId); } catch (error: any) { console.log(error); setErrorData({ @@ -95,7 +97,8 @@ export async function updateVerticesOrder( export async function buildVertices({ flowId, input_value, - nodeId = null, + startNodeId, + stopNodeId, onGetOrderSuccess, onBuildUpdate, onBuildComplete, @@ -104,9 +107,13 @@ export async function buildVertices({ validateNodes, }: BuildVerticesParams) { let verticesBuild = useFlowStore.getState().verticesBuild; - - if (!verticesBuild || nodeId) { - verticesBuild = await updateVerticesOrder(flowId, nodeId); + // if startNodeId and stopNodeId are provided + // something is wrong + if (startNodeId && stopNodeId) { + return; + } + if (!verticesBuild || startNodeId || stopNodeId) { + verticesBuild = await updateVerticesOrder(flowId, startNodeId, stopNodeId); } const verticesIds = verticesBuild?.verticesIds!;