Flow can now be built from a starting node or from an end node

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-03 19:42:07 -03:00
commit fe391ffa80
8 changed files with 59 additions and 23 deletions

View file

@ -50,6 +50,7 @@ async def try_running_celery_task(vertex, user_id):
async def get_vertices(
flow_id: str,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
session=Depends(get_session),
):
@ -60,9 +61,9 @@ async def get_vertices(
if cache := chat_service.get_cache(flow_id):
graph = cache.get("result")
graph = build_and_cache_graph(flow_id, session, chat_service, graph)
if stop_component_id:
if stop_component_id or start_component_id:
try:
vertices = graph.sort_vertices(stop_component_id)
vertices = graph.sort_vertices(stop_component_id, start_component_id)
except Exception as exc:
logger.error(exc)
vertices = graph.sort_vertices()

View file

@ -630,7 +630,7 @@ class Graph:
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str) -> List[Vertex]:
def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
# Initial setup
visited = set() # To keep track of visited vertices
@ -664,11 +664,19 @@ class Graph:
if current_id == vertex_id:
# We should add to visited all the vertices that are successors of the current vertex
# and their successors and so on
# if the vertex is a start, it means we are starting from the beginning
# and getting successors
for successor in current_vertex.successors:
excluded.add(successor.id)
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
all_successors = get_successors(successor)
for successor in all_successors:
excluded.add(successor.id)
if is_start:
stack.append(successor.id)
else:
excluded.add(successor.id)
# Filter the original graph's vertices and edges to keep only those in `visited`
vertices_to_keep = [self.get_vertex(vid) for vid in visited]
@ -764,12 +772,19 @@ class Graph:
return vertices_layers
def sort_vertices(self, stop_component_id: Optional[str] = None) -> List[List[str]]:
def sort_vertices(
self,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
) -> List[str]:
"""Sorts the vertices in the graph."""
self.mark_all_vertices("ACTIVE")
if stop_component_id:
self.stop_vertex = stop_component_id
vertices = self.sort_up_to_vertex(stop_component_id)
elif start_component_id:
vertices = self.sort_up_to_vertex(start_component_id)
else:
vertices = self.vertices
vertices_layers = self.layered_topological_sort(vertices)

View file

@ -509,7 +509,7 @@ export default function GenericNode({
if (buildStatus === BuildStatus.BUILDING || isBuilding)
return;
setValidationStatus(null);
buildFlow({ nodeId: data.id });
buildFlow({ stopNodeId: data.id });
}}
className="generic-node-status-position flex items-center justify-center"
>

View file

@ -70,7 +70,10 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
setLockChat(true);
setChatValue("");
for (let i = 0; i < count; i++) {
await buildFlow({ input_value: chatValue }).catch((err) => {
await buildFlow({
input_value: chatValue,
startNodeId: chatInput?.id,
}).catch((err) => {
console.error(err);
setLockChat(false);
});
@ -273,7 +276,7 @@ export default function IOView({ children, open, setOpen }): JSX.Element {
)}
{haveChat ? (
<div className="flex h-full flex-grow min-w-96">
<div className="flex h-full min-w-96 flex-grow">
{selectedViewField && (
<div
className={cn(

View file

@ -856,13 +856,16 @@ export async function requestLogout() {
export async function getVerticesOrder(
flowId: string,
nodeId?: string | null
startNodeId?: string | null,
stopNodeId?: string | null
): Promise<AxiosResponse<VerticesOrderTypeAPI>> {
// nodeId is optional and is a query parameter
// if nodeId is not provided, the API will return all vertices
const config = {};
if (nodeId) {
config["params"] = { stop_component_id: nodeId };
if (stopNodeId) {
config["params"] = { stop_component_id: stopNodeId };
} else if (startNodeId) {
config["params"] = { start_component_id: startNodeId };
}
return await api.get(`${BASE_URL_API}build/${flowId}/vertices`, config);
}

View file

@ -416,10 +416,12 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
});
},
buildFlow: async ({
nodeId,
startNodeId,
stopNodeId,
input_value,
}: {
nodeId?: string;
startNodeId?: string;
stopNodeId?: string;
input_value?: string;
}) => {
get().setIsBuilding(true);
@ -480,11 +482,13 @@ const useFlowStore = create<FlowStoreType>((set, get) => ({
await buildVertices({
input_value,
flowId: currentFlow!.id,
nodeId,
startNodeId,
stopNodeId,
onGetOrderSuccess: () => {
setNoticeData({ title: "Running components" });
},
onBuildComplete: () => {
const nodeId = startNodeId || stopNodeId;
if (nodeId) {
setSuccessData({
title: `${

View file

@ -90,10 +90,13 @@ export type FlowStoreType = {
onConnect: (connection: Connection) => void;
unselectAll: () => void;
buildFlow: ({
nodeId,
startNodeId,
stopNodeId,
input_value,
}: {
nodeId?: string;
startNodeId?: string;
stopNodeId?: string;
input_value?: string;
}) => Promise<void>;
getFlow: () => { nodes: Node[]; edges: Edge[]; viewport: Viewport };

View file

@ -8,7 +8,8 @@ import { VertexBuildTypeAPI } from "../types/api";
type BuildVerticesParams = {
flowId: string; // Assuming FlowType is the type for your flow
input_value?: any; // Replace any with the actual type if it's not any
nodeId?: string | null; // Assuming nodeId is of type string, and it's optional
startNodeId?: string | null; // Assuming nodeId is of type string, and it's optional
stopNodeId?: string | null; // Assuming nodeId is of type string, and it's optional
onGetOrderSuccess?: () => void;
onBuildUpdate?: (
data: VertexBuildTypeAPI,
@ -43,7 +44,8 @@ function getInactiveVertexData(vertexId: string): VertexBuildTypeAPI {
export async function updateVerticesOrder(
flowId: string,
nodeId: string | null
startNodeId?: string | null,
stopNodeId?: string | null
): Promise<{
verticesLayers: string[][];
verticesIds: string[];
@ -53,7 +55,7 @@ export async function updateVerticesOrder(
const setErrorData = useAlertStore.getState().setErrorData;
let orderResponse;
try {
orderResponse = await getVerticesOrder(flowId, nodeId);
orderResponse = await getVerticesOrder(flowId, startNodeId, stopNodeId);
} catch (error: any) {
console.log(error);
setErrorData({
@ -95,7 +97,8 @@ export async function updateVerticesOrder(
export async function buildVertices({
flowId,
input_value,
nodeId = null,
startNodeId,
stopNodeId,
onGetOrderSuccess,
onBuildUpdate,
onBuildComplete,
@ -104,9 +107,13 @@ export async function buildVertices({
validateNodes,
}: BuildVerticesParams) {
let verticesBuild = useFlowStore.getState().verticesBuild;
if (!verticesBuild || nodeId) {
verticesBuild = await updateVerticesOrder(flowId, nodeId);
// if startNodeId and stopNodeId are provided
// something is wrong
if (startNodeId && stopNodeId) {
return;
}
if (!verticesBuild || startNodeId || stopNodeId) {
verticesBuild = await updateVerticesOrder(flowId, startNodeId, stopNodeId);
}
const verticesIds = verticesBuild?.verticesIds!;