Merge branch 'feature/store' of github.com:logspace-ai/langflow into feature/store
This commit is contained in:
commit
99f0943f23
13 changed files with 142 additions and 122 deletions
|
|
@ -12,7 +12,8 @@ from dotenv import load_dotenv
|
|||
from langflow.main import setup_app
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.services.utils import initialize_services, initialize_settings_service
|
||||
from langflow.services.utils import (initialize_services,
|
||||
initialize_settings_service)
|
||||
from langflow.utils.logger import configure, logger
|
||||
from multiprocess import Process, cpu_count # type: ignore
|
||||
from rich import box
|
||||
|
|
@ -327,11 +328,19 @@ def superuser(
|
|||
|
||||
|
||||
@app.command()
|
||||
def migration(test: bool = typer.Option(True, help="Run migrations in test mode.")):
|
||||
def migration(test: bool = typer.Option(True, help="Run migrations in test mode."),
|
||||
fix: bool = typer.Option(False, help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.")
|
||||
):
|
||||
"""
|
||||
Run or test migrations.
|
||||
"""
|
||||
initialize_services()
|
||||
if fix:
|
||||
if not typer.confirm("This will delete all data necessary to fix migrations. Are you sure you want to continue?"):
|
||||
raise typer.Abort()
|
||||
|
||||
|
||||
|
||||
initialize_services(fix_migration=fix)
|
||||
db_service = get_db_service()
|
||||
if not test:
|
||||
db_service.run_migrations()
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ Create Date: 2023-10-18 23:08:57.744906
|
|||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
from alembic import op
|
||||
from loguru import logger
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7843803a87b5"
|
||||
|
|
@ -28,21 +28,23 @@ def upgrade() -> None:
|
|||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"store_api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=True
|
||||
"store_api_key", sqlmodel.AutoString(), nullable=True
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.drop_column("store_api_key")
|
||||
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
batch_op.drop_column("is_component")
|
||||
try:
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.drop_column("store_api_key")
|
||||
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
batch_op.drop_column("is_component")
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -1,27 +1,9 @@
|
|||
from .constants import (
|
||||
AgentExecutor,
|
||||
BaseChatMemory,
|
||||
BaseLanguageModel,
|
||||
BaseLLM,
|
||||
BaseLoader,
|
||||
BaseMemory,
|
||||
BaseOutputParser,
|
||||
BasePromptTemplate,
|
||||
BaseRetriever,
|
||||
Callable,
|
||||
Chain,
|
||||
ChatPromptTemplate,
|
||||
Data,
|
||||
Document,
|
||||
Embeddings,
|
||||
NestedDict,
|
||||
Object,
|
||||
PromptTemplate,
|
||||
TextSplitter,
|
||||
Tool,
|
||||
VectorStore,
|
||||
)
|
||||
|
||||
from .constants import (AgentExecutor, BaseChatMemory, BaseLanguageModel,
|
||||
BaseLLM, BaseLoader, BaseMemory, BaseOutputParser,
|
||||
BasePromptTemplate, BaseRetriever, Callable, Chain,
|
||||
ChatPromptTemplate, Data, Document, Embeddings,
|
||||
NestedDict, Object, Prompt, PromptTemplate,
|
||||
TextSplitter, Tool, VectorStore)
|
||||
|
||||
__all__ = [
|
||||
"NestedDict",
|
||||
|
|
@ -45,4 +27,5 @@ __all__ = [
|
|||
"Callable",
|
||||
"BasePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"Prompt"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from langchain.chains.base import Chain
|
|||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate
|
||||
from langchain.prompts import (BasePromptTemplate, ChatPromptTemplate,
|
||||
PromptTemplate)
|
||||
from langchain.schema import BaseOutputParser, BaseRetriever, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
|
@ -25,6 +26,9 @@ class Object:
|
|||
class Data:
|
||||
pass
|
||||
|
||||
class Prompt:
|
||||
pass
|
||||
|
||||
|
||||
LANGCHAIN_BASE_TYPES = {
|
||||
"Chain": Chain,
|
||||
|
|
@ -44,18 +48,14 @@ LANGCHAIN_BASE_TYPES = {
|
|||
"BaseOutputParser": BaseOutputParser,
|
||||
"BaseMemory": BaseMemory,
|
||||
"BaseChatMemory": BaseChatMemory,
|
||||
|
||||
}
|
||||
# Langchain base types plus Python base types
|
||||
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
|
||||
**LANGCHAIN_BASE_TYPES,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"NestedDict": NestedDict,
|
||||
"Data": Data,
|
||||
"Object": Object,
|
||||
"Callable": Callable,
|
||||
"Prompt": Prompt,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ class Vertex:
|
|||
if self.base_type == "custom_components":
|
||||
message += " Make sure your build method returns a component."
|
||||
|
||||
raise ValueError(message)
|
||||
logger.warning(message)
|
||||
|
||||
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from typing import Any, Dict, List, Type, Union
|
|||
|
||||
from cachetools import TTLCache, cachedmethod, keys
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails
|
||||
from langflow.interface.custom.schema import (CallableCodeDetails,
|
||||
ClassCodeDetails)
|
||||
|
||||
|
||||
class CodeSyntaxError(HTTPException):
|
||||
|
|
@ -56,6 +57,9 @@ class CodeParser:
|
|||
ast.Assign: self.parse_global_vars,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def __get_tree(self):
|
||||
"""
|
||||
Parses the provided code to validate its syntax.
|
||||
|
|
@ -79,6 +83,7 @@ class CodeParser:
|
|||
if handler := self.handlers.get(type(node)): # type: ignore
|
||||
handler(node) # type: ignore
|
||||
|
||||
|
||||
def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
|
||||
"""
|
||||
Extracts "imports" from the code, including aliases.
|
||||
|
|
@ -149,12 +154,16 @@ class CodeParser:
|
|||
# Handle cases where the type is not found in the constructed environment
|
||||
pass
|
||||
|
||||
|
||||
func = CallableCodeDetails(
|
||||
name=node.name, doc=ast.get_docstring(node), args=[], body=[], return_type=return_type or get_data_type()
|
||||
name=node.name,
|
||||
doc=ast.get_docstring(node),
|
||||
args= self.parse_function_args(node),
|
||||
body= self.parse_function_body(node),
|
||||
return_type=return_type or get_data_type(),
|
||||
has_return=self.parse_return_statement(node),
|
||||
)
|
||||
|
||||
func.args = self.parse_function_args(node)
|
||||
func.body = self.parse_function_body(node)
|
||||
|
||||
return func.model_dump()
|
||||
|
||||
|
|
@ -230,6 +239,14 @@ class CodeParser:
|
|||
"""
|
||||
return [ast.unparse(line) for line in node.body]
|
||||
|
||||
def parse_return_statement(self, node: ast.FunctionDef) -> bool:
|
||||
"""
|
||||
Parses the return statement of a function or method node.
|
||||
"""
|
||||
|
||||
return any(isinstance(n, ast.Return) for n in node.body)
|
||||
|
||||
|
||||
def parse_assign(self, stmt):
|
||||
"""
|
||||
Parses an Assign statement and returns a dictionary
|
||||
|
|
|
|||
|
|
@ -5,13 +5,11 @@ from uuid import UUID
|
|||
import yaml
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.interface.custom.component import Component
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.interface.custom.utils import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
)
|
||||
extract_union_types_from_generic_alias)
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_credential_service, get_db_service
|
||||
|
|
@ -29,14 +27,13 @@ class CustomComponent(Component):
|
|||
repr_value: Optional[Any] = ""
|
||||
user_id: Optional[Union[UUID, str]] = None
|
||||
status: Optional[Any] = None
|
||||
_tree: Optional[dict] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
super().__init__(**data)
|
||||
|
||||
@property
|
||||
def return_type_valid_list(self):
|
||||
return list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
|
||||
|
||||
|
||||
def custom_repr(self):
|
||||
if self.repr_value == "":
|
||||
|
|
@ -78,8 +75,11 @@ class CustomComponent(Component):
|
|||
def validate(self) -> bool:
|
||||
return self._class_template_validation(self.code) if self.code else False
|
||||
|
||||
def get_code_tree(self, code: str):
|
||||
return super().get_code_tree(code)
|
||||
|
||||
|
||||
@property
|
||||
def tree(self):
|
||||
return self.get_code_tree(self.code)
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_args(self) -> list:
|
||||
|
|
@ -108,9 +108,10 @@ class CustomComponent(Component):
|
|||
def get_build_method(self):
|
||||
if not self.code:
|
||||
return []
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
component_classes = [cls for cls in tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
|
||||
|
||||
|
||||
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
|
||||
if not component_classes:
|
||||
return []
|
||||
|
||||
|
|
@ -123,16 +124,19 @@ class CustomComponent(Component):
|
|||
if not build_methods:
|
||||
return []
|
||||
|
||||
|
||||
return build_methods[0]
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_return_type(self) -> List[Any]:
|
||||
build_method = self.get_build_method()
|
||||
if not build_method:
|
||||
return build_method
|
||||
return_type = build_method["return_type"]
|
||||
if not return_type:
|
||||
return []
|
||||
elif not build_method["has_return"]:
|
||||
return []
|
||||
|
||||
return_type = build_method["return_type"]
|
||||
|
||||
# If list or List is in the return type, then we remove it and return the inner type
|
||||
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
|
||||
return_type = extract_inner_type_from_generic_alias(return_type)
|
||||
|
|
@ -141,24 +145,23 @@ class CustomComponent(Component):
|
|||
if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union:
|
||||
if isinstance(return_type, list):
|
||||
return return_type
|
||||
return [return_type] # if return_type in self.return_type_valid_list else []
|
||||
return [return_type]
|
||||
|
||||
# If the return type is a Union, then we need to parse itx
|
||||
return_type = extract_union_types_from_generic_alias(return_type)
|
||||
# return [item for item in return_type if item in self.return_type_valid_list]
|
||||
return return_type
|
||||
|
||||
@property
|
||||
def get_main_class_name(self):
|
||||
if not self.code:
|
||||
return ""
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
|
||||
base_name = self.code_class_base_inheritance
|
||||
method_name = self.function_entrypoint_name
|
||||
|
||||
classes = []
|
||||
for item in tree.get("classes", []):
|
||||
for item in self.tree.get("classes", []):
|
||||
if base_name in item["bases"]:
|
||||
method_names = [method["name"] for method in item["methods"]]
|
||||
if method_name in method_names:
|
||||
|
|
@ -171,11 +174,11 @@ class CustomComponent(Component):
|
|||
def build_template_config(self):
|
||||
if not self.code:
|
||||
return {}
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
|
||||
attributes = [
|
||||
main_class["attributes"]
|
||||
for main_class in tree.get("classes", [])
|
||||
for main_class in self.tree.get("classes", [])
|
||||
if main_class["name"] == self.get_main_class_name
|
||||
]
|
||||
# Get just the first item
|
||||
|
|
@ -219,7 +222,8 @@ class CustomComponent(Component):
|
|||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
|
||||
from langflow.processing.process import build_sorted_vertices, process_tweaks
|
||||
from langflow.processing.process import (build_sorted_vertices,
|
||||
process_tweaks)
|
||||
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
|
|
|
|||
|
|
@ -26,3 +26,4 @@ class CallableCodeDetails(BaseModel):
|
|||
args: list
|
||||
body: list
|
||||
return_type: Optional[Any] = None
|
||||
has_return: bool = False
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from cachetools import LRUCache, cached
|
|||
from fastapi import HTTPException
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.interface.custom.utils import extract_inner_type
|
||||
|
|
@ -30,7 +29,8 @@ from langflow.interface.vector_store.base import vectorstore_creator
|
|||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE
|
||||
from langflow.template.frontend_node.custom_components import CustomComponentFrontendNode
|
||||
from langflow.template.frontend_node.custom_components import \
|
||||
CustomComponentFrontendNode
|
||||
from langflow.utils.util import get_base_classes
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -69,7 +69,6 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
|||
utility_creator,
|
||||
output_parser_creator,
|
||||
retriever_creator,
|
||||
custom_component_creator,
|
||||
]
|
||||
|
||||
all_types = {}
|
||||
|
|
|
|||
|
|
@ -5,17 +5,16 @@ from typing import TYPE_CHECKING
|
|||
import sqlalchemy as sa
|
||||
from alembic import command, util
|
||||
from alembic.config import Config
|
||||
from loguru import logger
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.database import models # noqa
|
||||
from langflow.services.database.models.user.crud import get_user_by_username
|
||||
from langflow.services.database.utils import Result, TableResults
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.services.utils import teardown_superuser
|
||||
from loguru import logger
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.engine import Engine
|
||||
|
|
@ -118,9 +117,10 @@ class DatabaseService(Service):
|
|||
alembic_cfg.set_main_option("script_location", str(self.script_location))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url)
|
||||
command.stamp(alembic_cfg, "head")
|
||||
# command.upgrade(alembic_cfg, "head")
|
||||
logger.info("Alembic initialized")
|
||||
|
||||
def run_migrations(self):
|
||||
def run_migrations(self, fix=False):
|
||||
# First we need to check if alembic has been initialized
|
||||
# If not, we need to initialize it
|
||||
# if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized
|
||||
|
|
@ -151,16 +151,32 @@ class DatabaseService(Service):
|
|||
if isinstance(exc, util.exc.CommandError) or isinstance(exc, util.exc.AutogenerateDiffsDetected):
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
# We should check the schema health after running migrations
|
||||
try:
|
||||
command.check(alembic_cfg)
|
||||
except util.exc.AutogenerateDiffsDetected:
|
||||
# downgrade to base and upgrade again
|
||||
logger.warning("Autogenerate diffs detected, downgrading and upgrading")
|
||||
command.downgrade(alembic_cfg, "-1")
|
||||
# wait for the database to be ready
|
||||
time.sleep(5)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
except util.exc.AutogenerateDiffsDetected as exc:
|
||||
logger.exception("AutogenerateDiffsDetected: {exc}")
|
||||
raise RuntimeError("Something went wrong running migrations. Please, run `langflow migration --fix`")
|
||||
|
||||
if fix:
|
||||
self.try_downgrade_upgrade_until_success(alembic_cfg)
|
||||
|
||||
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5):
|
||||
# Try -1 then head, if it fails, try -2 then head, etc.
|
||||
# until we reach the number of retries
|
||||
for i in range(1, retries + 1):
|
||||
try:
|
||||
command.check(alembic_cfg)
|
||||
break
|
||||
except util.exc.AutogenerateDiffsDetected as exc:
|
||||
|
||||
# downgrade to base and upgrade again
|
||||
logger.warning(f"AutogenerateDiffsDetected: {exc}")
|
||||
command.downgrade(alembic_cfg, f"-{i}")
|
||||
# wait for the database to be ready
|
||||
time.sleep(3)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
|
||||
|
||||
def run_migrations_test(self):
|
||||
# This method is used for testing purposes only
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from loguru import logger
|
||||
from contextlib import contextmanager
|
||||
|
||||
from alembic.util.exc import CommandError
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.service import DatabaseService
|
||||
|
||||
|
||||
def initialize_database():
|
||||
def initialize_database(fix_migration: bool = False):
|
||||
logger.debug("Initializing database")
|
||||
from langflow.services import service_manager, ServiceType
|
||||
from langflow.services import ServiceType, service_manager
|
||||
|
||||
database_service: "DatabaseService" = service_manager.get(ServiceType.DATABASE_SERVICE)
|
||||
try:
|
||||
|
|
@ -28,7 +29,7 @@ def initialize_database():
|
|||
logger.error(f"Error checking schema health: {exc}")
|
||||
raise RuntimeError("Error checking schema health") from exc
|
||||
try:
|
||||
database_service.run_migrations()
|
||||
database_service.run_migrations(fix=fix_migration)
|
||||
except CommandError as exc:
|
||||
# if "overlaps with other requested revisions" or "Can't locate revision identified by"
|
||||
# are not in the exception, we can't handle it
|
||||
|
|
@ -47,8 +48,9 @@ def initialize_database():
|
|||
# if the exception involves tables already existing
|
||||
# we can ignore it
|
||||
if "already exists" not in str(exc):
|
||||
logger.error(f"Error running migrations: {exc}")
|
||||
raise RuntimeError("Error running migrations") from exc
|
||||
logger.error(exc)
|
||||
|
||||
raise exc
|
||||
logger.debug("Database initialized")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from langflow.services.auth.utils import create_super_user, verify_password
|
|||
from langflow.services.database.utils import initialize_database
|
||||
from langflow.services.manager import service_manager
|
||||
from langflow.services.schema import ServiceType
|
||||
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
|
||||
from langflow.services.settings.constants import (DEFAULT_SUPERUSER,
|
||||
DEFAULT_SUPERUSER_PASSWORD)
|
||||
from loguru import logger
|
||||
from sqlmodel import Session
|
||||
|
||||
|
|
@ -13,12 +14,13 @@ def get_factories_and_deps():
|
|||
from langflow.services.auth import factory as auth_factory
|
||||
from langflow.services.cache import factory as cache_factory
|
||||
from langflow.services.chat import factory as chat_factory
|
||||
from langflow.services.credentials import factory as credentials_factory
|
||||
from langflow.services.database import factory as database_factory
|
||||
from langflow.services.session import factory as session_service_factory # type: ignore
|
||||
from langflow.services.session import \
|
||||
factory as session_service_factory # type: ignore
|
||||
from langflow.services.settings import factory as settings_factory
|
||||
from langflow.services.store import factory as store_factory
|
||||
from langflow.services.task import factory as task_factory
|
||||
from langflow.services.credentials import factory as credentials_factory
|
||||
|
||||
return [
|
||||
(settings_factory.SettingsServiceFactory(), []),
|
||||
|
|
@ -171,7 +173,8 @@ def initialize_session_service():
|
|||
Initialize the session manager.
|
||||
"""
|
||||
from langflow.services.cache import factory as cache_factory
|
||||
from langflow.services.session import factory as session_service_factory # type: ignore
|
||||
from langflow.services.session import \
|
||||
factory as session_service_factory # type: ignore
|
||||
|
||||
initialize_settings_service()
|
||||
|
||||
|
|
@ -183,7 +186,7 @@ def initialize_session_service():
|
|||
)
|
||||
|
||||
|
||||
def initialize_services():
|
||||
def initialize_services(fix_migration: bool = False):
|
||||
"""
|
||||
Initialize all the services needed.
|
||||
"""
|
||||
|
|
@ -197,7 +200,11 @@ def initialize_services():
|
|||
# Test cache connection
|
||||
service_manager.get(ServiceType.CACHE_SERVICE)
|
||||
# Setup the superuser
|
||||
initialize_database()
|
||||
try:
|
||||
initialize_database(fix_migration=fix_migration)
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise exc
|
||||
setup_superuser(service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session()))
|
||||
try:
|
||||
get_db_service().migrate_flows_if_auto_login()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import importlib
|
|||
from types import FunctionType
|
||||
from typing import Dict
|
||||
|
||||
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
|
||||
|
||||
def add_type_ignores():
|
||||
if not hasattr(ast, "TypeIgnore"):
|
||||
|
|
@ -266,29 +268,7 @@ def get_default_imports(code_string):
|
|||
# Add more imports from the typing module as needed
|
||||
}
|
||||
|
||||
langflow_imports = [
|
||||
"AgentExecutor",
|
||||
"BaseChatMemory",
|
||||
"BaseLanguageModel",
|
||||
"BaseLLM",
|
||||
"BaseLoader",
|
||||
"BaseMemory",
|
||||
"BaseOutputParser",
|
||||
"BasePromptTemplate",
|
||||
"BaseRetriever",
|
||||
"Callable",
|
||||
"Chain",
|
||||
"ChatPromptTemplate",
|
||||
"Data",
|
||||
"Document",
|
||||
"Embeddings",
|
||||
"NestedDict",
|
||||
"Object",
|
||||
"PromptTemplate",
|
||||
"TextSplitter",
|
||||
"Tool",
|
||||
"VectorStore",
|
||||
]
|
||||
langflow_imports = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
|
||||
necessary_imports = find_names_in_code(code_string, langflow_imports)
|
||||
langflow_module = importlib.import_module("langflow.field_typing")
|
||||
default_imports.update({name: getattr(langflow_module, name) for name in necessary_imports})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue