From 53910dbce7c32afbdf0fcb3caa2e8299cb61267e Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 4 Dec 2023 16:07:51 -0300 Subject: [PATCH] Fix type annotations and imports in code --- src/backend/langflow/api/v1/users.py | 3 ++- src/backend/langflow/graph/graph/base.py | 9 +++++---- .../langflow/interface/custom/code_parser.py | 3 ++- .../interface/custom/custom_component.py | 3 ++- .../services/database/models/api_key/crud.py | 18 ++++++++---------- .../database/models/credential/model.py | 6 ++---- .../services/database/models/user/model.py | 2 +- src/backend/langflow/services/store/service.py | 2 +- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/backend/langflow/api/v1/users.py b/src/backend/langflow/api/v1/users.py index f3738b36e..1aba198e9 100644 --- a/src/backend/langflow/api/v1/users.py +++ b/src/backend/langflow/api/v1/users.py @@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import func from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select +from sqlmodel.sql.expression import SelectOfScalar from langflow.api.v1.schemas import UsersResponse from langflow.services.auth.utils import ( @@ -62,7 +63,7 @@ def read_all_users( """ Retrieve a list of users from the database with pagination. """ - query = select(User).offset(skip).limit(limit) + query: SelectOfScalar = select(User).offset(skip).limit(limit) users = session.exec(query).fetchall() count_query = select(func.count()).select_from(User) # type: ignore diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 3a69e5c60..3481fda87 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,6 +1,8 @@ from typing import Dict, Generator, List, Type, Union from langchain.chains.base import Chain +from loguru import logger + from langflow.graph.edge.base import Edge from langflow.graph.graph.constants import lazy_load_vertex_dict from langflow.graph.graph.utils import process_flow @@ -8,7 +10,6 @@ from langflow.graph.vertex.base import Vertex from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex from langflow.interface.tools.constants import FILE_TOOLS from langflow.utils import payload -from loguru import logger class Graph: @@ -229,9 +230,9 @@ class Graph: vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type) - vertex = VertexClass(vertex, graph=self) - vertex.set_top_level(self.top_level_vertices) - vertices.append(vertex) + vertex_instance = VertexClass(vertex, graph=self) + vertex_instance.set_top_level(self.top_level_vertices) + vertices.append(vertex_instance) return vertices diff --git a/src/backend/langflow/interface/custom/code_parser.py b/src/backend/langflow/interface/custom/code_parser.py index 427f40652..a7bf8023c 100644 --- a/src/backend/langflow/interface/custom/code_parser.py +++ b/src/backend/langflow/interface/custom/code_parser.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Type, Union from cachetools import TTLCache, cachedmethod, keys from fastapi import HTTPException + from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails @@ -35,7 +36,7 @@ class CodeParser: """ Initializes the parser with the provided code. """ - self.cache = TTLCache(maxsize=1024, ttl=60) + self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60) if isinstance(code, type): if not inspect.isclass(code): raise ValueError("The provided code must be a class.") diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index d5d229b9e..438cc50cf 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -5,6 +5,7 @@ from uuid import UUID import yaml from cachetools import TTLCache, cachedmethod from fastapi import HTTPException + from langflow.interface.custom.component import Component from langflow.interface.custom.directory_reader import DirectoryReader from langflow.interface.custom.utils import ( @@ -187,7 +188,7 @@ class CustomComponent(Component): # Retrieve and decrypt the credential by name for the current user db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.get_credential(user_id=self._user_id, name=name, session=session) + return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session) return get_credential diff --git a/src/backend/langflow/services/database/models/api_key/crud.py b/src/backend/langflow/services/database/models/api_key/crud.py index 806848218..e5c7d9ddd 100644 --- a/src/backend/langflow/services/database/models/api_key/crud.py +++ b/src/backend/langflow/services/database/models/api_key/crud.py @@ -1,21 +1,19 @@ import datetime import secrets import threading -from uuid import UUID from typing import List, Optional +from uuid import UUID + from sqlmodel import Session, select -from langflow.services.database.models.api_key import ( - ApiKey, - ApiKeyCreate, - UnmaskedApiKeyRead, - ApiKeyRead, -) +from sqlmodel.sql.expression import SelectOfScalar + +from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead def get_api_keys(session: Session, user_id: UUID) -> List[ApiKeyRead]: - query = select(ApiKey).where(ApiKey.user_id == user_id) + query: SelectOfScalar = select(ApiKey).where(ApiKey.user_id == user_id) api_keys = session.exec(query).all() - return [ApiKeyRead.from_orm(api_key) for api_key in api_keys] + return [ApiKeyRead.model_validate(api_key) for api_key in api_keys] def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead: @@ -46,7 +44,7 @@ def delete_api_key(session: Session, api_key_id: UUID) -> None: def check_key(session: Session, api_key: str) -> Optional[ApiKey]: """Check if the API key is valid.""" - query = select(ApiKey).where(ApiKey.api_key == api_key) + query: SelectOfScalar = select(ApiKey).where(ApiKey.api_key == api_key) api_key_object: Optional[ApiKey] = session.exec(query).first() if api_key_object is not None: threading.Thread( diff --git a/src/backend/langflow/services/database/models/credential/model.py b/src/backend/langflow/services/database/models/credential/model.py index 4a5424364..95bd4b829 100644 --- a/src/backend/langflow/services/database/models/credential/model.py +++ b/src/backend/langflow/services/database/models/credential/model.py @@ -2,9 +2,10 @@ from datetime import datetime from typing import TYPE_CHECKING, Optional from uuid import UUID, uuid4 -from langflow.services.database.models.credential.schema import CredentialType from sqlmodel import Field, Relationship, SQLModel +from langflow.services.database.models.credential.schema import CredentialType + if TYPE_CHECKING: from langflow.services.database.models.user import User @@ -24,9 +25,6 @@ class Credential(CredentialBase, table=True): user_id: UUID = Field(description="User ID associated with this credential", foreign_key="user.id") user: "User" = Relationship(back_populates="credentials") - if TYPE_CHECKING: - user: "User" = Relationship(back_populates="credentials") - class CredentialCreate(CredentialBase): # AcceptedProviders is a custom Enum diff --git a/src/backend/langflow/services/database/models/user/model.py b/src/backend/langflow/services/database/models/user/model.py index 3a57bc9ab..dccd3c305 100644 --- a/src/backend/langflow/services/database/models/user/model.py +++ b/src/backend/langflow/services/database/models/user/model.py @@ -24,7 +24,7 @@ class User(SQLModel, table=True): back_populates="user", sa_relationship_kwargs={"cascade": "delete"}, ) - store_api_key: str = Field(default=None, nullable=True) + store_api_key: Optional[str] = Field(default=None, nullable=True) flows: list["Flow"] = Relationship(back_populates="user") credentials: list["Credential"] = Relationship( back_populates="user", diff --git a/src/backend/langflow/services/store/service.py b/src/backend/langflow/services/store/service.py index 4d5acce8c..94a89f974 100644 --- a/src/backend/langflow/services/store/service.py +++ b/src/backend/langflow/services/store/service.py @@ -482,7 +482,7 @@ class StoreService(Service): result: List[ListComponentResponse] = [] authorized = False - metadata = {} + metadata: Dict = {} comp_count = 0 try: result, metadata = await self.query_components(