Merge branch 'feature/store' of github.com:logspace-ai/langflow into feature/store

This commit is contained in:
cristhianzl 2023-12-01 16:44:43 -03:00
commit 99f0943f23
13 changed files with 142 additions and 122 deletions

View file

@ -12,7 +12,8 @@ from dotenv import load_dotenv
from langflow.main import setup_app
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service
from langflow.services.utils import initialize_services, initialize_settings_service
from langflow.services.utils import (initialize_services,
initialize_settings_service)
from langflow.utils.logger import configure, logger
from multiprocess import Process, cpu_count # type: ignore
from rich import box
@ -327,11 +328,19 @@ def superuser(
@app.command()
def migration(test: bool = typer.Option(True, help="Run migrations in test mode.")):
def migration(test: bool = typer.Option(True, help="Run migrations in test mode."),
fix: bool = typer.Option(False, help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.")
):
"""
Run or test migrations.
"""
initialize_services()
if fix:
if not typer.confirm("This will delete all data necessary to fix migrations. Are you sure you want to continue?"):
raise typer.Abort()
initialize_services(fix_migration=fix)
db_service = get_db_service()
if not test:
db_service.run_migrations()

View file

@ -7,10 +7,10 @@ Create Date: 2023-10-18 23:08:57.744906
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
from alembic import op
from loguru import logger
# revision identifiers, used by Alembic.
revision: str = "7843803a87b5"
@ -28,21 +28,23 @@ def upgrade() -> None:
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"store_api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=True
"store_api_key", sqlmodel.AutoString(), nullable=True
)
)
except Exception:
pass
except Exception as e:
logger.exception(e)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.drop_column("store_api_key")
with op.batch_alter_table("flow", schema=None) as batch_op:
batch_op.drop_column("is_component")
try:
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.drop_column("store_api_key")
with op.batch_alter_table("flow", schema=None) as batch_op:
batch_op.drop_column("is_component")
except Exception:
pass
# ### end Alembic commands ###

View file

@ -1,27 +1,9 @@
from .constants import (
AgentExecutor,
BaseChatMemory,
BaseLanguageModel,
BaseLLM,
BaseLoader,
BaseMemory,
BaseOutputParser,
BasePromptTemplate,
BaseRetriever,
Callable,
Chain,
ChatPromptTemplate,
Data,
Document,
Embeddings,
NestedDict,
Object,
PromptTemplate,
TextSplitter,
Tool,
VectorStore,
)
from .constants import (AgentExecutor, BaseChatMemory, BaseLanguageModel,
BaseLLM, BaseLoader, BaseMemory, BaseOutputParser,
BasePromptTemplate, BaseRetriever, Callable, Chain,
ChatPromptTemplate, Data, Document, Embeddings,
NestedDict, Object, Prompt, PromptTemplate,
TextSplitter, Tool, VectorStore)
__all__ = [
"NestedDict",
@ -45,4 +27,5 @@ __all__ = [
"Callable",
"BasePromptTemplate",
"ChatPromptTemplate",
"Prompt"
]

View file

@ -5,7 +5,8 @@ from langchain.chains.base import Chain
from langchain.document_loaders.base import BaseLoader
from langchain.llms.base import BaseLLM
from langchain.memory.chat_memory import BaseChatMemory
from langchain.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate
from langchain.prompts import (BasePromptTemplate, ChatPromptTemplate,
PromptTemplate)
from langchain.schema import BaseOutputParser, BaseRetriever, Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.language_model import BaseLanguageModel
@ -25,6 +26,9 @@ class Object:
class Data:
pass
class Prompt:
pass
LANGCHAIN_BASE_TYPES = {
"Chain": Chain,
@ -44,18 +48,14 @@ LANGCHAIN_BASE_TYPES = {
"BaseOutputParser": BaseOutputParser,
"BaseMemory": BaseMemory,
"BaseChatMemory": BaseChatMemory,
}
# Langchain base types plus Python base types
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
**LANGCHAIN_BASE_TYPES,
"str": str,
"int": int,
"float": float,
"bool": bool,
"list": list,
"dict": dict,
"NestedDict": NestedDict,
"Data": Data,
"Object": Object,
"Callable": Callable,
"Prompt": Prompt,
}

View file

@ -345,7 +345,7 @@ class Vertex:
if self.base_type == "custom_components":
message += " Make sure your build method returns a component."
raise ValueError(message)
logger.warning(message)
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
if not self._built or force:

View file

@ -6,7 +6,8 @@ from typing import Any, Dict, List, Type, Union
from cachetools import TTLCache, cachedmethod, keys
from fastapi import HTTPException
from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails
from langflow.interface.custom.schema import (CallableCodeDetails,
ClassCodeDetails)
class CodeSyntaxError(HTTPException):
@ -56,6 +57,9 @@ class CodeParser:
ast.Assign: self.parse_global_vars,
}
def __get_tree(self):
"""
Parses the provided code to validate its syntax.
@ -79,6 +83,7 @@ class CodeParser:
if handler := self.handlers.get(type(node)): # type: ignore
handler(node) # type: ignore
def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
"""
Extracts "imports" from the code, including aliases.
@ -149,12 +154,16 @@ class CodeParser:
# Handle cases where the type is not found in the constructed environment
pass
func = CallableCodeDetails(
name=node.name, doc=ast.get_docstring(node), args=[], body=[], return_type=return_type or get_data_type()
name=node.name,
doc=ast.get_docstring(node),
args= self.parse_function_args(node),
body= self.parse_function_body(node),
return_type=return_type or get_data_type(),
has_return=self.parse_return_statement(node),
)
func.args = self.parse_function_args(node)
func.body = self.parse_function_body(node)
return func.model_dump()
@ -230,6 +239,14 @@ class CodeParser:
"""
return [ast.unparse(line) for line in node.body]
def parse_return_statement(self, node: ast.FunctionDef) -> bool:
"""
Parses the return statement of a function or method node.
"""
return any(isinstance(n, ast.Return) for n in node.body)
def parse_assign(self, stmt):
"""
Parses an Assign statement and returns a dictionary

View file

@ -5,13 +5,11 @@ from uuid import UUID
import yaml
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.custom.component import Component
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
extract_union_types_from_generic_alias)
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_credential_service, get_db_service
@ -29,14 +27,13 @@ class CustomComponent(Component):
repr_value: Optional[Any] = ""
user_id: Optional[Union[UUID, str]] = None
status: Optional[Any] = None
_tree: Optional[dict] = None
def __init__(self, **data):
self.cache = TTLCache(maxsize=1024, ttl=60)
super().__init__(**data)
@property
def return_type_valid_list(self):
return list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
def custom_repr(self):
if self.repr_value == "":
@ -78,8 +75,11 @@ class CustomComponent(Component):
def validate(self) -> bool:
return self._class_template_validation(self.code) if self.code else False
def get_code_tree(self, code: str):
return super().get_code_tree(code)
@property
def tree(self):
return self.get_code_tree(self.code)
@property
def get_function_entrypoint_args(self) -> list:
@ -108,9 +108,10 @@ class CustomComponent(Component):
def get_build_method(self):
if not self.code:
return []
tree = self.get_code_tree(self.code)
component_classes = [cls for cls in tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
if not component_classes:
return []
@ -123,16 +124,19 @@ class CustomComponent(Component):
if not build_methods:
return []
return build_methods[0]
@property
def get_function_entrypoint_return_type(self) -> List[Any]:
build_method = self.get_build_method()
if not build_method:
return build_method
return_type = build_method["return_type"]
if not return_type:
return []
elif not build_method["has_return"]:
return []
return_type = build_method["return_type"]
# If list or List is in the return type, then we remove it and return the inner type
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
return_type = extract_inner_type_from_generic_alias(return_type)
@ -141,24 +145,23 @@ class CustomComponent(Component):
if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union:
if isinstance(return_type, list):
return return_type
return [return_type] # if return_type in self.return_type_valid_list else []
return [return_type]
# If the return type is a Union, then we need to parse itx
return_type = extract_union_types_from_generic_alias(return_type)
# return [item for item in return_type if item in self.return_type_valid_list]
return return_type
@property
def get_main_class_name(self):
if not self.code:
return ""
tree = self.get_code_tree(self.code)
base_name = self.code_class_base_inheritance
method_name = self.function_entrypoint_name
classes = []
for item in tree.get("classes", []):
for item in self.tree.get("classes", []):
if base_name in item["bases"]:
method_names = [method["name"] for method in item["methods"]]
if method_name in method_names:
@ -171,11 +174,11 @@ class CustomComponent(Component):
def build_template_config(self):
if not self.code:
return {}
tree = self.get_code_tree(self.code)
attributes = [
main_class["attributes"]
for main_class in tree.get("classes", [])
for main_class in self.tree.get("classes", [])
if main_class["name"] == self.get_main_class_name
]
# Get just the first item
@ -219,7 +222,8 @@ class CustomComponent(Component):
return validate.create_function(self.code, self.function_entrypoint_name)
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
from langflow.processing.process import build_sorted_vertices, process_tweaks
from langflow.processing.process import (build_sorted_vertices,
process_tweaks)
db_service = get_db_service()
with session_getter(db_service) as session:

View file

@ -26,3 +26,4 @@ class CallableCodeDetails(BaseModel):
args: list
body: list
return_type: Optional[Any] = None
has_return: bool = False

View file

@ -10,7 +10,6 @@ from cachetools import LRUCache, cached
from fastapi import HTTPException
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.custom.base import custom_component_creator
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import extract_inner_type
@ -30,7 +29,8 @@ from langflow.interface.vector_store.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.template.field.base import TemplateField
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE
from langflow.template.frontend_node.custom_components import CustomComponentFrontendNode
from langflow.template.frontend_node.custom_components import \
CustomComponentFrontendNode
from langflow.utils.util import get_base_classes
from loguru import logger
@ -69,7 +69,6 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
utility_creator,
output_parser_creator,
retriever_creator,
custom_component_creator,
]
all_types = {}

View file

@ -5,17 +5,16 @@ from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import command, util
from alembic.config import Config
from loguru import logger
from sqlalchemy import inspect
from sqlalchemy.exc import OperationalError
from sqlmodel import Session, SQLModel, create_engine
from langflow.services.base import Service
from langflow.services.database import models # noqa
from langflow.services.database.models.user.crud import get_user_by_username
from langflow.services.database.utils import Result, TableResults
from langflow.services.deps import get_settings_service
from langflow.services.utils import teardown_superuser
from loguru import logger
from sqlalchemy import inspect
from sqlalchemy.exc import OperationalError
from sqlmodel import Session, SQLModel, create_engine
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
@ -118,9 +117,10 @@ class DatabaseService(Service):
alembic_cfg.set_main_option("script_location", str(self.script_location))
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url)
command.stamp(alembic_cfg, "head")
# command.upgrade(alembic_cfg, "head")
logger.info("Alembic initialized")
def run_migrations(self):
def run_migrations(self, fix=False):
# First we need to check if alembic has been initialized
# If not, we need to initialize it
# if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized
@ -151,16 +151,32 @@ class DatabaseService(Service):
if isinstance(exc, util.exc.CommandError) or isinstance(exc, util.exc.AutogenerateDiffsDetected):
command.upgrade(alembic_cfg, "head")
# We should check the schema health after running migrations
try:
command.check(alembic_cfg)
except util.exc.AutogenerateDiffsDetected:
# downgrade to base and upgrade again
logger.warning("Autogenerate diffs detected, downgrading and upgrading")
command.downgrade(alembic_cfg, "-1")
# wait for the database to be ready
time.sleep(5)
command.upgrade(alembic_cfg, "head")
except util.exc.AutogenerateDiffsDetected as exc:
logger.exception("AutogenerateDiffsDetected: {exc}")
raise RuntimeError("Something went wrong running migrations. Please, run `langflow migration --fix`")
if fix:
self.try_downgrade_upgrade_until_success(alembic_cfg)
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5):
# Try -1 then head, if it fails, try -2 then head, etc.
# until we reach the number of retries
for i in range(1, retries + 1):
try:
command.check(alembic_cfg)
break
except util.exc.AutogenerateDiffsDetected as exc:
# downgrade to base and upgrade again
logger.warning(f"AutogenerateDiffsDetected: {exc}")
command.downgrade(alembic_cfg, f"-{i}")
# wait for the database to be ready
time.sleep(3)
command.upgrade(alembic_cfg, "head")
def run_migrations_test(self):
# This method is used for testing purposes only

View file

@ -1,17 +1,18 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING
from loguru import logger
from contextlib import contextmanager
from alembic.util.exc import CommandError
from loguru import logger
from sqlmodel import Session
if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
def initialize_database():
def initialize_database(fix_migration: bool = False):
logger.debug("Initializing database")
from langflow.services import service_manager, ServiceType
from langflow.services import ServiceType, service_manager
database_service: "DatabaseService" = service_manager.get(ServiceType.DATABASE_SERVICE)
try:
@ -28,7 +29,7 @@ def initialize_database():
logger.error(f"Error checking schema health: {exc}")
raise RuntimeError("Error checking schema health") from exc
try:
database_service.run_migrations()
database_service.run_migrations(fix=fix_migration)
except CommandError as exc:
# if "overlaps with other requested revisions" or "Can't locate revision identified by"
# are not in the exception, we can't handle it
@ -47,8 +48,9 @@ def initialize_database():
# if the exception involves tables already existing
# we can ignore it
if "already exists" not in str(exc):
logger.error(f"Error running migrations: {exc}")
raise RuntimeError("Error running migrations") from exc
logger.error(exc)
raise exc
logger.debug("Database initialized")

View file

@ -2,7 +2,8 @@ from langflow.services.auth.utils import create_super_user, verify_password
from langflow.services.database.utils import initialize_database
from langflow.services.manager import service_manager
from langflow.services.schema import ServiceType
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
from langflow.services.settings.constants import (DEFAULT_SUPERUSER,
DEFAULT_SUPERUSER_PASSWORD)
from loguru import logger
from sqlmodel import Session
@ -13,12 +14,13 @@ def get_factories_and_deps():
from langflow.services.auth import factory as auth_factory
from langflow.services.cache import factory as cache_factory
from langflow.services.chat import factory as chat_factory
from langflow.services.credentials import factory as credentials_factory
from langflow.services.database import factory as database_factory
from langflow.services.session import factory as session_service_factory # type: ignore
from langflow.services.session import \
factory as session_service_factory # type: ignore
from langflow.services.settings import factory as settings_factory
from langflow.services.store import factory as store_factory
from langflow.services.task import factory as task_factory
from langflow.services.credentials import factory as credentials_factory
return [
(settings_factory.SettingsServiceFactory(), []),
@ -171,7 +173,8 @@ def initialize_session_service():
Initialize the session manager.
"""
from langflow.services.cache import factory as cache_factory
from langflow.services.session import factory as session_service_factory # type: ignore
from langflow.services.session import \
factory as session_service_factory # type: ignore
initialize_settings_service()
@ -183,7 +186,7 @@ def initialize_session_service():
)
def initialize_services():
def initialize_services(fix_migration: bool = False):
"""
Initialize all the services needed.
"""
@ -197,7 +200,11 @@ def initialize_services():
# Test cache connection
service_manager.get(ServiceType.CACHE_SERVICE)
# Setup the superuser
initialize_database()
try:
initialize_database(fix_migration=fix_migration)
except Exception as exc:
logger.exception(exc)
raise exc
setup_superuser(service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session()))
try:
get_db_service().migrate_flows_if_auto_login()

View file

@ -4,6 +4,8 @@ import importlib
from types import FunctionType
from typing import Dict
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
def add_type_ignores():
if not hasattr(ast, "TypeIgnore"):
@ -266,29 +268,7 @@ def get_default_imports(code_string):
# Add more imports from the typing module as needed
}
langflow_imports = [
"AgentExecutor",
"BaseChatMemory",
"BaseLanguageModel",
"BaseLLM",
"BaseLoader",
"BaseMemory",
"BaseOutputParser",
"BasePromptTemplate",
"BaseRetriever",
"Callable",
"Chain",
"ChatPromptTemplate",
"Data",
"Document",
"Embeddings",
"NestedDict",
"Object",
"PromptTemplate",
"TextSplitter",
"Tool",
"VectorStore",
]
langflow_imports = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
necessary_imports = find_names_in_code(code_string, langflow_imports)
langflow_module = importlib.import_module("langflow.field_typing")
default_imports.update({name: getattr(langflow_module, name) for name in necessary_imports})