fix unit test

This commit is contained in:
ming luo 2024-06-21 20:46:54 -04:00 committed by Gabriel Luiz Freitas Almeida
commit 17adc8b5a5
5 changed files with 162 additions and 75 deletions

View file

@ -2,12 +2,10 @@ import abc
from typing import Optional, Union
from uuid import UUID
from fastapi import Depends
from sqlmodel import Session
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable
from langflow.services.deps import get_session
class VariableService(Service):
@ -90,9 +88,9 @@ class VariableService(Service):
user_id: Union[UUID, str],
name: str,
value: str,
default_fields: list[str] = [],
_type: str = "Generic",
session: Session = Depends(get_session),
default_fields: list[str],
_type: str,
session: Session,
) -> Variable:
"""
Create a variable.

View file

@ -1,8 +1,11 @@
from kubernetes import client, config # type: ignore
from kubernetes.client.rest import ApiException
from kubernetes.client.rest import ApiException # type: ignore
from base64 import b64encode, b64decode
from loguru import logger
from typing import Union
from uuid import UUID
class KubernetesSecretManager:
"""
@ -38,15 +41,11 @@ class KubernetesSecretManager:
secret_metadata = client.V1ObjectMeta(name=name)
secret = client.V1Secret(
api_version="v1",
kind="Secret",
metadata=secret_metadata,
type=secret_type,
data=encoded_data
api_version="v1", kind="Secret", metadata=secret_metadata, type=secret_type, data=encoded_data
)
return self.core_api.create_namespaced_secret(self.namespace, secret)
def upsert_secret(self, secret_name: str, data: dict, secret_type: str = "Opaque"):
"""
Upsert a secret in the specified namespace.
@ -60,18 +59,18 @@ class KubernetesSecretManager:
try:
# Try to read the existing secret
existing_secret = self.core_api.read_namespaced_secret(secret_name, self.namespace)
# If secret exists, update it
existing_data = {k: b64decode(v).decode() for k, v in existing_secret.data.items()}
existing_data.update(data)
# Encode all data to base64
encoded_data = {k: b64encode(v.encode()).decode() for k, v in existing_data.items()}
# Update the existing secret
existing_secret.data = encoded_data
return self.core_api.replace_namespaced_secret(secret_name, self.namespace, existing_secret)
except ApiException as e:
if e.status == 404:
# Secret doesn't exist, create a new one
@ -113,14 +112,14 @@ class KubernetesSecretManager:
secret = self.core_api.read_namespaced_secret(name, self.namespace)
if secret is None:
raise ApiException(status=404, reason="Not Found", msg="Secret not found")
# Update the secret data
encoded_data = {k: b64encode(v.encode()).decode() for k, v in data.items()}
secret.data.update(encoded_data)
# Update the secret in Kubernetes
return self.core_api.replace_namespaced_secret(name, self.namespace, secret)
def delete_secret_key(self, name: str, key: str):
"""
Delete a key from the specified secret in the namespace.
@ -136,16 +135,16 @@ class KubernetesSecretManager:
secret = self.core_api.read_namespaced_secret(name, self.namespace)
if secret is None:
raise ApiException(status=404, reason="Not Found", msg="Secret not found")
# Delete the key from the secret data
if key in secret.data:
del secret.data[key]
else:
raise ApiException(status=404, reason="Not Found", msg="Key not found in the secret")
# Update the secret in Kubernetes
return self.core_api.replace_namespaced_secret(name, self.namespace, secret)
def delete_secret(self, name: str):
"""
Delete a secret from the specified namespace.
@ -158,7 +157,39 @@ class KubernetesSecretManager:
"""
return self.core_api.delete_namespaced_secret(name, self.namespace)
# utility function to encode user_id to base64 lower case and numbers only
# this is required by kubernetes secret name restrictions
def encode_user_id(user_id: str) -> str:
return b64encode(user_id.encode()).decode().lower().replace("=", "").replace("+", "-").replace("/", "_")
def encode_user_id(user_id: Union[UUID | str]) -> str:
# Handle UUID
if isinstance(user_id, UUID):
return f"uuid-{str(user_id).lower()}"[:253]
# Convert string to lowercase
id = str(user_id).lower()
# If the user_id looks like an email, replace @ and . with allowed characters
if "@" in id or "." in id:
id = id.replace("@", "-at-").replace(".", "-dot-")
# Encode the user_id to base64
# encoded = base64.b64encode(user_id.encode("utf-8")).decode("utf-8")
# Replace characters not allowed in Kubernetes names
id = id.replace("+", "-").replace("/", "_").rstrip("=")
# Ensure the name starts with an alphanumeric character
if not id[0].isalnum():
id = "a-" + id
# Truncate to 253 characters (Kubernetes name length limit)
id = id[:253]
if not all(c.isalnum() or c in "-_" for c in id):
raise ValueError(f"Invalid user_id: {id}")
# Ensure the name ends with an alphanumeric character
while not id[-1].isalnum():
id = id[:-1]
return id

View file

@ -154,7 +154,7 @@ class KubernetesSecretService(VariableService, Service):
variables[key] = str(value)
try:
secret_name = user_id
secret_name = encode_user_id(user_id)
self.kubernetes_secrets.create_secret(
name=secret_name,
data=variables,
@ -190,7 +190,7 @@ class KubernetesSecretService(VariableService, Service):
user_id: Union[UUID, str],
name: str,
field: str,
_session: Session = None,
_session: Session,
) -> str:
secret_name = encode_user_id(user_id)
key, value = self.resolve_variable(secret_name, user_id, name)
@ -204,7 +204,7 @@ class KubernetesSecretService(VariableService, Service):
def list_variables(
self,
user_id: Union[UUID, str],
_session: Session = None,
_session: Session,
) -> list[Optional[str]]:
variables = self.kubernetes_secrets.get_secret(name=encode_user_id(user_id))
if not variables:
@ -223,17 +223,17 @@ class KubernetesSecretService(VariableService, Service):
user_id: Union[UUID, str],
name: str,
value: str,
_session: Session = None,
_session: Session,
):
secret_name = encode_user_id(user_id)
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
return self.kubernetes_secrets.update_secret_key(name=secret_name, data={secret_key: value})
return self.kubernetes_secrets.update_secret(name=secret_name, data={secret_key: value})
def delete_variable(
self,
user_id: Union[UUID, str],
name: str,
_session: Session = None,
_session: Session,
):
secret_name = encode_user_id(user_id)
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
@ -245,13 +245,24 @@ class KubernetesSecretService(VariableService, Service):
user_id: Union[UUID, str],
name: str,
value: str,
default_fields: list[str] = [],
_type: str = "Generic",
_session: Session = None,
):
default_fields: list[str],
_type: str,
_session: Session,
) -> Variable:
secret_name = encode_user_id(user_id)
secret_key = name
if _type == CREDENTIAL_TYPE:
secret_key = CREDENTIAL_TYPE + "_" + name
else:
_type = GENERIC_TYPE
return self.kubernetes_secrets.upsert_secret(name=secret_name, data={secret_key: value})
self.kubernetes_secrets.upsert_secret(secret_name=secret_name, data={secret_key: value})
variable_base = VariableCreate(
name=name,
type=_type,
value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service),
default_fields=default_fields,
)
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
return variable