-
diff --git a/src/frontend/src/components/chatComponent/index.tsx b/src/frontend/src/components/chatComponent/index.tsx
index 61d0f35a1..610ea5c4c 100644
--- a/src/frontend/src/components/chatComponent/index.tsx
+++ b/src/frontend/src/components/chatComponent/index.tsx
@@ -1,13 +1,18 @@
-import { useEffect, useRef, useState } from "react";
-
+import { Context, useEffect, useRef, useState, useContext } from "react";
+import ReactFlow, { useNodes } from "reactflow";
import { ChatMessageType, ChatType } from "../../types/chat";
import ChatTrigger from "./chatTrigger";
+import BuildTrigger from "./buildTrigger";
import ChatModal from "../../modals/chatModal";
-import _ from "lodash";
+import _, { set } from "lodash";
+import { getBuildStatus } from "../../controllers/API";
+import { NodeType } from "../../types/flow";
export default function Chat({ flow }: ChatType) {
const [open, setOpen] = useState(false);
+ const [isBuilt, setIsBuilt] = useState(false);
+
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (
@@ -23,10 +28,58 @@ export default function Chat({ flow }: ChatType) {
document.removeEventListener("keydown", handleKeyDown);
};
}, []);
+
+ useEffect(() => {
+ // Define an async function within the useEffect hook
+ const fetchBuildStatus = async () => {
+ const response = await getBuildStatus(flow.id);
+ setIsBuilt(response.built);
+ };
+
+ // Call the async function
+ fetchBuildStatus();
+ }, [flow]);
+
+ const prevNodesRef = useRef
();
+ const nodes = useNodes();
+ useEffect(() => {
+ const prevNodes = prevNodesRef.current;
+ const currentNodes = nodes.map(
+ (node: NodeType) => node.data.node.template.value
+ );
+
+ if (
+ prevNodes &&
+ JSON.stringify(prevNodes) !== JSON.stringify(currentNodes)
+ ) {
+ setIsBuilt(false);
+ console.log("Nodes changed");
+ }
+
+ prevNodesRef.current = currentNodes;
+ }, [nodes]);
+
return (
<>
-
-
+ {isBuilt ? (
+
+
+
+
+
+ ) : (
+
+ )}
>
);
}
diff --git a/src/frontend/src/components/ui/loading.tsx b/src/frontend/src/components/ui/loading.tsx
new file mode 100644
index 000000000..7de6c8a6a
--- /dev/null
+++ b/src/frontend/src/components/ui/loading.tsx
@@ -0,0 +1,39 @@
+import { SVGProps } from "react";
+
+// https://github.com/feathericons/feather/issues/695#issuecomment-1503699643
+export const Loading = (props: SVGProps) => (
+
+);
+export default Loading;
diff --git a/src/frontend/src/contexts/SSEContext.tsx b/src/frontend/src/contexts/SSEContext.tsx
new file mode 100644
index 000000000..8d90b26e2
--- /dev/null
+++ b/src/frontend/src/contexts/SSEContext.tsx
@@ -0,0 +1,35 @@
+import {
+ createContext,
+ useContext,
+ useState,
+ useEffect,
+ useCallback,
+} from "react";
+
+const initialValue = {
+ updateSSEData: ({}) => {},
+ sseData: {},
+};
+
+const SSEContext = createContext(initialValue);
+
+export function useSSE() {
+ return useContext(SSEContext);
+}
+
+export function SSEProvider({ children }) {
+ const [sseData, setSSEData] = useState({});
+
+ const updateSSEData = useCallback((newData: any) => {
+ setSSEData((prevData) => ({
+ ...prevData,
+ ...newData,
+ }));
+ }, []);
+
+ return (
+
+ {children}
+
+ );
+}
diff --git a/src/frontend/src/controllers/API/index.ts b/src/frontend/src/controllers/API/index.ts
index b8b89784b..169f63d31 100644
--- a/src/frontend/src/controllers/API/index.ts
+++ b/src/frontend/src/controllers/API/index.ts
@@ -1,4 +1,9 @@
-import { PromptTypeAPI, errorsTypeAPI } from "./../../types/api/index";
+import {
+ BuildStatusTypeAPI,
+ PromptTypeAPI,
+ errorsTypeAPI,
+ InitTypeAPI,
+} from "./../../types/api/index";
import { APIObjectType, sendAllProps } from "../../types/api/index";
import axios, { AxiosResponse } from "axios";
import { FlowStyleType, FlowType } from "../../types/flow";
@@ -272,4 +277,14 @@ export async function getVersion() {
*/
export async function getHealth() {
return await axios.get("/health"); // Health is the only endpoint that doesn't require /api/v1
+export async function getBuildStatus(
+ flowId: string
+): Promise {
+ return await axios.get(`/api/v1/build/${flowId}/status`);
+}
+
+export async function postBuildInit(
+ flow: FlowType
+): Promise> {
+ return await axios.post(`/api/v1/build/init`, flow);
}
diff --git a/src/frontend/src/modals/EditNodeModal/index.tsx b/src/frontend/src/modals/EditNodeModal/index.tsx
index 316303c92..1a270f46e 100644
--- a/src/frontend/src/modals/EditNodeModal/index.tsx
+++ b/src/frontend/src/modals/EditNodeModal/index.tsx
@@ -22,7 +22,6 @@ import IntComponent from "../../components/intComponent";
import InputFileComponent from "../../components/inputFileComponent";
import PromptAreaComponent from "../../components/promptComponent";
import CodeAreaComponent from "../../components/codeAreaComponent";
-import { TabsContext } from "../../contexts/tabsContext";
import {
Dialog,
DialogContent,
@@ -33,7 +32,6 @@ import {
DialogTrigger,
} from "../../components/ui/dialog";
import { Button } from "../../components/ui/button";
-import { Edit } from "lucide-react";
import { Badge } from "../../components/ui/badge";
export default function EditNodeModal({ data }: { data: NodeDataType }) {
diff --git a/src/frontend/src/modals/chatModal/index.tsx b/src/frontend/src/modals/chatModal/index.tsx
index 8cdff93b8..ed6629cff 100644
--- a/src/frontend/src/modals/chatModal/index.tsx
+++ b/src/frontend/src/modals/chatModal/index.tsx
@@ -3,7 +3,7 @@ import { ChatBubbleOvalLeftEllipsisIcon } from "@heroicons/react/24/outline";
import { Fragment, useContext, useEffect, useRef, useState } from "react";
import { FlowType, NodeType } from "../../types/flow";
import { alertContext } from "../../contexts/alertContext";
-import { toNormalCase } from "../../utils";
+import { toNormalCase, validateNodes } from "../../utils";
import { typesContext } from "../../contexts/typesContext";
import ChatMessage from "./chatMessage";
import { FaEraser } from "react-icons/fa";
@@ -185,6 +185,17 @@ export default function ChatModal({
}://${host}${chatEndpoint}`;
}
+ function getWebSocketUrl(chatId, isDevelopment = false) {
+ const isSecureProtocol = window.location.protocol === "https:";
+ const webSocketProtocol = isSecureProtocol ? "wss" : "ws";
+ const host = isDevelopment ? "localhost:7860" : window.location.host;
+ const chatEndpoint = `/api/v1/chat/${chatId}`;
+
+ return `${
+ isDevelopment ? "ws" : webSocketProtocol
+ }://${host}${chatEndpoint}`;
+ }
+
function connectWS() {
try {
const urlWs = getWebSocketUrl(
@@ -269,53 +280,6 @@ export default function ChatModal({
if (ref.current) ref.current.scrollIntoView({ behavior: "smooth" });
}, [chatHistory]);
- function validateNode(n: NodeType): Array {
- if (!n.data?.node?.template || !Object.keys(n.data.node.template)) {
- setNoticeData({
- title:
- "We've noticed a potential issue with a node in the flow. Please review it and, if necessary, submit a bug report with your exported flow file. Thank you for your help!",
- });
- return [];
- }
-
- const {
- type,
- node: { template },
- } = n.data;
-
- return Object.keys(template).reduce(
- (errors: Array, t) =>
- errors.concat(
- template[t].required &&
- template[t].show &&
- (template[t].value === undefined ||
- template[t].value === null ||
- template[t].value === "") &&
- !reactFlowInstance
- .getEdges()
- .some(
- (e) =>
- e.targetHandle.split("|")[1] === t &&
- e.targetHandle.split("|")[2] === n.id
- )
- ? [
- `${type} is missing ${
- template.display_name
- ? template.display_name
- : toNormalCase(template[t].name)
- }.`,
- ]
- : []
- ),
- [] as string[]
- );
- }
-
- function validateNodes() {
- return reactFlowInstance
- .getNodes()
- .flatMap((n: NodeType) => validateNode(n));
- }
const ref = useRef(null);
@@ -327,7 +291,7 @@ export default function ChatModal({
function sendMessage() {
if (chatValue !== "") {
- let nodeValidationErrors = validateNodes();
+ let nodeValidationErrors = validateNodes(reactFlowInstance);
if (nodeValidationErrors.length === 0) {
setLockChat(true);
let message = chatValue;
diff --git a/src/frontend/src/types/api/index.ts b/src/frontend/src/types/api/index.ts
index 8d26c4e15..534657644 100644
--- a/src/frontend/src/types/api/index.ts
+++ b/src/frontend/src/types/api/index.ts
@@ -38,3 +38,11 @@ export type errorsTypeAPI = {
imports: { errors: Array };
};
export type PromptTypeAPI = { input_variables: Array };
+
+export type BuildStatusTypeAPI = {
+ built: boolean;
+};
+
+export type InitTypeAPI = {
+ flowId: string;
+};
diff --git a/src/frontend/src/utils.ts b/src/frontend/src/utils.ts
index 60bfdf3dd..714eb795b 100644
--- a/src/frontend/src/utils.ts
+++ b/src/frontend/src/utils.ts
@@ -17,7 +17,7 @@ import {
Bars3CenterLeftIcon,
} from "@heroicons/react/24/outline";
import { Connection, Edge, Node, ReactFlowInstance } from "reactflow";
-import { FlowType, NodeDataType, NodeType } from "./types/flow";
+import { FlowType, NodeType } from "./types/flow";
import { APITemplateType } from "./types/api";
import _ from "lodash";
import { ChromaIcon } from "./icons/ChromaIcon";
@@ -737,3 +737,56 @@ export function buildTweaks(flow) {
return acc;
}, {});
}
+export function validateNode(
+ n: NodeType,
+ reactFlowInstance: ReactFlowInstance
+): Array {
+ if (!n.data?.node?.template || !Object.keys(n.data.node.template)) {
+ return [
+ "We've noticed a potential issue with a node in the flow. Please review it and, if necessary, submit a bug report with your exported flow file. Thank you for your help!",
+ ];
+ }
+
+ const {
+ type,
+ node: { template },
+ } = n.data;
+
+ return Object.keys(template).reduce(
+ (errors: Array, t) =>
+ errors.concat(
+ template[t].required &&
+ template[t].show &&
+ (template[t].value === undefined ||
+ template[t].value === null ||
+ template[t].value === "") &&
+ !reactFlowInstance
+ .getEdges()
+ .some(
+ (e) =>
+ e.targetHandle.split("|")[1] === t &&
+ e.targetHandle.split("|")[2] === n.id
+ )
+ ? [
+ `${type} is missing ${
+ template.display_name
+ ? template.display_name
+ : toNormalCase(template[t].name)
+ }.`,
+ ]
+ : []
+ ),
+ [] as string[]
+ );
+}
+
+export function validateNodes(reactFlowInstance: ReactFlowInstance) {
+ if (reactFlowInstance.getNodes().length === 0) {
+ return [
+ "No nodes found in the flow. Please add at least one node to the flow.",
+ ];
+ }
+ return reactFlowInstance
+ .getNodes()
+ .flatMap((n: NodeType) => validateNode(n, reactFlowInstance));
+}
diff --git a/src/frontend/vite.config.ts b/src/frontend/vite.config.ts
index 40cd6af0f..860c690ea 100644
--- a/src/frontend/vite.config.ts
+++ b/src/frontend/vite.config.ts
@@ -12,11 +12,9 @@ const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
changeOrigin: true,
secure: false,
ws: true,
- // rewrite: (path) => `/api/v1${path}`,
};
return proxyObj;
}, {});
-
export default defineConfig(() => {
return {
build: {
diff --git a/tests/conftest.py b/tests/conftest.py
index 35ec5eac6..f893533ac 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -66,6 +66,12 @@ def get_graph(_type="basic"):
return Graph(nodes, edges)
+@pytest.fixture
+def basic_graph_data():
+ with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
+ return json.load(f)
+
+
@pytest.fixture
def basic_graph():
return get_graph()
diff --git a/tests/test_websocket.py b/tests/test_websocket.py
index 611faff79..a628e7928 100644
--- a/tests/test_websocket.py
+++ b/tests/test_websocket.py
@@ -1,47 +1,47 @@
-import json
-from unittest.mock import patch
+from fastapi import WebSocketDisconnect
-from fastapi.testclient import TestClient
+# from langflow.chat.manager import ChatManager
+
+import pytest
-def test_websocket_connection(client: TestClient):
- with client.websocket_connect("api/v1/chat/test_client") as websocket:
- assert websocket.scope["client"] == ["testclient", 50000]
- assert websocket.scope["path"] == "/api/v1/chat/test_client"
+def test_init_build(client):
+ response = client.post(
+ "api/v1/build/init", json={"id": "test", "data": {"key": "value"}}
+ )
+ assert response.status_code == 200
+ assert response.json() == {"flowId": "test"}
-def test_chat_history(client: TestClient):
- # Mock the process_graph function to return a specific value
- with patch("langflow.chat.manager.process_graph") as mock_process_graph:
- mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
+def test_stream_build(client):
+ client.post("/build/init", json={"id": "stream_test", "data": {"key": "value"}})
- with client.websocket_connect("api/v1/chat/test_client") as websocket:
- # First message should be the history
- history = websocket.receive_json()
- assert history == [] # Empty history
- # Send a message
- payload = {"message": "Hello"}
- websocket.send_json(json.dumps(payload))
+ # Test the stream
+ response = client.get("api/v1/build/stream/stream_test")
+ assert response.status_code == 200
+ assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
- # Receive the response from the server
- response = websocket.receive_json()
- assert response == {
- "is_bot": True,
- "message": None,
- "type": "start",
- "intermediate_steps": "",
- "files": [],
- }
- # Send another message
- payload = {"message": "How are you?"}
- websocket.send_json(json.dumps(payload))
- # Receive the response from the server
- response = websocket.receive_json()
- assert response == {
- "is_bot": True,
- "message": "Hello, I'm a mock response!",
- "type": "end",
- "intermediate_steps": "",
- "files": [],
- }
+def test_websocket_endpoint(client):
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect(
+ "api/v1/chat/non_existing_client_id"
+ ) as websocket:
+ websocket.send_json({"type": "test"})
+ data = websocket.receive_json()
+ assert "Please, build the flow before sending messages" in data["message"]
+
+
+def test_websocket_endpoint_after_build(client, basic_graph_data):
+ # Assuming your websocket_endpoint uses chat_manager which caches data from stream_build
+ client.post("/build/init", json=basic_graph_data)
+ client.get("/build/stream/websocket_test")
+
+ # There should be more to test here, but it depends on the inner workings of your websocket handler
+ # and how your chat_manager and other classes behave. The following is just an example structure.
+ with pytest.raises(WebSocketDisconnect):
+ with client.websocket_connect("api/v1/chat/websocket_test") as websocket:
+ websocket.send_json({"type": "test"})
+ # Perform assertions here, based on what you expect the websocket to return
+ # data = websocket.receive_json()
+ # assert ...