diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index dd98f77d9..173a2c572 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,6 +1,6 @@ from http import HTTPStatus from typing import Annotated, Optional, Union -from langflow.services.auth.utils import get_current_active_user, validate_api_key +from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow @@ -77,22 +77,27 @@ def get_all(current_user: User = Depends(get_current_active_user)): # For backwards compatibility we will keep the old endpoint -@router.post("/predict/{flow_id}", response_model=ProcessResponse) -@router.post("/process/{flow_id}", response_model=ProcessResponse) +@router.post( + "/predict/{flow_id}", + response_model=ProcessResponse, + dependencies=[Depends(api_key_security)], +) +@router.post( + "/process/{flow_id}", + response_model=ProcessResponse, + dependencies=[Depends(api_key_security)], +) async def process_flow( + session: Annotated[Session, Depends(get_session)], flow_id: str, inputs: Optional[dict] = None, tweaks: Optional[dict] = None, clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 - session: Session = Depends(get_session), - valid: bool = Depends(validate_api_key), ): """ Endpoint to process an input with a given flow_id. """ - if not valid: - raise HTTPException(status_code=401, detail="Invalid API key") try: flow = session.get(Flow, flow_id) diff --git a/src/backend/langflow/services/database/models/api_key/api_key.py b/src/backend/langflow/services/database/models/api_key/api_key.py index 6fc58b312..f8015af40 100644 --- a/src/backend/langflow/services/database/models/api_key/api_key.py +++ b/src/backend/langflow/services/database/models/api_key/api_key.py @@ -13,13 +13,14 @@ class ApiKeyBase(SQLModelSerializable): name: Optional[str] = Field(index=True) created_at: datetime = Field(default_factory=datetime.utcnow) last_used_at: Optional[datetime] = Field(default=None) + total_uses: int = Field(default=0) + is_active: bool = Field(default=True) class ApiKey(ApiKeyBase, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True) api_key: str = Field(index=True, unique=True) - hashed_api_key: str = Field(index=True) # User relationship user_id: UUID = Field(index=True, foreign_key="user.id") user: "User" = Relationship(back_populates="api_keys") @@ -44,4 +45,4 @@ class ApiKeyRead(ApiKeyBase): @validator("api_key", always=True) def mask_api_key(cls, v): # This validator will always run, and will mask the API key - return f"{'*' * 8}{v[-4:]}" + return f"{v[:2]}{'*' * (len(v) - 4)}{v[-2:]}"