diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 49426aeb7..813aaf415 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Annotated, Optional, Union +from typing import Annotated, Any, Optional, Union from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file @@ -40,7 +40,7 @@ def get_all(current_user: User = Depends(get_current_active_user)): native_components = build_langchain_types_dict() # custom_components is a list of dicts # need to merge all the keys into one dict - custom_components_from_file = {} + custom_components_from_file: dict[str, Any] = {} settings_manager = get_settings_manager() if settings_manager.settings.COMPONENTS_PATH: logger.info( @@ -93,19 +93,19 @@ async def process_flow( tweaks: Optional[dict] = None, clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 - api_key=Depends(api_key_security), + api_key_user: User = Depends(api_key_security), ): """ Endpoint to process an input with a given flow_id. """ try: - if api_key is None: + if api_key_user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key", ) - api_key_user = api_key.user + # Get the flow that matches the flow_id and belongs to the user flow = ( session.query(Flow) diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 9a2dc21c5..a70a06e88 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -21,18 +21,18 @@ class AgentVertex(Vertex): elif isinstance(source_node, ChainVertex): self.chains.append(source_node) - def build(self, force: bool = False) -> Any: + def build(self, force: bool = False, user_id=None) -> Any: if not self._built or force: self._set_tools_and_chains() # First, build the tools for tool_node in self.tools: - tool_node.build() + tool_node.build(user_id=user_id) # Next, build the chains and the rest for chain_node in self.chains: - chain_node.build(tools=self.tools) + chain_node.build(tools=self.tools, user_id=user_id) - self._build() + self._build(user_id=user_id) return self._built_object @@ -49,13 +49,13 @@ class LLMVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="llms") - def build(self, force: bool = False) -> Any: + def build(self, force: bool = False, user_id=None) -> Any: # LLM is different because some models might take up too much memory # or time to load. So we only load them when we need them.ß if self.vertex_type == self.built_node_type: return self.class_built_object if not self._built or force: - self._build() + self._build(user_id=user_id) self.built_node_type = self.vertex_type self.class_built_object = self._built_object # Avoid deepcopying the LLM @@ -77,11 +77,11 @@ class WrapperVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="wrappers") - def build(self, force: bool = False) -> Any: + def build(self, force: bool = False, user_id=None) -> Any: if not self._built or force: if "headers" in self.params: self.params["headers"] = ast.literal_eval(self.params["headers"]) - self._build() + self._build(user_id=user_id) return self._built_object @@ -149,6 +149,7 @@ class ChainVertex(Vertex): self, force: bool = False, tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None, + user_id=None, ) -> Any: if not self._built or force: # Check if the chain requires a PromptVertex @@ -157,7 +158,7 @@ class ChainVertex(Vertex): # Build the PromptVertex, passing the tools if available self.params[key] = value.build(tools=tools, force=force) - self._build() + self._build(user_id=user_id) return self._built_object @@ -170,6 +171,7 @@ class PromptVertex(Vertex): self, force: bool = False, tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None, + user_id=None, ) -> Any: if not self._built or force: if ( @@ -180,7 +182,7 @@ class PromptVertex(Vertex): # Check if it is a ZeroShotPrompt and needs a tool if "ShotPrompt" in self.vertex_type: tools = ( - [tool_node.build() for tool_node in tools] + [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] ) @@ -208,7 +210,7 @@ class PromptVertex(Vertex): else: self.params.pop("input_variables", None) - self._build() + self._build(user_id=user_id) return self._built_object def _built_object_repr(self): diff --git a/src/backend/langflow/services/auth/service.py b/src/backend/langflow/services/auth/service.py index c80b984bb..29984a75c 100644 --- a/src/backend/langflow/services/auth/service.py +++ b/src/backend/langflow/services/auth/service.py @@ -1,4 +1,3 @@ -from fastapi import Request from langflow.services.base import Service from typing import TYPE_CHECKING @@ -11,8 +10,3 @@ class AuthManager(Service): def __init__(self, settings_manager: "SettingsManager"): self.settings_manager = settings_manager - - # We need to define a function that can be passed to the Depends() function. - # This function will be called by FastAPI to run oauth2_scheme - def run_oauth2_scheme(self, request: Request): - return self.settings_manager.auth_settings.oauth2_scheme(request=request) diff --git a/src/backend/langflow/services/auth/utils.py b/src/backend/langflow/services/auth/utils.py index 8377b26cb..333ba226b 100644 --- a/src/backend/langflow/services/auth/utils.py +++ b/src/backend/langflow/services/auth/utils.py @@ -36,7 +36,11 @@ async def api_key_security( settings_manager = get_settings_manager() result = None if settings_manager.auth_settings.AUTO_LOGIN: - return settings_manager.auth_settings.API_KEY_SECRET_KEY + # Get the first user + settings_manager.auth_settings.FIRST_SUPERUSER + result = get_user_by_username( + db, settings_manager.auth_settings.FIRST_SUPERUSER + ) elif not query_param and not header_param: raise HTTPException( @@ -50,13 +54,15 @@ async def api_key_security( else: result = check_key(db, header_param) - if result: - return result - else: + if not result: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid or missing API key", ) + if isinstance(result, ApiKey): + return result.user + elif isinstance(result, User): + return result async def get_current_user( @@ -139,7 +145,9 @@ def create_token(data: dict, expires_delta: timedelta): def create_super_user( - db: Session = Depends(get_session), username: str = None, password: str = None + db: Session = Depends(get_session), + username: Optional[str] = None, + password: Optional[str] = None, ) -> User: settings_manager = get_settings_manager() diff --git a/src/backend/langflow/utils/util.py b/src/backend/langflow/utils/util.py index 8f0ff216a..890201294 100644 --- a/src/backend/langflow/utils/util.py +++ b/src/backend/langflow/utils/util.py @@ -10,7 +10,7 @@ from langflow.__main__ import console # type: ignore from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS from langflow.utils import constants from langflow.utils.logger import logger -from multiprocess import cpu_count +from multiprocess import cpu_count # type: ignore from rich.table import Table # type: ignore @@ -267,7 +267,7 @@ def format_dict( _type: Union[str, type] = get_type(value) - if "BaseModel" in _type: + if "BaseModel" in str(_type): continue _type = remove_optional_wrapper(_type)