Fix type annotations and imports in code

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-04 16:07:51 -03:00
commit 53910dbce7
8 changed files with 23 additions and 23 deletions

View file

@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select
from sqlmodel.sql.expression import SelectOfScalar
from langflow.api.v1.schemas import UsersResponse
from langflow.services.auth.utils import (
@ -62,7 +63,7 @@ def read_all_users(
"""
Retrieve a list of users from the database with pagination.
"""
query = select(User).offset(skip).limit(limit)
query: SelectOfScalar = select(User).offset(skip).limit(limit)
users = session.exec(query).fetchall()
count_query = select(func.count()).select_from(User) # type: ignore

View file

@ -1,6 +1,8 @@
from typing import Dict, Generator, List, Type, Union
from langchain.chains.base import Chain
from loguru import logger
from langflow.graph.edge.base import Edge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.utils import process_flow
@ -8,7 +10,6 @@ from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.utils import payload
from loguru import logger
class Graph:
@ -229,9 +230,9 @@ class Graph:
vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type)
vertex = VertexClass(vertex, graph=self)
vertex.set_top_level(self.top_level_vertices)
vertices.append(vertex)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices

View file

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Type, Union
from cachetools import TTLCache, cachedmethod, keys
from fastapi import HTTPException
from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails
@ -35,7 +36,7 @@ class CodeParser:
"""
Initializes the parser with the provided code.
"""
self.cache = TTLCache(maxsize=1024, ttl=60)
self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60)
if isinstance(code, type):
if not inspect.isclass(code):
raise ValueError("The provided code must be a class.")

View file

@ -5,6 +5,7 @@ from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.interface.custom.component import Component
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import (
@ -187,7 +188,7 @@ class CustomComponent(Component):
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
return credential_service.get_credential(user_id=self._user_id, name=name, session=session)
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return get_credential

View file

@ -1,21 +1,19 @@
import datetime
import secrets
import threading
from uuid import UUID
from typing import List, Optional
from uuid import UUID
from sqlmodel import Session, select
from langflow.services.database.models.api_key import (
ApiKey,
ApiKeyCreate,
UnmaskedApiKeyRead,
ApiKeyRead,
)
from sqlmodel.sql.expression import SelectOfScalar
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead
def get_api_keys(session: Session, user_id: UUID) -> List[ApiKeyRead]:
query = select(ApiKey).where(ApiKey.user_id == user_id)
query: SelectOfScalar = select(ApiKey).where(ApiKey.user_id == user_id)
api_keys = session.exec(query).all()
return [ApiKeyRead.from_orm(api_key) for api_key in api_keys]
return [ApiKeyRead.model_validate(api_key) for api_key in api_keys]
def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead:
@ -46,7 +44,7 @@ def delete_api_key(session: Session, api_key_id: UUID) -> None:
def check_key(session: Session, api_key: str) -> Optional[ApiKey]:
"""Check if the API key is valid."""
query = select(ApiKey).where(ApiKey.api_key == api_key)
query: SelectOfScalar = select(ApiKey).where(ApiKey.api_key == api_key)
api_key_object: Optional[ApiKey] = session.exec(query).first()
if api_key_object is not None:
threading.Thread(

View file

@ -2,9 +2,10 @@ from datetime import datetime
from typing import TYPE_CHECKING, Optional
from uuid import UUID, uuid4
from langflow.services.database.models.credential.schema import CredentialType
from sqlmodel import Field, Relationship, SQLModel
from langflow.services.database.models.credential.schema import CredentialType
if TYPE_CHECKING:
from langflow.services.database.models.user import User
@ -24,9 +25,6 @@ class Credential(CredentialBase, table=True):
user_id: UUID = Field(description="User ID associated with this credential", foreign_key="user.id")
user: "User" = Relationship(back_populates="credentials")
if TYPE_CHECKING:
user: "User" = Relationship(back_populates="credentials")
class CredentialCreate(CredentialBase):
# AcceptedProviders is a custom Enum

View file

@ -24,7 +24,7 @@ class User(SQLModel, table=True):
back_populates="user",
sa_relationship_kwargs={"cascade": "delete"},
)
store_api_key: str = Field(default=None, nullable=True)
store_api_key: Optional[str] = Field(default=None, nullable=True)
flows: list["Flow"] = Relationship(back_populates="user")
credentials: list["Credential"] = Relationship(
back_populates="user",

View file

@ -482,7 +482,7 @@ class StoreService(Service):
result: List[ListComponentResponse] = []
authorized = False
metadata = {}
metadata: Dict = {}
comp_count = 0
try:
result, metadata = await self.query_components(