Improves secret key implementation (#846)

Secret key is now set once if not passed using the env variable and
saved into the CONFIG_DIR.
There are separate implementations depending on the OS
This commit is contained in:
anovazzi1 2023-08-30 19:08:19 -03:00 committed by GitHub
commit f30c818053
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 316 additions and 201 deletions

View file

@ -18,6 +18,10 @@ from langflow.services.utils import get_session
from langflow.utils.logger import logger
from cachetools import LRUCache
from sqlmodel import Session
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langflow.services.chat.manager import ChatManager
router = APIRouter(tags=["Chat"])
@ -33,21 +37,38 @@ async def chat(
):
"""Websocket endpoint for chat."""
try:
await websocket.accept()
user = await get_current_user(token, db)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
if not user.is_active:
raise HTTPException(status_code=401, detail="Invalid token")
chat_manager = service_manager.get(ServiceType.CHAT_MANAGER)
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
chat_manager: "ChatManager" = service_manager.get(ServiceType.CHAT_MANAGER)
if client_id in chat_manager.in_memory_cache:
await chat_manager.handle_websocket(client_id, websocket)
else:
# We accept the connection but close it immediately
# if the flow is not built yet
await websocket.accept()
message = "Please, build the flow before sending messages"
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=message)
except WebSocketException as exc:
logger.error(f"Websocket error: {exc}")
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except Exception as exc:
logger.error(f"Error in chat websocket: {exc}")
if isinstance(exc, HTTPException):
exc = exc.detail
if "Could not validate credentials" in str(exc):
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
else:
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
@router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201)

View file

@ -88,6 +88,11 @@ async def get_current_user(
)
user_id: UUID = payload.get("sub") # type: ignore
token_type: str = payload.get("type") # type: ignore
if expires := payload.get("exp", None):
expires_datetime = datetime.fromtimestamp(expires, timezone.utc)
# TypeError: can't compare offset-naive and offset-aware datetimes
if datetime.now(timezone.utc) > expires_datetime:
raise credentials_exception
if user_id is None or token_type:
raise credentials_exception

View file

@ -92,7 +92,6 @@ class ChatManager(Service):
)
async def connect(self, client_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):

View file

@ -1,13 +1,21 @@
from pathlib import Path
from typing import Optional
import secrets
from langflow.services.settings.utils import read_secret_from_file, write_secret_to_file
from pydantic import BaseSettings
from pydantic import BaseSettings, Field, validator
from passlib.context import CryptContext
from langflow.utils.logger import logger
class AuthSettings(BaseSettings):
# Login settings
SECRET_KEY: str = secrets.token_hex(32)
CONFIG_DIR: str
SECRET_KEY: Optional[str] = Field(
None,
description="Secret key for JWT. If not provided, a random one will be generated.",
env="LANGFLOW_SECRET_KEY",
)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60
REFRESH_TOKEN_EXPIRE_MINUTES: int = 70
@ -31,3 +39,33 @@ class AuthSettings(BaseSettings):
validate_assignment = True
extra = "ignore"
env_prefix = "LANGFLOW_"
@validator("SECRET_KEY", pre=True)
def get_secret_key(cls, value, values):
config_dir = values.get("CONFIG_DIR")
if not config_dir:
logger.debug("No CONFIG_DIR provided, not saving secret key")
return value or secrets.token_urlsafe(32)
secret_key_path = Path(config_dir) / "secret_key"
if value:
logger.debug("Secret key provided")
write_secret_to_file(secret_key_path, value)
else:
logger.debug("No secret key provided, generating a random one")
if secret_key_path.exists():
value = read_secret_from_file(secret_key_path)
logger.debug("Loaded secret key")
if not value:
value = secrets.token_urlsafe(32)
write_secret_to_file(secret_key_path, value)
logger.debug("Saved secret key")
else:
value = secrets.token_urlsafe(32)
write_secret_to_file(secret_key_path, value)
logger.debug("Saved secret key")
return value

View file

@ -35,5 +35,5 @@ class SettingsManager(Service):
)
settings = Settings(**settings_dict)
auth_settings = AuthSettings()
auth_settings = AuthSettings(CONFIG_DIR=settings.CONFIG_DIR)
return cls(settings, auth_settings)

View file

@ -0,0 +1,47 @@
import os
from pathlib import Path
import platform
from langflow.utils.logger import logger
def set_secure_permissions(file_path):
if platform.system() in ["Linux", "Darwin"]: # Unix/Linux/Mac
os.chmod(file_path, 0o600)
elif platform.system() == "Windows":
import win32api
import win32con
import win32security
user, domain, _ = win32security.LookupAccountName("", win32api.GetUserName())
sd = win32security.GetFileSecurity(
file_path, win32security.DACL_SECURITY_INFORMATION
)
dacl = win32security.ACL()
# Set the new DACL for the file: read and write access for the owner, no access for everyone else
dacl.AddAccessAllowedAce(
win32security.ACL_REVISION,
win32con.GENERIC_READ | win32con.GENERIC_WRITE,
user,
)
sd.SetSecurityDescriptorDacl(1, dacl, 0)
win32security.SetFileSecurity(
file_path, win32security.DACL_SECURITY_INFORMATION, sd
)
else:
print("Unsupported OS")
def write_secret_to_file(path: Path, value: str) -> None:
with path.open("wb") as f:
f.write(value.encode("utf-8"))
try:
set_secure_permissions(path)
except Exception:
logger.error("Failed to set secure permissions on secret key")
def read_secret_from_file(path: Path) -> str:
with path.open("rb") as f:
return f.read()

View file

@ -312,7 +312,7 @@ export async function getHealth() {
*/
export async function getBuildStatus(
flowId: string
): Promise<BuildStatusTypeAPI> {
): Promise<AxiosResponse<BuildStatusTypeAPI>> {
return await api.get(`${BASE_URL_API}build/${flowId}/status`);
}

View file

@ -8,7 +8,7 @@ import { classNames } from "../../utils/utils";
import ChatInput from "./chatInput";
import ChatMessage from "./chatMessage";
import _ from "lodash";
import _, { set } from "lodash";
import AccordionComponent from "../../components/AccordionComponent";
import IconComponent from "../../components/genericIconComponent";
import ToggleShadComponent from "../../components/toggleShadComponent";
@ -27,6 +27,7 @@ import { AuthContext } from "../../contexts/authContext";
import { TabsContext } from "../../contexts/tabsContext";
import { TabsState } from "../../types/tabs";
import { validateNodes } from "../../utils/reactflowUtils";
import { getBuildStatus } from "../../controllers/API";
export default function FormModal({
flow,
@ -155,9 +156,21 @@ export default function FormModal({
function handleOnClose(event: CloseEvent): void {
if (isOpen.current) {
getBuildStatus(flow.id).then((response) => {
if (response.data.built) {
connectWS();
}
else {
setErrorData({
title: "Please build the flow again before using the chat."
})
}
}).catch((error) => {
setErrorData({title:error.data?.detail?error.data.detail:error.message})
});
setErrorData({ title: event.reason });
setTimeout(() => {
connectWS();
setLockChat(false);
}, 1000);
}
@ -173,9 +186,8 @@ export default function FormModal({
const host = isDevelopment ? "localhost:7860" : window.location.host;
const chatEndpoint = `/api/v1/chat/${chatId}`;
return `${
isDevelopment ? "ws" : webSocketProtocol
}://${host}${chatEndpoint}?token=${accessToken}`;
return `${isDevelopment ? "ws" : webSocketProtocol
}://${host}${chatEndpoint}?token=${encodeURIComponent(accessToken!)}`;
}
function handleWsMessage(data: any) {
@ -197,20 +209,20 @@ export default function FormModal({
newChatHistory.push(
chatItem.files
? {
isSend: !chatItem.is_bot,
message: chatItem.message,
template: chatItem.template,
thought: chatItem.intermediate_steps,
files: chatItem.files,
chatKey: chatItem.chatKey,
}
isSend: !chatItem.is_bot,
message: chatItem.message,
template: chatItem.template,
thought: chatItem.intermediate_steps,
files: chatItem.files,
chatKey: chatItem.chatKey,
}
: {
isSend: !chatItem.is_bot,
message: chatItem.message,
template: chatItem.template,
thought: chatItem.intermediate_steps,
chatKey: chatItem.chatKey,
}
isSend: !chatItem.is_bot,
message: chatItem.message,
template: chatItem.template,
thought: chatItem.intermediate_steps,
chatKey: chatItem.chatKey,
}
);
}
}
@ -260,7 +272,6 @@ export default function FormModal({
};
newWs.onmessage = (event) => {
const data = JSON.parse(event.data);
console.log("Received data:", data);
handleWsMessage(data);
//get chat history
};
@ -268,7 +279,6 @@ export default function FormModal({
handleOnClose(event);
};
newWs.onerror = (ev) => {
console.log(ev, "error");
if (flow.id === "") {
connectWS();
} else {
@ -294,7 +304,6 @@ export default function FormModal({
useEffect(() => {
connectWS();
return () => {
console.log("unmount");
console.log(ws);
if (ws.current) {
ws.current.close();
@ -433,73 +442,73 @@ export default function FormModal({
{tabsState[id.current]?.formKeysData?.input_keys
? Object.keys(
tabsState[id.current].formKeysData.input_keys!
).map((key, index) => (
<div className="file-component-accordion-div" key={index}>
<AccordionComponent
trigger={
<div className="file-component-badge-div">
<Badge variant="gray" size="md">
{key}
</Badge>
tabsState[id.current].formKeysData.input_keys!
).map((key, index) => (
<div className="file-component-accordion-div" key={index}>
<AccordionComponent
trigger={
<div className="file-component-badge-div">
<Badge variant="gray" size="md">
{key}
</Badge>
<div
className="-mb-1"
onClick={(event) => {
event.stopPropagation();
}}
>
<ToggleShadComponent
enabled={chatKey === key}
setEnabled={(value) =>
handleOnCheckedChange(value, key)
}
size="small"
disabled={tabsState[
id.current
].formKeysData.handle_keys!.some(
(t) => t === key
)}
/>
</div>
<div
className="-mb-1"
onClick={(event) => {
event.stopPropagation();
}}
>
<ToggleShadComponent
enabled={chatKey === key}
setEnabled={(value) =>
handleOnCheckedChange(value, key)
}
size="small"
disabled={tabsState[
id.current
].formKeysData.handle_keys!.some(
(t) => t === key
)}
/>
</div>
}
key={index}
keyValue={key}
>
<div className="file-component-tab-column">
{tabsState[id.current].formKeysData.handle_keys!.some(
(t) => t === key
) && (
</div>
}
key={index}
keyValue={key}
>
<div className="file-component-tab-column">
{tabsState[id.current].formKeysData.handle_keys!.some(
(t) => t === key
) && (
<div className="font-normal text-muted-foreground ">
Source: Component
</div>
)}
<Textarea
className="custom-scroll"
value={
tabsState[id.current].formKeysData.input_keys![
key
]
}
onChange={(e) => {
//@ts-ignore
setTabsState((old: TabsState) => {
let newTabsState = _.cloneDeep(old);
newTabsState[
id.current
].formKeysData.input_keys![key] =
e.target.value;
return newTabsState;
});
}}
disabled={chatKey === key}
placeholder="Enter text..."
></Textarea>
</div>
</AccordionComponent>
</div>
))
<Textarea
className="custom-scroll"
value={
tabsState[id.current].formKeysData.input_keys![
key
]
}
onChange={(e) => {
//@ts-ignore
setTabsState((old: TabsState) => {
let newTabsState = _.cloneDeep(old);
newTabsState[
id.current
].formKeysData.input_keys![key] =
e.target.value;
return newTabsState;
});
}}
disabled={chatKey === key}
placeholder="Enter text..."
></Textarea>
</div>
</AccordionComponent>
</div>
))
: null}
{tabsState[id.current].formKeysData.memory_keys!.map(
(key, index) => (
@ -513,7 +522,7 @@ export default function FormModal({
<div className="-mb-1">
<ToggleShadComponent
enabled={chatKey === key}
setEnabled={() => {}}
setEnabled={() => { }}
size="small"
disabled={true}
/>