Merge remote-tracking branch 'origin/dev' into FileLoaders
This commit is contained in:
commit
a899d8a081
420 changed files with 37922 additions and 8145 deletions
|
|
@ -11,4 +11,4 @@ RUN rm *.whl
|
|||
|
||||
EXPOSE 80
|
||||
|
||||
CMD [ "uvicorn", "--host", "0.0.0.0", "--port", "80", "langflow.backend.app:app" ]
|
||||
CMD [ "uvicorn", "--host", "0.0.0.0", "--port", "7860", "--factory", "langflow.main:create_app" ]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from importlib import metadata
|
||||
from langflow.cache import cache_manager
|
||||
|
||||
# Deactivate cache manager for now
|
||||
# from langflow.services.cache import cache_service
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
|
|
@ -10,4 +12,4 @@ except metadata.PackageNotFoundError:
|
|||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
__all__ = ["load_flow_from_json", "cache_manager", "CustomComponent"]
|
||||
__all__ = ["load_flow_from_json", "cache_service", "CustomComponent"]
|
||||
|
|
|
|||
|
|
@ -1,23 +1,69 @@
|
|||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import httpx
|
||||
from langflow.utils.util import get_number_of_workers
|
||||
from multiprocess import Process # type: ignore
|
||||
import platform
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import socket
|
||||
from rich.panel import Panel
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
from dotenv import load_dotenv
|
||||
from langflow.main import setup_app
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service, get_settings_service
|
||||
from langflow.services.utils import initialize_services, initialize_settings_service
|
||||
from langflow.utils.logger import configure, logger
|
||||
from multiprocess import Process, cpu_count # type: ignore
|
||||
from rich import box
|
||||
from rich import print as rprint
|
||||
import typer
|
||||
from langflow.main import setup_app
|
||||
from langflow.settings import settings
|
||||
from langflow.utils.logger import configure, logger
|
||||
import webbrowser
|
||||
from dotenv import load_dotenv
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
app = typer.Typer()
|
||||
console = Console()
|
||||
|
||||
app = typer.Typer(no_args_is_help=True)
|
||||
|
||||
|
||||
def get_number_of_workers(workers=None):
|
||||
if workers == -1 or workers is None:
|
||||
workers = (cpu_count() * 2) + 1
|
||||
logger.debug(f"Number of workers: {workers}")
|
||||
return workers
|
||||
|
||||
|
||||
def display_results(results):
|
||||
"""
|
||||
Display the results of the migration.
|
||||
"""
|
||||
for table_results in results:
|
||||
table = Table(title=f"Migration {table_results.table_name}")
|
||||
table.add_column("Name")
|
||||
table.add_column("Type")
|
||||
table.add_column("Status")
|
||||
|
||||
for result in table_results.results:
|
||||
status = "Success" if result.success else "Failure"
|
||||
color = "green" if result.success else "red"
|
||||
table.add_row(result.name, result.type, f"[{color}]{status}[/{color}]")
|
||||
|
||||
console.print(table)
|
||||
console.print() # Print a new line
|
||||
|
||||
|
||||
def set_var_for_macos_issue():
|
||||
# OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
|
||||
# we need to set this var is we are running on MacOS
|
||||
# otherwise we get an error when running gunicorn
|
||||
|
||||
if platform.system() in ["Darwin"]:
|
||||
import os
|
||||
|
||||
os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES"
|
||||
# https://stackoverflow.com/questions/75747888/uwsgi-segmentation-fault-with-flask-python-app-behind-nginx-after-running-for-2 # noqa
|
||||
os.environ["no_proxy"] = "*" # to avoid error with gunicorn
|
||||
logger.debug("Set OBJC_DISABLE_INITIALIZE_FORK_SAFETY to YES to avoid error")
|
||||
|
||||
|
||||
def update_settings(
|
||||
|
|
@ -30,74 +76,29 @@ def update_settings(
|
|||
"""Update the settings from a config file."""
|
||||
|
||||
# Check for database_url in the environment variables
|
||||
|
||||
initialize_settings_service()
|
||||
settings_service = get_settings_service()
|
||||
if config:
|
||||
logger.debug(f"Loading settings from {config}")
|
||||
settings.update_from_yaml(config, dev=dev)
|
||||
settings_service.settings.update_from_yaml(config, dev=dev)
|
||||
if remove_api_keys:
|
||||
logger.debug(f"Setting remove_api_keys to {remove_api_keys}")
|
||||
settings.update_settings(REMOVE_API_KEYS=remove_api_keys)
|
||||
settings_service.settings.update_settings(REMOVE_API_KEYS=remove_api_keys)
|
||||
if cache:
|
||||
logger.debug(f"Setting cache to {cache}")
|
||||
settings.update_settings(CACHE=cache)
|
||||
settings_service.settings.update_settings(CACHE=cache)
|
||||
if components_path:
|
||||
logger.debug(f"Adding component path {components_path}")
|
||||
settings.update_settings(COMPONENTS_PATH=components_path)
|
||||
|
||||
|
||||
def serve_on_jcloud():
|
||||
"""
|
||||
Deploy Langflow server on Jina AI Cloud
|
||||
"""
|
||||
import asyncio
|
||||
from importlib.metadata import version as mod_version
|
||||
|
||||
import click
|
||||
|
||||
try:
|
||||
from lcserve.__main__ import serve_on_jcloud # type: ignore
|
||||
except ImportError:
|
||||
click.secho(
|
||||
"🚨 Please install langchain-serve to deploy Langflow server on Jina AI Cloud "
|
||||
"using `pip install langchain-serve`",
|
||||
fg="red",
|
||||
)
|
||||
return
|
||||
|
||||
app_name = "langflow.lcserve:app"
|
||||
app_dir = str(Path(__file__).parent)
|
||||
version = mod_version("langflow")
|
||||
base_image = "jinaai+docker://deepankarm/langflow"
|
||||
|
||||
click.echo("🚀 Deploying Langflow server on Jina AI Cloud")
|
||||
app_id = asyncio.run(
|
||||
serve_on_jcloud(
|
||||
fastapi_app_str=app_name,
|
||||
app_dir=app_dir,
|
||||
uses=f"{base_image}:{version}",
|
||||
name="langflow",
|
||||
)
|
||||
)
|
||||
click.secho(
|
||||
"🎉 Langflow server successfully deployed on Jina AI Cloud 🎉", fg="green"
|
||||
)
|
||||
click.secho(
|
||||
"🔗 Click on the link to open the server (please allow ~1-2 minutes for the server to startup): ",
|
||||
nl=False,
|
||||
fg="green",
|
||||
)
|
||||
click.secho(f"https://{app_id}.wolf.jina.ai/", fg="blue")
|
||||
click.secho("📖 Read more about managing the server: ", nl=False, fg="green")
|
||||
click.secho("https://github.com/jina-ai/langchain-serve", fg="blue")
|
||||
settings_service.settings.update_settings(COMPONENTS_PATH=components_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
def run(
|
||||
host: str = typer.Option(
|
||||
"127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST"
|
||||
),
|
||||
workers: int = typer.Option(
|
||||
2, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"
|
||||
1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"
|
||||
),
|
||||
timeout: int = typer.Option(300, help="Worker timeout in seconds."),
|
||||
port: int = typer.Option(7860, help="Port to listen on.", envvar="LANGFLOW_PORT"),
|
||||
|
|
@ -106,7 +107,9 @@ def serve(
|
|||
help="Path to the directory containing custom components.",
|
||||
envvar="LANGFLOW_COMPONENTS_PATH",
|
||||
),
|
||||
config: str = typer.Option("config.yaml", help="Path to the configuration file."),
|
||||
config: str = typer.Option(
|
||||
Path(__file__).parent / "config.yaml", help="Path to the configuration file."
|
||||
),
|
||||
# .env file param
|
||||
env_file: Path = typer.Option(
|
||||
None, help="Path to the .env file containing environment variables."
|
||||
|
|
@ -122,7 +125,6 @@ def serve(
|
|||
help="Type of cache to use. (InMemoryCache, SQLiteCache)",
|
||||
default=None,
|
||||
),
|
||||
jcloud: bool = typer.Option(False, help="Deploy on Jina AI Cloud"),
|
||||
dev: bool = typer.Option(False, help="Run in development mode (may contain bugs)"),
|
||||
# This variable does not work but is set by the .env file
|
||||
# and works with Pydantic
|
||||
|
|
@ -146,17 +148,22 @@ def serve(
|
|||
help="Remove API keys from the projects saved in the database.",
|
||||
envvar="LANGFLOW_REMOVE_API_KEYS",
|
||||
),
|
||||
backend_only: bool = typer.Option(
|
||||
False,
|
||||
help="Run only the backend server without the frontend.",
|
||||
envvar="LANGFLOW_BACKEND_ONLY",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run the Langflow server.
|
||||
Run the Langflow.
|
||||
"""
|
||||
|
||||
set_var_for_macos_issue()
|
||||
# override env variables with .env file
|
||||
|
||||
if env_file:
|
||||
load_dotenv(env_file, override=True)
|
||||
|
||||
if jcloud:
|
||||
return serve_on_jcloud()
|
||||
|
||||
configure(log_level=log_level, log_file=log_file)
|
||||
update_settings(
|
||||
config,
|
||||
|
|
@ -167,7 +174,7 @@ def serve(
|
|||
)
|
||||
# create path object if path is provided
|
||||
static_files_dir: Optional[Path] = Path(path) if path else None
|
||||
app = setup_app(static_files_dir=static_files_dir)
|
||||
app = setup_app(static_files_dir=static_files_dir, backend_only=backend_only)
|
||||
# check if port is being used
|
||||
if is_port_in_use(port, host):
|
||||
port = get_free_port(port)
|
||||
|
|
@ -175,10 +182,13 @@ def serve(
|
|||
options = {
|
||||
"bind": f"{host}:{port}",
|
||||
"workers": get_number_of_workers(workers),
|
||||
"worker_class": "uvicorn.workers.UvicornWorker",
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
# Define an env variable to know if we are just testing the server
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
|
||||
if platform.system() in ["Windows"]:
|
||||
# Run using uvicorn on MacOS and Windows
|
||||
# Windows doesn't support gunicorn
|
||||
|
|
@ -299,6 +309,53 @@ def run_langflow(host, port, log_level, options, app):
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def superuser(
|
||||
username: str = typer.Option(..., prompt=True, help="Username for the superuser."),
|
||||
password: str = typer.Option(
|
||||
..., prompt=True, hide_input=True, help="Password for the superuser."
|
||||
),
|
||||
log_level: str = typer.Option(
|
||||
"critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Create a superuser.
|
||||
"""
|
||||
configure(log_level=log_level)
|
||||
initialize_services()
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
from langflow.services.auth.utils import create_super_user
|
||||
|
||||
if create_super_user(db=session, username=username, password=password):
|
||||
# Verify that the superuser was created
|
||||
from langflow.services.database.models.user.user import User
|
||||
|
||||
user: User = session.query(User).filter(User.username == username).first()
|
||||
if user is None or not user.is_superuser:
|
||||
typer.echo("Superuser creation failed.")
|
||||
return
|
||||
|
||||
typer.echo("Superuser created successfully.")
|
||||
|
||||
else:
|
||||
typer.echo("Superuser creation failed.")
|
||||
|
||||
|
||||
@app.command()
|
||||
def migration(test: bool = typer.Option(True, help="Run migrations in test mode.")):
|
||||
"""
|
||||
Run or test migrations.
|
||||
"""
|
||||
initialize_services()
|
||||
db_service = get_db_service()
|
||||
if not test:
|
||||
db_service.run_migrations()
|
||||
results = db_service.run_migrations_test()
|
||||
display_results(results)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
|
|
|||
113
src/backend/langflow/alembic.ini
Normal file
113
src/backend/langflow/alembic.ini
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# This is the path to the db in the root of the project.
|
||||
# When the user runs the Langflow the database url will
|
||||
# be set dinamically.
|
||||
sqlalchemy.url = sqlite:///../../../langflow.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
src/backend/langflow/alembic/README
Normal file
1
src/backend/langflow/alembic/README
Normal file
|
|
@ -0,0 +1 @@
|
|||
Generic single-database configuration.
|
||||
81
src/backend/langflow/alembic/env.py
Normal file
81
src/backend/langflow/alembic/env.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
from langflow.services.database.manager import SQLModel
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
render_as_batch=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata, render_as_batch=True
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
27
src/backend/langflow/alembic/script.py.mako
Normal file
27
src/backend/langflow/alembic/script.py.mako
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
"""Adds tables
|
||||
|
||||
Revision ID: 260dbcc8b680
|
||||
Revises:
|
||||
Create Date: 2023-08-27 19:49:02.681355
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "260dbcc8b680"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
# List existing tables
|
||||
existing_tables = inspector.get_table_names()
|
||||
# Drop 'flowstyle' table if it exists
|
||||
# and other related indices
|
||||
if "flowstyle" in existing_tables:
|
||||
op.drop_table("flowstyle")
|
||||
if "ix_flowstyle_flow_id" in [
|
||||
index["name"] for index in inspector.get_indexes("flowstyle")
|
||||
]:
|
||||
op.drop_index("ix_flowstyle_flow_id", table_name="flowstyle")
|
||||
|
||||
existing_indices_flow = []
|
||||
existing_fks_flow = []
|
||||
if "flow" in existing_tables:
|
||||
existing_indices_flow = [
|
||||
index["name"] for index in inspector.get_indexes("flow")
|
||||
]
|
||||
# Existing foreign keys for the 'flow' table, if it exists
|
||||
existing_fks_flow = [
|
||||
fk["referred_table"] + "." + fk["referred_columns"][0]
|
||||
for fk in inspector.get_foreign_keys("flow")
|
||||
]
|
||||
# Now check if the columns user_id exists in the 'flow' table
|
||||
# If it does not exist, we need to create the foreign key
|
||||
|
||||
if "user" not in existing_tables:
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column("create_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("last_login_at", sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("id"),
|
||||
)
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
batch_op.f("ix_user_username"), ["username"], unique=True
|
||||
)
|
||||
|
||||
if "apikey" not in existing_tables:
|
||||
op.create_table(
|
||||
"apikey",
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("total_uses", sa.Integer(), nullable=False, default=0),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("id"),
|
||||
)
|
||||
with op.batch_alter_table("apikey", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
batch_op.f("ix_apikey_api_key"), ["api_key"], unique=True
|
||||
)
|
||||
batch_op.create_index(batch_op.f("ix_apikey_name"), ["name"], unique=False)
|
||||
batch_op.create_index(
|
||||
batch_op.f("ix_apikey_user_id"), ["user_id"], unique=False
|
||||
)
|
||||
if "flow" not in existing_tables:
|
||||
op.create_table(
|
||||
"flow",
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("id"),
|
||||
)
|
||||
# Conditionally create indices for 'flow' table
|
||||
# if _alembic_tmp_flow exists, then we need to drop it first
|
||||
# This is to deal with SQLite not being able to ROLLBACK
|
||||
# for some unknown reason
|
||||
if "_alembic_tmp_flow" in existing_tables:
|
||||
op.drop_table("_alembic_tmp_flow")
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
flow_columns = [col["name"] for col in inspector.get_columns("flow")]
|
||||
if "user_id" not in flow_columns:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sqlmodel.sql.sqltypes.GUID(),
|
||||
nullable=True, # This should be False, but we need to allow NULL values for now
|
||||
)
|
||||
)
|
||||
if "user.id" not in existing_fks_flow:
|
||||
batch_op.create_foreign_key("fk_flow_user_id", "user", ["user_id"], ["id"])
|
||||
if "ix_flow_description" not in existing_indices_flow:
|
||||
batch_op.create_index(
|
||||
batch_op.f("ix_flow_description"), ["description"], unique=False
|
||||
)
|
||||
if "ix_flow_name" not in existing_indices_flow:
|
||||
batch_op.create_index(batch_op.f("ix_flow_name"), ["name"], unique=False)
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
if "ix_flow_user_id" not in existing_indices_flow:
|
||||
batch_op.create_index(
|
||||
batch_op.f("ix_flow_user_id"), ["user_id"], unique=False
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
# List existing tables
|
||||
existing_tables = inspector.get_table_names()
|
||||
if "flow" in existing_tables:
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f("ix_flow_user_id"))
|
||||
batch_op.drop_index(batch_op.f("ix_flow_name"))
|
||||
batch_op.drop_index(batch_op.f("ix_flow_description"))
|
||||
|
||||
op.drop_table("flow")
|
||||
if "apikey" in existing_tables:
|
||||
with op.batch_alter_table("apikey", schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f("ix_apikey_user_id"))
|
||||
batch_op.drop_index(batch_op.f("ix_apikey_name"))
|
||||
batch_op.drop_index(batch_op.f("ix_apikey_api_key"))
|
||||
|
||||
op.drop_table("apikey")
|
||||
if "user" in existing_tables:
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f("ix_user_username"))
|
||||
|
||||
op.drop_table("user")
|
||||
|
||||
if "flowstyle" in existing_tables:
|
||||
op.drop_table("flowstyle")
|
||||
|
||||
if "component" in existing_tables:
|
||||
op.drop_table("component")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Add profile-image column
|
||||
|
||||
Revision ID: 67cc006d50bf
|
||||
Revises: 260dbcc8b680
|
||||
Create Date: 2023-09-08 07:36:13.387318
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "67cc006d50bf"
|
||||
down_revision: Union[str, None] = "260dbcc8b680"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
if "user" in inspector.get_table_names() and "profile_image" not in [
|
||||
column["name"] for column in inspector.get_columns("user")
|
||||
]:
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"profile_image", sqlmodel.sql.sqltypes.AutoString(), nullable=True
|
||||
)
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
if "user" in inspector.get_table_names() and "profile_image" in [
|
||||
column["name"] for column in inspector.get_columns("user")
|
||||
]:
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.drop_column("profile_image")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
"""Change columns to be nullable
|
||||
|
||||
Revision ID: eb5866d51fd2
|
||||
Revises: 67cc006d50bf
|
||||
Create Date: 2023-10-04 10:18:25.640458
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exc
|
||||
import sqlmodel # noqa: F401
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "eb5866d51fd2"
|
||||
down_revision: Union[str, None] = "67cc006d50bf"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
connection = op.get_bind()
|
||||
try:
|
||||
op.drop_table("flowstyle")
|
||||
with op.batch_alter_table("component", schema=None) as batch_op:
|
||||
batch_op.drop_index("ix_component_frontend_node_id")
|
||||
batch_op.drop_index("ix_component_name")
|
||||
except exc.SQLAlchemyError:
|
||||
connection.execute("ROLLBACK")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.drop_table("component")
|
||||
except exc.SQLAlchemyError:
|
||||
connection.execute("ROLLBACK")
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
op.create_table(
|
||||
"component",
|
||||
sa.Column("id", sa.CHAR(length=32), nullable=False),
|
||||
sa.Column("frontend_node_id", sa.CHAR(length=32), nullable=False),
|
||||
sa.Column("name", sa.VARCHAR(), nullable=False),
|
||||
sa.Column("description", sa.VARCHAR(), nullable=True),
|
||||
sa.Column("python_code", sa.VARCHAR(), nullable=True),
|
||||
sa.Column("return_type", sa.VARCHAR(), nullable=True),
|
||||
sa.Column("is_disabled", sa.BOOLEAN(), nullable=False),
|
||||
sa.Column("is_read_only", sa.BOOLEAN(), nullable=False),
|
||||
sa.Column("create_at", sa.DATETIME(), nullable=False),
|
||||
sa.Column("update_at", sa.DATETIME(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
with op.batch_alter_table("component", schema=None) as batch_op:
|
||||
batch_op.create_index("ix_component_name", ["name"], unique=False)
|
||||
batch_op.create_index(
|
||||
"ix_component_frontend_node_id", ["frontend_node_id"], unique=False
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.create_table(
|
||||
"flowstyle",
|
||||
sa.Column("color", sa.VARCHAR(), nullable=False),
|
||||
sa.Column("emoji", sa.VARCHAR(), nullable=False),
|
||||
sa.Column("flow_id", sa.CHAR(length=32), nullable=True),
|
||||
sa.Column("id", sa.CHAR(length=32), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["flow_id"],
|
||||
["flow.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("id"),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -5,8 +5,10 @@ from langflow.api.v1 import (
|
|||
endpoints_router,
|
||||
validate_router,
|
||||
flows_router,
|
||||
flow_styles_router,
|
||||
component_router,
|
||||
users_router,
|
||||
api_key_router,
|
||||
login_router,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -17,4 +19,6 @@ router.include_router(endpoints_router)
|
|||
router.include_router(validate_router)
|
||||
router.include_router(component_router)
|
||||
router.include_router(flows_router)
|
||||
router.include_router(flow_styles_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(api_key_router)
|
||||
router.include_router(login_router)
|
||||
|
|
|
|||
|
|
@ -59,33 +59,6 @@ def build_input_keys_response(langchain_object, artifacts):
|
|||
return input_keys_response
|
||||
|
||||
|
||||
def merge_nested_dicts(dict1, dict2):
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict) and isinstance(dict1.get(key), dict):
|
||||
dict1[key] = merge_nested_dicts(dict1[key], value)
|
||||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
||||
|
||||
def merge_nested_dicts_with_renaming(dict1, dict2):
|
||||
for key, value in dict2.items():
|
||||
if (
|
||||
key in dict1
|
||||
and isinstance(value, dict)
|
||||
and isinstance(dict1.get(key), dict)
|
||||
):
|
||||
for sub_key, sub_value in value.items():
|
||||
if sub_key in dict1[key]:
|
||||
new_key = get_new_key(dict1[key], sub_key)
|
||||
dict1[key][new_key] = sub_value
|
||||
else:
|
||||
dict1[key][sub_key] = sub_value
|
||||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
||||
|
||||
def get_new_key(dictionary, original_key):
|
||||
counter = 1
|
||||
new_key = original_key + " (" + str(counter) + ")"
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ from langflow.api.v1.endpoints import router as endpoints_router
|
|||
from langflow.api.v1.validate import router as validate_router
|
||||
from langflow.api.v1.chat import router as chat_router
|
||||
from langflow.api.v1.flows import router as flows_router
|
||||
from langflow.api.v1.flow_styles import router as flow_styles_router
|
||||
from langflow.api.v1.components import router as component_router
|
||||
from langflow.api.v1.users import router as users_router
|
||||
from langflow.api.v1.api_key import router as api_key_router
|
||||
from langflow.api.v1.login import router as login_router
|
||||
|
||||
__all__ = [
|
||||
"chat_router",
|
||||
|
|
@ -11,5 +13,7 @@ __all__ = [
|
|||
"component_router",
|
||||
"validate_router",
|
||||
"flows_router",
|
||||
"flow_styles_router",
|
||||
"users_router",
|
||||
"api_key_router",
|
||||
"login_router",
|
||||
]
|
||||
|
|
|
|||
61
src/backend/langflow/api/v1/api_key.py
Normal file
61
src/backend/langflow/api/v1/api_key.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from langflow.api.v1.schemas import ApiKeysResponse
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.api_key.api_key import (
|
||||
ApiKeyCreate,
|
||||
UnmaskedApiKeyRead,
|
||||
)
|
||||
|
||||
# Assuming you have these methods in your service layer
|
||||
from langflow.services.database.models.api_key.crud import (
|
||||
get_api_keys,
|
||||
create_api_key,
|
||||
delete_api_key,
|
||||
)
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.getters import get_session
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
router = APIRouter(tags=["APIKey"], prefix="/api_key")
|
||||
|
||||
|
||||
@router.get("/", response_model=ApiKeysResponse)
|
||||
def get_api_keys_route(
|
||||
db: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
try:
|
||||
user_id = current_user.id
|
||||
keys = get_api_keys(db, user_id)
|
||||
|
||||
return ApiKeysResponse(total_count=len(keys), user_id=user_id, api_keys=keys)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/", response_model=UnmaskedApiKeyRead)
|
||||
def create_api_key_route(
|
||||
req: ApiKeyCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
user_id = current_user.id
|
||||
return create_api_key(db, req, user_id=user_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/{api_key_id}")
|
||||
def delete_api_key_route(
|
||||
api_key_id: UUID,
|
||||
current_user=Depends(get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
delete_api_key(db, api_key_id)
|
||||
return {"detail": "API Key deleted"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Optional
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
|
@ -20,7 +21,8 @@ class FrontendNodeRequest(FrontendNode):
|
|||
class ValidatePromptRequest(BaseModel):
|
||||
name: str
|
||||
template: str
|
||||
frontend_node: FrontendNodeRequest
|
||||
# optional for tweak call
|
||||
frontend_node: Optional[FrontendNodeRequest] = None
|
||||
|
||||
|
||||
# Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
|
@ -39,7 +41,8 @@ class CodeValidationResponse(BaseModel):
|
|||
|
||||
class PromptValidationResponse(BaseModel):
|
||||
input_variables: list
|
||||
frontend_node: FrontendNodeRequest
|
||||
# object return for tweak call
|
||||
frontend_node: Optional[FrontendNodeRequest] = None
|
||||
|
||||
|
||||
INVALID_CHARACTERS = {
|
||||
|
|
|
|||
|
|
@ -1,55 +1,33 @@
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
from langflow.api.v1.schemas import ChatResponse
|
||||
from langflow.api.v1.schemas import ChatResponse, PromptResponse
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from fastapi import WebSocket
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langflow.services.getters import get_chat_service
|
||||
|
||||
|
||||
from langchain.schema import AgentAction, LLMResult, AgentFinish
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import remove_ansi_escape_codes
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
self.websocket = websocket
|
||||
def __init__(self, client_id: str):
|
||||
self.chat_service = get_chat_service()
|
||||
self.client_id = client_id
|
||||
self.websocket = self.chat_service.active_connections[self.client_id]
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
|
|
@ -95,8 +73,14 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
logger.error(f"Error sending response: {exc}")
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
|
|
@ -104,6 +88,14 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
# This runs when first sending the prompt
|
||||
# to the LLM, adding it will send the final prompt
|
||||
# to the frontend
|
||||
if "Prompt after formatting" in text:
|
||||
text = text.replace("Prompt after formatting:\n", "")
|
||||
text = remove_ansi_escape_codes(text)
|
||||
resp = PromptResponse(
|
||||
prompt=text,
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
self.chat_service.chat_history.add_message(self.client_id, resp)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
|
||||
log = f"Thought: {action.log}"
|
||||
|
|
@ -131,8 +123,10 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
self.websocket = websocket
|
||||
def __init__(self, client_id: str):
|
||||
self.chat_service = get_chat_service()
|
||||
self.client_id = client_id
|
||||
self.websocket = self.chat_service.active_connections[self.client_id]
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
|
|
|
|||
|
|
@ -1,57 +1,99 @@
|
|||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketException, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Query,
|
||||
WebSocket,
|
||||
WebSocketException,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langflow.api.utils import build_input_keys_response
|
||||
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
|
||||
|
||||
from langflow.chat.manager import ChatManager
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.utils.logger import logger
|
||||
from cachetools import LRUCache
|
||||
from langflow.services.auth.utils import get_current_active_user, get_current_user
|
||||
from langflow.services.cache.utils import update_build_status
|
||||
from loguru import logger
|
||||
from langflow.services.getters import get_chat_service, get_session, get_cache_service
|
||||
from sqlmodel import Session
|
||||
from langflow.services.chat.manager import ChatService
|
||||
from langflow.services.cache.manager import BaseCacheService
|
||||
|
||||
|
||||
router = APIRouter(tags=["Chat"])
|
||||
chat_manager = ChatManager()
|
||||
flow_data_store: LRUCache = LRUCache(maxsize=10)
|
||||
|
||||
|
||||
@router.websocket("/chat/{client_id}")
|
||||
async def chat(client_id: str, websocket: WebSocket):
|
||||
async def chat(
|
||||
client_id: str,
|
||||
websocket: WebSocket,
|
||||
token: str = Query(...),
|
||||
db: Session = Depends(get_session),
|
||||
chat_service: "ChatService" = Depends(get_chat_service),
|
||||
):
|
||||
"""Websocket endpoint for chat."""
|
||||
try:
|
||||
if client_id in chat_manager.in_memory_cache:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
await websocket.accept()
|
||||
user = await get_current_user(token, db)
|
||||
if not user:
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
if not user.is_active:
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
|
||||
if client_id in chat_service.cache_service:
|
||||
await chat_service.handle_websocket(client_id, websocket)
|
||||
else:
|
||||
# We accept the connection but close it immediately
|
||||
# if the flow is not built yet
|
||||
await websocket.accept()
|
||||
message = "Please, build the flow before sending messages"
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=message)
|
||||
except WebSocketException as exc:
|
||||
logger.error(f"Websocket error: {exc}")
|
||||
logger.error(f"Websocket exrror: {exc}")
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
|
||||
except Exception as exc:
|
||||
logger.error(f"Error in chat websocket: {exc}")
|
||||
messsage = exc.detail if isinstance(exc, HTTPException) else str(exc)
|
||||
if "Could not validate credentials" in str(exc):
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
else:
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage)
|
||||
|
||||
|
||||
@router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201)
|
||||
async def init_build(graph_data: dict, flow_id: str):
|
||||
async def init_build(
|
||||
graph_data: dict,
|
||||
flow_id: str,
|
||||
current_user=Depends(get_current_active_user),
|
||||
chat_service: "ChatService" = Depends(get_chat_service),
|
||||
cache_service: "BaseCacheService" = Depends(get_cache_service),
|
||||
):
|
||||
"""Initialize the build by storing graph data and returning a unique session ID."""
|
||||
|
||||
try:
|
||||
if flow_id is None:
|
||||
raise ValueError("No ID provided")
|
||||
# Check if already building
|
||||
if (
|
||||
flow_id in flow_data_store
|
||||
and flow_data_store[flow_id]["status"] == BuildStatus.IN_PROGRESS
|
||||
flow_id in cache_service
|
||||
and isinstance(cache_service[flow_id], dict)
|
||||
and cache_service[flow_id].get("status") == BuildStatus.IN_PROGRESS
|
||||
):
|
||||
return InitResponse(flowId=flow_id)
|
||||
|
||||
# Delete from cache if already exists
|
||||
if flow_id in chat_manager.in_memory_cache:
|
||||
with chat_manager.in_memory_cache._lock:
|
||||
chat_manager.in_memory_cache.delete(flow_id)
|
||||
logger.debug(f"Deleted flow {flow_id} from cache")
|
||||
flow_data_store[flow_id] = {
|
||||
if flow_id in chat_service.cache_service:
|
||||
chat_service.cache_service.delete(flow_id)
|
||||
logger.debug(f"Deleted flow {flow_id} from cache")
|
||||
cache_service[flow_id] = {
|
||||
"graph_data": graph_data,
|
||||
"status": BuildStatus.STARTED,
|
||||
"user_id": current_user.id,
|
||||
}
|
||||
|
||||
return InitResponse(flowId=flow_id)
|
||||
|
|
@ -61,12 +103,14 @@ async def init_build(graph_data: dict, flow_id: str):
|
|||
|
||||
|
||||
@router.get("/build/{flow_id}/status", response_model=BuiltResponse)
|
||||
async def build_status(flow_id: str):
|
||||
"""Check the flow_id is in the flow_data_store."""
|
||||
async def build_status(
|
||||
flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)
|
||||
):
|
||||
"""Check the flow_id is in the cache_service."""
|
||||
try:
|
||||
built = (
|
||||
flow_id in flow_data_store
|
||||
and flow_data_store[flow_id]["status"] == BuildStatus.SUCCESS
|
||||
flow_id in cache_service
|
||||
and cache_service[flow_id]["status"] == BuildStatus.SUCCESS
|
||||
)
|
||||
|
||||
return BuiltResponse(
|
||||
|
|
@ -79,24 +123,29 @@ async def build_status(flow_id: str):
|
|||
|
||||
|
||||
@router.get("/build/stream/{flow_id}", response_class=StreamingResponse)
|
||||
async def stream_build(flow_id: str):
|
||||
async def stream_build(
|
||||
flow_id: str,
|
||||
chat_service: "ChatService" = Depends(get_chat_service),
|
||||
cache_service: "BaseCacheService" = Depends(get_cache_service),
|
||||
):
|
||||
"""Stream the build process based on stored flow data."""
|
||||
|
||||
async def event_stream(flow_id):
|
||||
final_response = {"end_of_stream": True}
|
||||
artifacts = {}
|
||||
try:
|
||||
if flow_id not in flow_data_store:
|
||||
if flow_id not in cache_service:
|
||||
error_message = "Invalid session ID"
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
if flow_data_store[flow_id].get("status") == BuildStatus.IN_PROGRESS:
|
||||
if cache_service[flow_id].get("status") == BuildStatus.IN_PROGRESS:
|
||||
error_message = "Already building"
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
graph_data = flow_data_store[flow_id].get("graph_data")
|
||||
graph_data = cache_service[flow_id].get("graph_data")
|
||||
cache_service[flow_id]["user_id"]
|
||||
|
||||
if not graph_data:
|
||||
error_message = "No data provided"
|
||||
|
|
@ -109,7 +158,7 @@ async def stream_build(flow_id: str):
|
|||
graph = Graph.from_payload(graph_data)
|
||||
|
||||
number_of_nodes = len(graph.nodes)
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.IN_PROGRESS
|
||||
update_build_status(cache_service, flow_id, BuildStatus.IN_PROGRESS)
|
||||
|
||||
for i, vertex in enumerate(graph.generator_build(), 1):
|
||||
try:
|
||||
|
|
@ -117,7 +166,10 @@ async def stream_build(flow_id: str):
|
|||
"log": f"Building node {vertex.vertex_type}",
|
||||
}
|
||||
yield str(StreamData(event="log", data=log_dict))
|
||||
vertex.build()
|
||||
if vertex.is_task:
|
||||
vertex = try_running_celery_task(vertex)
|
||||
else:
|
||||
vertex.build()
|
||||
params = vertex._built_object_repr()
|
||||
valid = True
|
||||
logger.debug(f"Building node {str(vertex.vertex_type)}")
|
||||
|
|
@ -133,16 +185,20 @@ async def stream_build(flow_id: str):
|
|||
logger.exception(exc)
|
||||
params = str(exc)
|
||||
valid = False
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.FAILURE
|
||||
update_build_status(cache_service, flow_id, BuildStatus.FAILURE)
|
||||
|
||||
response = {
|
||||
"valid": valid,
|
||||
"params": params,
|
||||
"id": vertex.id,
|
||||
"progress": round(i / number_of_nodes, 2),
|
||||
}
|
||||
vertex_id = (
|
||||
vertex.parent_node_id if vertex.parent_is_top_level else vertex.id
|
||||
)
|
||||
if vertex_id in graph.top_level_nodes:
|
||||
response = {
|
||||
"valid": valid,
|
||||
"params": params,
|
||||
"id": vertex_id,
|
||||
"progress": round(i / number_of_nodes, 2),
|
||||
}
|
||||
|
||||
yield str(StreamData(event="message", data=response))
|
||||
yield str(StreamData(event="message", data=response))
|
||||
|
||||
langchain_object = graph.build()
|
||||
# Now we need to check the input_keys to send them to the client
|
||||
|
|
@ -157,15 +213,15 @@ async def stream_build(flow_id: str):
|
|||
"handle_keys": [],
|
||||
}
|
||||
yield str(StreamData(event="message", data=input_keys_response))
|
||||
|
||||
chat_manager.set_cache(flow_id, langchain_object)
|
||||
chat_service.set_cache(flow_id, langchain_object)
|
||||
# We need to reset the chat history
|
||||
chat_manager.chat_history.empty_history(flow_id)
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.SUCCESS
|
||||
chat_service.chat_history.empty_history(flow_id)
|
||||
update_build_status(cache_service, flow_id, BuildStatus.SUCCESS)
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
logger.error("Error while building the flow: %s", exc)
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.FAILURE
|
||||
|
||||
update_build_status(cache_service, flow_id, BuildStatus.FAILURE)
|
||||
yield str(StreamData(event="error", data={"error": str(exc)}))
|
||||
finally:
|
||||
yield str(StreamData(event="message", data=final_response))
|
||||
|
|
@ -175,3 +231,19 @@ async def stream_build(flow_id: str):
|
|||
except Exception as exc:
|
||||
logger.error(f"Error streaming build: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
def try_running_celery_task(vertex):
|
||||
# Try running the task in celery
|
||||
# and set the task_id to the local vertex
|
||||
# if it fails, run the task locally
|
||||
try:
|
||||
from langflow.worker import build_vertex
|
||||
|
||||
task = build_vertex.delay(vertex)
|
||||
vertex.task_id = task.id
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error running task in celery: {exc}")
|
||||
vertex.task_id = None
|
||||
vertex.build()
|
||||
return vertex
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from datetime import timezone
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from langflow.database.models.component import Component, ComponentModel
|
||||
from langflow.database.base import get_session
|
||||
from langflow.services.database.models.component import Component, ComponentModel
|
||||
from langflow.services.getters import get_session
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
|
|||
|
|
@ -1,89 +1,103 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Optional, Union
|
||||
from langflow.services.auth.utils import api_key_security, get_current_active_user
|
||||
|
||||
from langflow.cache.utils import save_uploaded_file
|
||||
from langflow.database.models.flow import Flow
|
||||
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.processing.process import process_graph_cached, process_tweaks
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.settings import settings
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body
|
||||
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.getters import (
|
||||
get_session_service,
|
||||
get_settings_service,
|
||||
get_task_service,
|
||||
)
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body, status
|
||||
import sqlalchemy as sa
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
|
||||
from langflow.api.v1.schemas import (
|
||||
ProcessResponse,
|
||||
TaskResponse,
|
||||
TaskStatusResponse,
|
||||
UploadFileResponse,
|
||||
CustomComponentCode,
|
||||
)
|
||||
|
||||
from langflow.api.utils import merge_nested_dicts_with_renaming
|
||||
|
||||
from langflow.interface.types import (
|
||||
build_langchain_types_dict,
|
||||
build_langchain_template_custom_component,
|
||||
build_langchain_custom_component_list_from_path,
|
||||
)
|
||||
from langflow.services.getters import get_session
|
||||
|
||||
try:
|
||||
from langflow.worker import process_graph_cached_task
|
||||
except ImportError:
|
||||
|
||||
def process_graph_cached_task(*args, **kwargs):
|
||||
raise NotImplementedError("Celery is not installed")
|
||||
|
||||
|
||||
from langflow.database.base import get_session
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
from langflow.services.task.manager import TaskService
|
||||
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
||||
|
||||
@router.get("/all")
|
||||
def get_all():
|
||||
@router.get("/all", dependencies=[Depends(get_current_active_user)])
|
||||
def get_all(
|
||||
settings_service=Depends(get_settings_service),
|
||||
):
|
||||
from langflow.interface.types import get_all_types_dict
|
||||
|
||||
logger.debug("Building langchain types dict")
|
||||
native_components = build_langchain_types_dict()
|
||||
# custom_components is a list of dicts
|
||||
# need to merge all the keys into one dict
|
||||
custom_components_from_file = {}
|
||||
if settings.COMPONENTS_PATH:
|
||||
logger.info(f"Building custom components from {settings.COMPONENTS_PATH}")
|
||||
|
||||
custom_component_dicts = []
|
||||
processed_paths = []
|
||||
for path in settings.COMPONENTS_PATH:
|
||||
if str(path) in processed_paths:
|
||||
continue
|
||||
custom_component_dict = build_langchain_custom_component_list_from_path(
|
||||
str(path)
|
||||
)
|
||||
custom_component_dicts.append(custom_component_dict)
|
||||
processed_paths.append(str(path))
|
||||
|
||||
logger.info(f"Loading {len(custom_component_dicts)} category(ies)")
|
||||
for custom_component_dict in custom_component_dicts:
|
||||
logger.debug(
|
||||
{key: len(value) for key, value in custom_component_dict.items()}
|
||||
)
|
||||
custom_components_from_file = merge_nested_dicts_with_renaming(
|
||||
custom_components_from_file, custom_component_dict
|
||||
)
|
||||
|
||||
return merge_nested_dicts_with_renaming(
|
||||
native_components, custom_components_from_file
|
||||
)
|
||||
try:
|
||||
return get_all_types_dict(settings_service)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
# For backwards compatibility we will keep the old endpoint
|
||||
@router.post("/predict/{flow_id}", response_model=ProcessResponse)
|
||||
@router.post("/process/{flow_id}", response_model=ProcessResponse)
|
||||
async def process_flow(
|
||||
@router.post(
|
||||
"/predict/{flow_id}",
|
||||
response_model=ProcessResponse,
|
||||
dependencies=[Depends(api_key_security)],
|
||||
)
|
||||
@router.post(
|
||||
"/process/{flow_id}",
|
||||
response_model=ProcessResponse,
|
||||
)
|
||||
async def process(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[dict] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session: Session = Depends(get_session),
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
task_service: "TaskService" = Depends(get_task_service),
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
|
||||
):
|
||||
"""
|
||||
Endpoint to process an input with a given flow_id.
|
||||
"""
|
||||
|
||||
try:
|
||||
flow = session.get(Flow, flow_id)
|
||||
if api_key_user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API Key",
|
||||
)
|
||||
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
flow = (
|
||||
session.query(Flow)
|
||||
.filter(Flow.id == flow_id)
|
||||
.filter(Flow.user_id == api_key_user.id)
|
||||
.first()
|
||||
)
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
|
|
@ -95,16 +109,94 @@ async def process_flow(
|
|||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing tweaks: {exc}")
|
||||
response = process_graph_cached(graph_data, inputs, clear_cache)
|
||||
if sync:
|
||||
task_id, result = await task_service.launch_and_await_task(
|
||||
process_graph_cached_task
|
||||
if task_service.use_celery
|
||||
else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
session_id,
|
||||
)
|
||||
if isinstance(result, dict) and "result" in result:
|
||||
task_result = result["result"]
|
||||
session_id = result["session_id"]
|
||||
elif hasattr(result, "result") and hasattr(result, "session_id"):
|
||||
task_result = result.result
|
||||
|
||||
session_id = result.session_id
|
||||
else:
|
||||
logger.warning(
|
||||
"This is an experimental feature and may not work as expected."
|
||||
"Please report any issues to our GitHub repository."
|
||||
)
|
||||
if session_id is None:
|
||||
# Generate a session ID
|
||||
session_id = get_session_service().generate_key(
|
||||
session_id=session_id, data_graph=graph_data
|
||||
)
|
||||
task_id, task = await task_service.launch_task(
|
||||
process_graph_cached_task
|
||||
if task_service.use_celery
|
||||
else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
session_id,
|
||||
)
|
||||
task_result = task.status
|
||||
|
||||
if task_id:
|
||||
task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}")
|
||||
else:
|
||||
task_response = None
|
||||
|
||||
return ProcessResponse(
|
||||
result=response,
|
||||
result=task_result,
|
||||
task=task_response,
|
||||
session_id=session_id,
|
||||
backend=task_service.backend_name,
|
||||
)
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
|
||||
) from exc
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/task/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str):
|
||||
task_service = get_task_service()
|
||||
task = task_service.get_task(task_id)
|
||||
result = None
|
||||
if task.ready():
|
||||
result = task.result
|
||||
if isinstance(result, dict) and "result" in result:
|
||||
result = result["result"]
|
||||
elif hasattr(result, "result"):
|
||||
result = result.result
|
||||
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return TaskStatusResponse(status=task.status, result=result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload/{flow_id}",
|
||||
response_model=UploadFileResponse,
|
||||
|
|
@ -113,7 +205,7 @@ async def process_flow(
|
|||
async def create_upload_file(file: UploadFile, flow_id: str):
|
||||
# Cache file
|
||||
try:
|
||||
file_path = save_uploaded_file(file.file, folder_name=flow_id)
|
||||
file_path = save_uploaded_file(file, folder_name=flow_id)
|
||||
|
||||
return UploadFileResponse(
|
||||
flowId=flow_id,
|
||||
|
|
@ -136,6 +228,10 @@ def get_version():
|
|||
async def custom_component(
|
||||
raw_code: CustomComponentCode,
|
||||
):
|
||||
from langflow.interface.types import (
|
||||
build_langchain_template_custom_component,
|
||||
)
|
||||
|
||||
extractor = CustomComponent(code=raw_code.code)
|
||||
extractor.is_check_valid()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,83 +0,0 @@
|
|||
from uuid import UUID
|
||||
from langflow.database.models.flow_style import (
|
||||
FlowStyle,
|
||||
FlowStyleCreate,
|
||||
FlowStyleRead,
|
||||
FlowStyleUpdate,
|
||||
)
|
||||
from langflow.database.base import get_session
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
|
||||
# build router
|
||||
router = APIRouter(prefix="/flow_styles", tags=["FlowStyles"])
|
||||
|
||||
# FlowStyleCreate:
|
||||
# class FlowStyleBase(SQLModel):
|
||||
# color: str = Field(index=True)
|
||||
# emoji: str = Field(index=False)
|
||||
# flow_id: UUID = Field(default=None, foreign_key="flow.id")
|
||||
|
||||
|
||||
@router.post("/", response_model=FlowStyleRead)
|
||||
def create_flow_style(
|
||||
*, session: Session = Depends(get_session), flow_style: FlowStyleCreate
|
||||
):
|
||||
"""Create a new flow_style."""
|
||||
db_flow_style = FlowStyle.from_orm(flow_style)
|
||||
session.add(db_flow_style)
|
||||
session.commit()
|
||||
session.refresh(db_flow_style)
|
||||
return db_flow_style
|
||||
|
||||
|
||||
@router.get("/", response_model=list[FlowStyleRead])
|
||||
def read_flow_styles(*, session: Session = Depends(get_session)):
|
||||
"""Read all flows."""
|
||||
try:
|
||||
flows = session.exec(select(FlowStyle)).all()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
return flows
|
||||
|
||||
|
||||
@router.get("/{flow_styles_id}", response_model=FlowStyleRead)
|
||||
def read_flow_style(*, session: Session = Depends(get_session), flow_styles_id: UUID):
|
||||
"""Read a flow_style."""
|
||||
if flow_style := session.get(FlowStyle, flow_styles_id):
|
||||
return flow_style
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="FlowStyle not found")
|
||||
|
||||
|
||||
@router.patch("/{flow_style_id}", response_model=FlowStyleRead)
|
||||
def update_flow_style(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
flow_style_id: UUID,
|
||||
flow_style: FlowStyleUpdate,
|
||||
):
|
||||
"""Update a flow_style."""
|
||||
db_flow_style = session.get(FlowStyle, flow_style_id)
|
||||
if not db_flow_style:
|
||||
raise HTTPException(status_code=404, detail="FlowStyle not found")
|
||||
flow_data = flow_style.dict(exclude_unset=True)
|
||||
for key, value in flow_data.items():
|
||||
if hasattr(db_flow_style, key) and value is not None:
|
||||
setattr(db_flow_style, key, value)
|
||||
session.add(db_flow_style)
|
||||
session.commit()
|
||||
session.refresh(db_flow_style)
|
||||
return db_flow_style
|
||||
|
||||
|
||||
@router.delete("/{flow_id}")
|
||||
def delete_flow_style(*, session: Session = Depends(get_session), flow_id: UUID):
|
||||
"""Delete a flow_style."""
|
||||
flow_style = session.get(FlowStyle, flow_id)
|
||||
if not flow_style:
|
||||
raise HTTPException(status_code=404, detail="FlowStyle not found")
|
||||
session.delete(flow_style)
|
||||
session.commit()
|
||||
return {"message": "FlowStyle deleted successfully"}
|
||||
|
|
@ -1,19 +1,21 @@
|
|||
from typing import List
|
||||
from uuid import UUID
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from langflow.settings import settings
|
||||
|
||||
from langflow.api.utils import remove_api_keys
|
||||
from langflow.api.v1.schemas import FlowListCreate, FlowListRead
|
||||
from langflow.database.models.flow import (
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.flow import (
|
||||
Flow,
|
||||
FlowCreate,
|
||||
FlowRead,
|
||||
FlowReadWithStyle,
|
||||
FlowUpdate,
|
||||
)
|
||||
from langflow.database.base import get_session
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.getters import get_session
|
||||
from langflow.services.getters import get_settings_service
|
||||
import orjson
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import Session
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from fastapi import File, UploadFile
|
||||
|
|
@ -23,48 +25,77 @@ router = APIRouter(prefix="/flows", tags=["Flows"])
|
|||
|
||||
|
||||
@router.post("/", response_model=FlowRead, status_code=201)
|
||||
def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate):
|
||||
def create_flow(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
flow: FlowCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Create a new flow."""
|
||||
if flow.user_id is None:
|
||||
flow.user_id = current_user.id
|
||||
|
||||
db_flow = Flow.from_orm(flow)
|
||||
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
session.refresh(db_flow)
|
||||
return db_flow
|
||||
|
||||
|
||||
@router.get("/", response_model=list[FlowReadWithStyle], status_code=200)
|
||||
def read_flows(*, session: Session = Depends(get_session)):
|
||||
@router.get("/", response_model=list[FlowRead], status_code=200)
|
||||
def read_flows(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Read all flows."""
|
||||
try:
|
||||
flows = session.exec(select(Flow)).all()
|
||||
flows = current_user.flows
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
return [jsonable_encoder(flow) for flow in flows]
|
||||
|
||||
|
||||
@router.get("/{flow_id}", response_model=FlowReadWithStyle, status_code=200)
|
||||
def read_flow(*, session: Session = Depends(get_session), flow_id: UUID):
|
||||
@router.get("/{flow_id}", response_model=FlowRead, status_code=200)
|
||||
def read_flow(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
flow_id: UUID,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Read a flow."""
|
||||
if flow := session.get(Flow, flow_id):
|
||||
return flow
|
||||
if user_flow := (
|
||||
session.query(Flow)
|
||||
.filter(Flow.id == flow_id)
|
||||
.filter(Flow.user_id == current_user.id)
|
||||
.first()
|
||||
):
|
||||
return user_flow
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
|
||||
|
||||
@router.patch("/{flow_id}", response_model=FlowRead, status_code=200)
|
||||
def update_flow(
|
||||
*, session: Session = Depends(get_session), flow_id: UUID, flow: FlowUpdate
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
flow_id: UUID,
|
||||
flow: FlowUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
settings_service=Depends(get_settings_service),
|
||||
):
|
||||
"""Update a flow."""
|
||||
|
||||
db_flow = session.get(Flow, flow_id)
|
||||
db_flow = read_flow(session=session, flow_id=flow_id, current_user=current_user)
|
||||
if not db_flow:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
flow_data = flow.dict(exclude_unset=True)
|
||||
if settings.REMOVE_API_KEYS:
|
||||
if settings_service.settings.REMOVE_API_KEYS:
|
||||
flow_data = remove_api_keys(flow_data)
|
||||
for key, value in flow_data.items():
|
||||
setattr(db_flow, key, value)
|
||||
if value is not None:
|
||||
setattr(db_flow, key, value)
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
session.refresh(db_flow)
|
||||
|
|
@ -72,9 +103,14 @@ def update_flow(
|
|||
|
||||
|
||||
@router.delete("/{flow_id}", status_code=200)
|
||||
def delete_flow(*, session: Session = Depends(get_session), flow_id: UUID):
|
||||
def delete_flow(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
flow_id: UUID,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Delete a flow."""
|
||||
flow = session.get(Flow, flow_id)
|
||||
flow = read_flow(session=session, flow_id=flow_id, current_user=current_user)
|
||||
if not flow:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
session.delete(flow)
|
||||
|
|
@ -86,10 +122,16 @@ def delete_flow(*, session: Session = Depends(get_session), flow_id: UUID):
|
|||
|
||||
|
||||
@router.post("/batch/", response_model=List[FlowRead], status_code=201)
|
||||
def create_flows(*, session: Session = Depends(get_session), flow_list: FlowListCreate):
|
||||
def create_flows(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
flow_list: FlowListCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Create multiple new flows."""
|
||||
db_flows = []
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = current_user.id
|
||||
db_flow = Flow.from_orm(flow)
|
||||
session.add(db_flow)
|
||||
db_flows.append(db_flow)
|
||||
|
|
@ -101,7 +143,10 @@ def create_flows(*, session: Session = Depends(get_session), flow_list: FlowList
|
|||
|
||||
@router.post("/upload/", response_model=List[FlowRead], status_code=201)
|
||||
async def upload_file(
|
||||
*, session: Session = Depends(get_session), file: UploadFile = File(...)
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Upload flows from a file."""
|
||||
contents = await file.read()
|
||||
|
|
@ -110,11 +155,19 @@ async def upload_file(
|
|||
flow_list = FlowListCreate(**data)
|
||||
else:
|
||||
flow_list = FlowListCreate(flows=[FlowCreate(**flow) for flow in data])
|
||||
return create_flows(session=session, flow_list=flow_list)
|
||||
# Now we set the user_id for all flows
|
||||
for flow in flow_list.flows:
|
||||
flow.user_id = current_user.id
|
||||
|
||||
return create_flows(session=session, flow_list=flow_list, current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/download/", response_model=FlowListRead, status_code=200)
|
||||
async def download_file(*, session: Session = Depends(get_session)):
|
||||
async def download_file(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Download all flows as a file."""
|
||||
flows = read_flows(session=session)
|
||||
flows = read_flows(session=session, current_user=current_user)
|
||||
return FlowListRead(flows=flows)
|
||||
|
|
|
|||
73
src/backend/langflow/api/v1/login.py
Normal file
73
src/backend/langflow/api/v1/login.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from sqlmodel import Session
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from langflow.services.getters import get_session
|
||||
from langflow.api.v1.schemas import Token
|
||||
from langflow.services.auth.utils import (
|
||||
authenticate_user,
|
||||
create_user_tokens,
|
||||
create_refresh_token,
|
||||
create_user_longterm_token,
|
||||
get_current_active_user,
|
||||
)
|
||||
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
router = APIRouter(tags=["Login"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login_to_get_access_token(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_session),
|
||||
# _: Session = Depends(get_current_active_user)
|
||||
):
|
||||
try:
|
||||
user = authenticate_user(form_data.username, form_data.password, db)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, HTTPException):
|
||||
raise exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
if user:
|
||||
return create_user_tokens(user_id=user.id, db=db, update_last_login=True)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/auto_login")
|
||||
async def auto_login(
|
||||
db: Session = Depends(get_session), settings_service=Depends(get_settings_service)
|
||||
):
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
return create_user_longterm_token(db)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"message": "Auto login is disabled. Please enable it in the settings",
|
||||
"auto_login": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
token: str, current_user: Session = Depends(get_current_active_user)
|
||||
):
|
||||
if token:
|
||||
return create_refresh_token(token)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
|
@ -1,8 +1,12 @@
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from langflow.database.models.base import orjson_dumps
|
||||
from langflow.database.models.flow import FlowCreate, FlowRead
|
||||
from uuid import UUID
|
||||
from langflow.services.database.models.api_key.api_key import ApiKeyRead
|
||||
from langflow.services.database.models.flow import FlowCreate, FlowRead
|
||||
from langflow.services.database.models.user import UserRead
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
|
|
@ -43,10 +47,30 @@ class UpdateTemplateRequest(BaseModel):
|
|||
template: dict
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
"""Task response schema."""
|
||||
|
||||
id: Optional[str] = Field(None)
|
||||
href: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class ProcessResponse(BaseModel):
|
||||
"""Process response schema."""
|
||||
|
||||
result: dict
|
||||
result: Any
|
||||
task: Optional[TaskResponse] = None
|
||||
session_id: Optional[str] = None
|
||||
backend: Optional[str] = None
|
||||
|
||||
|
||||
# TaskStatusResponse(
|
||||
# status=task.status, result=task.result if task.ready() else None
|
||||
# )
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""Task status response schema."""
|
||||
|
||||
status: str
|
||||
result: Optional[Any] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
|
@ -54,6 +78,7 @@ class ChatMessage(BaseModel):
|
|||
|
||||
is_bot: bool = False
|
||||
message: Union[str, None, dict] = None
|
||||
chatKey: Optional[str] = None
|
||||
type: str = "human"
|
||||
|
||||
|
||||
|
|
@ -61,6 +86,7 @@ class ChatResponse(ChatMessage):
|
|||
"""Chat response schema."""
|
||||
|
||||
intermediate_steps: str
|
||||
|
||||
type: str
|
||||
is_bot: bool = True
|
||||
files: list = []
|
||||
|
|
@ -72,6 +98,14 @@ class ChatResponse(ChatMessage):
|
|||
return v
|
||||
|
||||
|
||||
class PromptResponse(ChatMessage):
|
||||
"""Prompt response schema."""
|
||||
|
||||
prompt: str
|
||||
type: str = "prompt"
|
||||
is_bot: bool = True
|
||||
|
||||
|
||||
class FileResponse(ChatMessage):
|
||||
"""File response schema."""
|
||||
|
||||
|
|
@ -135,3 +169,32 @@ class ComponentListCreate(BaseModel):
|
|||
|
||||
class ComponentListRead(BaseModel):
|
||||
flows: List[FlowRead]
|
||||
|
||||
|
||||
class UsersResponse(BaseModel):
|
||||
total_count: int
|
||||
users: List[UserRead]
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: str
|
||||
api_key: str
|
||||
name: str
|
||||
created_at: str
|
||||
last_used_at: str
|
||||
|
||||
|
||||
class ApiKeysResponse(BaseModel):
|
||||
total_count: int
|
||||
user_id: UUID
|
||||
api_keys: List[ApiKeyRead]
|
||||
|
||||
|
||||
class CreateApiKeyRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
|
|
|
|||
169
src/backend/langflow/api/v1/users.py
Normal file
169
src/backend/langflow/api/v1/users.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
from uuid import UUID
|
||||
from langflow.api.v1.schemas import UsersResponse
|
||||
from langflow.services.database.models.user import (
|
||||
User,
|
||||
UserCreate,
|
||||
UserRead,
|
||||
UserUpdate,
|
||||
)
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from langflow.services.getters import get_session, get_settings_service
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_superuser,
|
||||
get_current_active_user,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from langflow.services.database.models.user.crud import (
|
||||
get_user_by_id,
|
||||
update_user,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Users"], prefix="/users")
|
||||
|
||||
|
||||
@router.post("/", response_model=UserRead, status_code=201)
|
||||
def add_user(
|
||||
user: UserCreate,
|
||||
session: Session = Depends(get_session),
|
||||
settings_service=Depends(get_settings_service),
|
||||
) -> User:
|
||||
"""
|
||||
Add a new user to the database.
|
||||
"""
|
||||
new_user = User.from_orm(user)
|
||||
try:
|
||||
new_user.password = get_password_hash(user.password)
|
||||
new_user.is_active = settings_service.auth_settings.NEW_USER_IS_ACTIVE
|
||||
session.add(new_user)
|
||||
session.commit()
|
||||
session.refresh(new_user)
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="This username is unavailable."
|
||||
) from e
|
||||
|
||||
return new_user
|
||||
|
||||
|
||||
@router.get("/whoami", response_model=UserRead)
|
||||
def read_current_user(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> User:
|
||||
"""
|
||||
Retrieve the current user's data.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/", response_model=UsersResponse)
|
||||
def read_all_users(
|
||||
skip: int = 0,
|
||||
limit: int = 10,
|
||||
_: Session = Depends(get_current_active_superuser),
|
||||
session: Session = Depends(get_session),
|
||||
) -> UsersResponse:
|
||||
"""
|
||||
Retrieve a list of users from the database with pagination.
|
||||
"""
|
||||
query = select(User).offset(skip).limit(limit)
|
||||
users = session.execute(query).fetchall()
|
||||
|
||||
count_query = select(func.count()).select_from(User) # type: ignore
|
||||
total_count = session.execute(count_query).scalar()
|
||||
|
||||
return UsersResponse(
|
||||
total_count=total_count, # type: ignore
|
||||
users=[UserRead(**dict(user.User)) for user in users],
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserRead)
|
||||
def patch_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
user: User = Depends(get_current_active_user),
|
||||
session: Session = Depends(get_session),
|
||||
) -> User:
|
||||
"""
|
||||
Update an existing user's data.
|
||||
"""
|
||||
if not user.is_superuser and user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have the permission to update this user"
|
||||
)
|
||||
if user_update.password:
|
||||
if not user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't change your password here"
|
||||
)
|
||||
user_update.password = get_password_hash(user_update.password)
|
||||
|
||||
if user_db := get_user_by_id(session, user_id):
|
||||
return update_user(user_db, user_update, session)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
||||
@router.patch("/{user_id}/reset-password", response_model=UserRead)
|
||||
def reset_password(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
user: User = Depends(get_current_active_user),
|
||||
session: Session = Depends(get_session),
|
||||
) -> User:
|
||||
"""
|
||||
Reset a user's password.
|
||||
"""
|
||||
if user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't change another user's password"
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if verify_password(user_update.password, user.password):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't use your current password"
|
||||
)
|
||||
new_password = get_password_hash(user_update.password)
|
||||
user.password = new_password
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=dict)
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a user from the database.
|
||||
"""
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete your own user account"
|
||||
)
|
||||
elif not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have the permission to delete this user"
|
||||
)
|
||||
|
||||
user_db = session.query(User).filter(User.id == user_id).first()
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
session.delete(user_db)
|
||||
session.commit()
|
||||
|
||||
return {"detail": "User deleted"}
|
||||
|
|
@ -8,7 +8,7 @@ from langflow.api.v1.base import (
|
|||
validate_prompt,
|
||||
)
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
# build router
|
||||
|
|
@ -31,7 +31,12 @@ def post_validate_code(code: Code):
|
|||
def post_validate_prompt(prompt_request: ValidatePromptRequest):
|
||||
try:
|
||||
input_variables = validate_prompt(prompt_request.template)
|
||||
|
||||
# Check if frontend_node is None before proceeding to avoid attempting to update a non-existent node.
|
||||
if prompt_request.frontend_node is None:
|
||||
return PromptValidationResponse(
|
||||
input_variables=input_variables,
|
||||
frontend_node=None,
|
||||
)
|
||||
old_custom_fields = get_old_custom_fields(prompt_request)
|
||||
|
||||
add_new_variables_to_template(input_variables, prompt_request)
|
||||
|
|
|
|||
7
src/backend/langflow/cache/__init__.py
vendored
7
src/backend/langflow/cache/__init__.py
vendored
|
|
@ -1,7 +0,0 @@
|
|||
from langflow.cache.manager import cache_manager
|
||||
from langflow.cache.flow import InMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"cache_manager",
|
||||
"InMemoryCache",
|
||||
]
|
||||
146
src/backend/langflow/cache/flow.py
vendored
146
src/backend/langflow/cache/flow.py
vendored
|
|
@ -1,146 +0,0 @@
|
|||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
from langflow.cache.base import BaseCache
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
"""
|
||||
A simple in-memory cache using an OrderedDict.
|
||||
|
||||
This cache supports setting a maximum size and expiration time for cached items.
|
||||
When the cache is full, it uses a Least Recently Used (LRU) eviction policy.
|
||||
Thread-safe using a threading Lock.
|
||||
|
||||
Attributes:
|
||||
max_size (int, optional): Maximum number of items to store in the cache.
|
||||
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
|
||||
|
||||
Example:
|
||||
|
||||
cache = InMemoryCache(max_size=3, expiration_time=5)
|
||||
|
||||
# setting cache values
|
||||
cache.set("a", 1)
|
||||
cache.set("b", 2)
|
||||
cache["c"] = 3
|
||||
|
||||
# getting cache values
|
||||
a = cache.get("a")
|
||||
b = cache["b"]
|
||||
"""
|
||||
|
||||
def __init__(self, max_size=None, expiration_time=60 * 60):
|
||||
"""
|
||||
Initialize a new InMemoryCache instance.
|
||||
|
||||
Args:
|
||||
max_size (int, optional): Maximum number of items to store in the cache.
|
||||
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
|
||||
"""
|
||||
self._cache = OrderedDict()
|
||||
self._lock = threading.Lock()
|
||||
self.max_size = max_size
|
||||
self.expiration_time = expiration_time
|
||||
|
||||
def get(self, key):
|
||||
"""
|
||||
Retrieve an item from the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to retrieve.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found or the item has expired.
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
item = self._cache.pop(key)
|
||||
if (
|
||||
self.expiration_time is None
|
||||
or time.time() - item["time"] < self.expiration_time
|
||||
):
|
||||
# Move the key to the end to make it recently used
|
||||
self._cache[key] = item
|
||||
return item["value"]
|
||||
else:
|
||||
self.delete(key)
|
||||
return None
|
||||
|
||||
def set(self, key, value):
|
||||
"""
|
||||
Add an item to the cache.
|
||||
|
||||
If the cache is full, the least recently used item is evicted.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache.
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
# Remove existing key before re-inserting to update order
|
||||
self.delete(key)
|
||||
elif self.max_size and len(self._cache) >= self.max_size:
|
||||
# Remove least recently used item
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = {"value": value, "time": time.time()}
|
||||
|
||||
def get_or_set(self, key, value):
|
||||
"""
|
||||
Retrieve an item from the cache. If the item does not exist, set it with the provided value.
|
||||
|
||||
Args:
|
||||
key: The key of the item.
|
||||
value: The value to cache if the item doesn't exist.
|
||||
|
||||
Returns:
|
||||
The cached value associated with the key.
|
||||
"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
return self.get(key)
|
||||
self.set(key, value)
|
||||
return value
|
||||
|
||||
def delete(self, key):
|
||||
"""
|
||||
Remove an item from the cache.
|
||||
|
||||
Args:
|
||||
key: The key of the item to remove.
|
||||
"""
|
||||
# with self._lock:
|
||||
self._cache.pop(key, None)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clear all items from the cache.
|
||||
"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Check if the key is in the cache."""
|
||||
return key in self._cache
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Retrieve an item from the cache using the square bracket notation."""
|
||||
return self.get(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Add an item to the cache using the square bracket notation."""
|
||||
self.set(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Remove an item from the cache using the square bracket notation."""
|
||||
self.delete(key)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of items in the cache."""
|
||||
return len(self._cache)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the InMemoryCache instance."""
|
||||
return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})"
|
||||
|
|
@ -79,4 +79,5 @@ class ConversationalAgent(CustomComponent):
|
|||
memory=memory,
|
||||
verbose=True,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langflow import CustomComponent
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain import PromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
|
|
@ -16,17 +16,14 @@ class PromptRunner(CustomComponent):
|
|||
"info": "Make sure the prompt has all variables filled.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
"inputs": {"field_type": "code"},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
prompt: PromptTemplate,
|
||||
self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}
|
||||
) -> Document:
|
||||
chain = prompt | llm
|
||||
# The input is an empty dict because the prompt is already filled
|
||||
result = chain.invoke({})
|
||||
result = chain.invoke(input=inputs)
|
||||
if hasattr(result, "content"):
|
||||
result = result.content
|
||||
self.repr_value = result
|
||||
|
|
|
|||
45
src/backend/langflow/components/llms/AmazonBedrock.py
Normal file
45
src/backend/langflow/components/llms/AmazonBedrock.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
from langchain.llms.bedrock import Bedrock
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
class AmazonBedrockComponent(CustomComponent):
|
||||
display_name: str = "Amazon Bedrock"
|
||||
description: str = "LLM model from Amazon Bedrock."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model_id": {
|
||||
"display_name": "Model Id",
|
||||
"options": [
|
||||
"ai21.j2-grande-instruct",
|
||||
"ai21.j2-jumbo-instruct",
|
||||
"ai21.j2-mid",
|
||||
"ai21.j2-mid-v1",
|
||||
"ai21.j2-ultra",
|
||||
"ai21.j2-ultra-v1",
|
||||
"anthropic.claude-instant-v1",
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-v2",
|
||||
"cohere.command-text-v14",
|
||||
],
|
||||
},
|
||||
"credentials_profile_name": {"display_name": "Credentials Profile Name"},
|
||||
"streaming": {"display_name": "Streaming", "field_type": "bool"},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
model_id: str = "anthropic.claude-instant-v1",
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
) -> BaseLLM:
|
||||
try:
|
||||
output = Bedrock(
|
||||
credentials_profile_name=credentials_profile_name,
|
||||
model_id=model_id,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to AmazonBedrock API.") from e
|
||||
return output
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
class QianfanChatEndpointComponent(CustomComponent):
|
||||
display_name: str = "QianfanChatEndpoint"
|
||||
description: str = (
|
||||
"Baidu Qianfan chat models. Get more detail from "
|
||||
"https://python.langchain.com/docs/integrations/chat/baidu_qianfan_endpoint."
|
||||
)
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"options": [
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Bot-turbo",
|
||||
"BLOOMZ-7B",
|
||||
"Llama-2-7b-chat",
|
||||
"Llama-2-13b-chat",
|
||||
"Llama-2-70b-chat",
|
||||
"Qianfan-BLOOMZ-7B-compressed",
|
||||
"Qianfan-Chinese-Llama-2-7B",
|
||||
"ChatGLM2-6B-32K",
|
||||
"AquilaChat-7B",
|
||||
],
|
||||
"info": "https://python.langchain.com/docs/integrations/chat/baidu_qianfan_endpoint",
|
||||
"required": True,
|
||||
},
|
||||
"qianfan_ak": {
|
||||
"display_name": "Qianfan Ak",
|
||||
"required": True,
|
||||
"password": True,
|
||||
"info": "which you could get from https://cloud.baidu.com/product/wenxinworkshop",
|
||||
},
|
||||
"qianfan_sk": {
|
||||
"display_name": "Qianfan Sk",
|
||||
"required": True,
|
||||
"password": True,
|
||||
"info": "which you could get from https://cloud.baidu.com/product/wenxinworkshop",
|
||||
},
|
||||
"top_p": {
|
||||
"display_name": "Top p",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 0.8,
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 0.95,
|
||||
},
|
||||
"penalty_score": {
|
||||
"display_name": "Penalty Score",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 1.0,
|
||||
},
|
||||
"endpoint": {
|
||||
"display_name": "Endpoint",
|
||||
"info": "Endpoint of the Qianfan LLM, required if custom model used.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
model: str = "ERNIE-Bot-turbo",
|
||||
qianfan_ak: Optional[str] = None,
|
||||
qianfan_sk: Optional[str] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
penalty_score: Optional[float] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> BaseLLM:
|
||||
try:
|
||||
output = QianfanChatEndpoint( # type: ignore
|
||||
model=model,
|
||||
qianfan_ak=qianfan_ak,
|
||||
qianfan_sk=qianfan_sk,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
penalty_score=penalty_score,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to Baidu Qianfan API.") from e
|
||||
return output # type: ignore
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
class QianfanLLMEndpointComponent(CustomComponent):
|
||||
display_name: str = "QianfanLLMEndpoint"
|
||||
description: str = (
|
||||
"Baidu Qianfan hosted open source or customized models. "
|
||||
"Get more detail from https://python.langchain.com/docs/integrations/chat/baidu_qianfan_endpoint"
|
||||
)
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"options": [
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Bot-turbo",
|
||||
"BLOOMZ-7B",
|
||||
"Llama-2-7b-chat",
|
||||
"Llama-2-13b-chat",
|
||||
"Llama-2-70b-chat",
|
||||
"Qianfan-BLOOMZ-7B-compressed",
|
||||
"Qianfan-Chinese-Llama-2-7B",
|
||||
"ChatGLM2-6B-32K",
|
||||
"AquilaChat-7B",
|
||||
],
|
||||
"info": "https://python.langchain.com/docs/integrations/chat/baidu_qianfan_endpoint",
|
||||
"required": True,
|
||||
},
|
||||
"qianfan_ak": {
|
||||
"display_name": "Qianfan Ak",
|
||||
"required": True,
|
||||
"password": True,
|
||||
"info": "which you could get from https://cloud.baidu.com/product/wenxinworkshop",
|
||||
},
|
||||
"qianfan_sk": {
|
||||
"display_name": "Qianfan Sk",
|
||||
"required": True,
|
||||
"password": True,
|
||||
"info": "which you could get from https://cloud.baidu.com/product/wenxinworkshop",
|
||||
},
|
||||
"top_p": {
|
||||
"display_name": "Top p",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 0.8,
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 0.95,
|
||||
},
|
||||
"penalty_score": {
|
||||
"display_name": "Penalty Score",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 1.0,
|
||||
},
|
||||
"endpoint": {
|
||||
"display_name": "Endpoint",
|
||||
"info": "Endpoint of the Qianfan LLM, required if custom model used.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
model: str = "ERNIE-Bot-turbo",
|
||||
qianfan_ak: Optional[str] = None,
|
||||
qianfan_sk: Optional[str] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
penalty_score: Optional[float] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> BaseLLM:
|
||||
try:
|
||||
output = QianfanLLMEndpoint( # type: ignore
|
||||
model=model,
|
||||
qianfan_ak=qianfan_ak,
|
||||
qianfan_sk=qianfan_sk,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
penalty_score=penalty_score,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to Baidu Qianfan API.") from e
|
||||
return output # type: ignore
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
|
||||
|
|
@ -13,7 +13,6 @@ class HuggingFaceEndpointsComponent(CustomComponent):
|
|||
"endpoint_url": {"display_name": "Endpoint URL", "password": True},
|
||||
"task": {
|
||||
"display_name": "Task",
|
||||
"type": "select",
|
||||
"options": ["text2text-generation", "text-generation", "summarization"],
|
||||
},
|
||||
"huggingfacehub_api_token": {"display_name": "API token", "password": True},
|
||||
|
|
@ -27,7 +26,7 @@ class HuggingFaceEndpointsComponent(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
endpoint_url: str,
|
||||
task="text2text-generation",
|
||||
task: str = "text2text-generation",
|
||||
huggingfacehub_api_token: Optional[str] = None,
|
||||
model_kwargs: Optional[dict] = None,
|
||||
) -> BaseLLM:
|
||||
|
|
@ -36,6 +35,7 @@ class HuggingFaceEndpointsComponent(CustomComponent):
|
|||
endpoint_url=endpoint_url,
|
||||
task=task,
|
||||
huggingfacehub_api_token=huggingfacehub_api_token,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e
|
||||
|
|
|
|||
48
src/backend/langflow/components/retrievers/AmazonKendra.py
Normal file
48
src/backend/langflow/components/retrievers/AmazonKendra.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
from langchain.retrievers import AmazonKendraRetriever
|
||||
from langchain.schema import BaseRetriever
|
||||
|
||||
|
||||
class AmazonKendraRetrieverComponent(CustomComponent):
|
||||
display_name: str = "Amazon Kendra Retriever"
|
||||
description: str = "Retriever that uses the Amazon Kendra API."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"index_id": {"display_name": "Index ID"},
|
||||
"region_name": {"display_name": "Region Name"},
|
||||
"credentials_profile_name": {"display_name": "Credentials Profile Name"},
|
||||
"attribute_filter": {
|
||||
"display_name": "Attribute Filter",
|
||||
"field_type": "code",
|
||||
},
|
||||
"top_k": {"display_name": "Top K", "field_type": "int"},
|
||||
"user_context": {
|
||||
"display_name": "User Context",
|
||||
"field_type": "code",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
index_id: str,
|
||||
top_k: int = 3,
|
||||
region_name: Optional[str] = None,
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
attribute_filter: Optional[dict] = None,
|
||||
user_context: Optional[dict] = None,
|
||||
) -> BaseRetriever:
|
||||
try:
|
||||
output = AmazonKendraRetriever(
|
||||
index_id=index_id,
|
||||
top_k=top_k,
|
||||
region_name=region_name,
|
||||
credentials_profile_name=credentials_profile_name,
|
||||
attribute_filter=attribute_filter,
|
||||
user_context=user_context,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to AmazonKendra API.") from e
|
||||
return output
|
||||
|
|
@ -14,7 +14,7 @@ class MetalRetrieverComponent(CustomComponent):
|
|||
"api_key": {"display_name": "API Key", "password": True},
|
||||
"client_id": {"display_name": "Client ID", "password": True},
|
||||
"index_id": {"display_name": "Index ID"},
|
||||
"params": {"display_name": "Parameters", "field_type": "code"},
|
||||
"params": {"display_name": "Parameters"},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from typing import Optional
|
|||
from langflow import CustomComponent
|
||||
from langchain.text_splitter import Language
|
||||
from langchain.schema import Document
|
||||
from langflow.utils.util import build_loader_repr_from_documents
|
||||
|
||||
|
||||
class LanguageRecursiveTextSplitterComponent(CustomComponent):
|
||||
|
|
@ -78,5 +77,4 @@ class LanguageRecursiveTextSplitterComponent(CustomComponent):
|
|||
)
|
||||
|
||||
docs = splitter.split_documents(documents)
|
||||
self.repr_value = build_loader_repr_from_documents(docs)
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
from langchain.schema import Document
|
||||
from langflow.utils.util import build_loader_repr_from_documents
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitterComponent(CustomComponent):
|
||||
|
|
@ -74,6 +75,5 @@ class RecursiveCharacterTextSplitterComponent(CustomComponent):
|
|||
)
|
||||
|
||||
docs = splitter.split_documents(documents)
|
||||
# self.repr_value = build_loader_repr_from_documents(docs)
|
||||
self.repr_value = separators
|
||||
self.repr_value = build_loader_repr_from_documents(docs)
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from langflow import CustomComponent
|
||||
from langchain.schema import Document
|
||||
from langflow.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
import requests
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -19,7 +19,6 @@ class GetRequest(CustomComponent):
|
|||
},
|
||||
"headers": {
|
||||
"display_name": "Headers",
|
||||
"field_type": "code",
|
||||
"info": "The headers to send with the request.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
from langflow import CustomComponent
|
||||
from langchain.schema import Document
|
||||
from langflow.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
|
||||
class JSONDocumentBuilder(CustomComponent):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from langflow import CustomComponent
|
||||
from langchain.schema import Document
|
||||
from langflow.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
import requests
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -15,7 +15,6 @@ class PostRequest(CustomComponent):
|
|||
"url": {"display_name": "URL", "info": "The URL to make the request to."},
|
||||
"headers": {
|
||||
"display_name": "Headers",
|
||||
"field_type": "code",
|
||||
"info": "The headers to send with the request.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import List, Optional
|
|||
import requests
|
||||
from langflow import CustomComponent
|
||||
from langchain.schema import Document
|
||||
from langflow.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
|
||||
class UpdateRequest(CustomComponent):
|
||||
|
|
@ -15,7 +15,7 @@ class UpdateRequest(CustomComponent):
|
|||
"url": {"display_name": "URL", "info": "The URL to make the request to."},
|
||||
"headers": {
|
||||
"display_name": "Headers",
|
||||
"field_type": "code",
|
||||
"field_type": "NestedDict",
|
||||
"info": "The headers to send with the request.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
|
|
|
|||
109
src/backend/langflow/components/vectorstores/Chroma.py
Normal file
109
src/backend/langflow/components/vectorstores/Chroma.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
from typing import Optional, Union
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.embeddings.base import Embeddings
|
||||
import chromadb # type: ignore
|
||||
|
||||
|
||||
class ChromaComponent(CustomComponent):
|
||||
"""
|
||||
A custom component for implementing a Vector Store using Chroma.
|
||||
"""
|
||||
|
||||
display_name: str = "Chroma"
|
||||
description: str = "Implementation of Vector Store using Chroma"
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/chroma"
|
||||
beta = True
|
||||
|
||||
def build_config(self):
|
||||
"""
|
||||
Builds the configuration for the component.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing the configuration options for the component.
|
||||
"""
|
||||
return {
|
||||
"collection_name": {"display_name": "Collection Name", "value": "langflow"},
|
||||
"persist": {"display_name": "Persist"},
|
||||
"persist_directory": {"display_name": "Persist Directory"},
|
||||
"code": {"show": False, "display_name": "Code"},
|
||||
"documents": {"display_name": "Documents", "is_list": True},
|
||||
"embedding": {"display_name": "Embedding"},
|
||||
"chroma_server_cors_allow_origins": {
|
||||
"display_name": "Server CORS Allow Origins",
|
||||
"advanced": True,
|
||||
},
|
||||
"chroma_server_host": {"display_name": "Server Host", "advanced": True},
|
||||
"chroma_server_port": {"display_name": "Server Port", "advanced": True},
|
||||
"chroma_server_grpc_port": {
|
||||
"display_name": "Server gRPC Port",
|
||||
"advanced": True,
|
||||
},
|
||||
"chroma_server_ssl_enabled": {
|
||||
"display_name": "Server SSL Enabled",
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
collection_name: str,
|
||||
persist: bool,
|
||||
chroma_server_ssl_enabled: bool,
|
||||
persist_directory: Optional[str] = None,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
documents: Optional[Document] = None,
|
||||
chroma_server_cors_allow_origins: Optional[str] = None,
|
||||
chroma_server_host: Optional[str] = None,
|
||||
chroma_server_port: Optional[int] = None,
|
||||
chroma_server_grpc_port: Optional[int] = None,
|
||||
) -> Union[VectorStore, BaseRetriever]:
|
||||
"""
|
||||
Builds the Vector Store or BaseRetriever object.
|
||||
|
||||
Args:
|
||||
- collection_name (str): The name of the collection.
|
||||
- persist_directory (Optional[str]): The directory to persist the Vector Store to.
|
||||
- chroma_server_ssl_enabled (bool): Whether to enable SSL for the Chroma server.
|
||||
- persist (bool): Whether to persist the Vector Store or not.
|
||||
- embedding (Optional[Embeddings]): The embeddings to use for the Vector Store.
|
||||
- documents (Optional[Document]): The documents to use for the Vector Store.
|
||||
- chroma_server_cors_allow_origins (Optional[str]): The CORS allow origins for the Chroma server.
|
||||
- chroma_server_host (Optional[str]): The host for the Chroma server.
|
||||
- chroma_server_port (Optional[int]): The port for the Chroma server.
|
||||
- chroma_server_grpc_port (Optional[int]): The gRPC port for the Chroma server.
|
||||
|
||||
Returns:
|
||||
- Union[VectorStore, BaseRetriever]: The Vector Store or BaseRetriever object.
|
||||
"""
|
||||
|
||||
# Chroma settings
|
||||
chroma_settings = None
|
||||
|
||||
if chroma_server_host is not None:
|
||||
chroma_settings = chromadb.config.Settings(
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
|
||||
or None,
|
||||
chroma_server_host=chroma_server_host,
|
||||
chroma_server_port=chroma_server_port or None,
|
||||
chroma_server_grpc_port=chroma_server_grpc_port or None,
|
||||
chroma_server_ssl_enabled=chroma_server_ssl_enabled,
|
||||
)
|
||||
|
||||
# If documents, then we need to create a Chroma instance using .from_documents
|
||||
if documents is not None and embedding is not None:
|
||||
return Chroma.from_documents(
|
||||
documents=documents, # type: ignore
|
||||
persist_directory=persist_directory if persist else None,
|
||||
collection_name=collection_name,
|
||||
embedding=embedding,
|
||||
client_settings=chroma_settings,
|
||||
)
|
||||
|
||||
return Chroma(
|
||||
persist_directory=persist_directory, client_settings=chroma_settings
|
||||
)
|
||||
|
|
@ -5,7 +5,6 @@ from langchain.vectorstores import Vectara
|
|||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class VectaraComponent(CustomComponent):
|
||||
|
|
@ -22,7 +21,6 @@ class VectaraComponent(CustomComponent):
|
|||
"vectara_api_key": {"display_name": "Vectara API Key", "password": True},
|
||||
"code": {"show": False},
|
||||
"documents": {"display_name": "Documents"},
|
||||
"embedding": {"display_name": "Embedding"},
|
||||
}
|
||||
|
||||
def build(
|
||||
|
|
@ -30,21 +28,21 @@ class VectaraComponent(CustomComponent):
|
|||
vectara_customer_id: str,
|
||||
vectara_corpus_id: str,
|
||||
vectara_api_key: str,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
documents: Optional[Document] = None,
|
||||
) -> Union[VectorStore, BaseRetriever]:
|
||||
# If documents, then we need to create a Vectara instance using .from_documents
|
||||
if documents is not None and embedding is not None:
|
||||
if documents is not None:
|
||||
return Vectara.from_documents(
|
||||
documents=documents, # type: ignore
|
||||
vectara_customer_id=vectara_customer_id,
|
||||
vectara_corpus_id=vectara_corpus_id,
|
||||
vectara_api_key=vectara_api_key,
|
||||
embedding=embedding,
|
||||
source="langflow",
|
||||
)
|
||||
|
||||
return Vectara(
|
||||
vectara_customer_id=vectara_customer_id,
|
||||
vectara_corpus_id=vectara_corpus_id,
|
||||
vectara_api_key=vectara_api_key,
|
||||
source="langflow",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -104,6 +104,8 @@ embeddings:
|
|||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/sentence_transformers"
|
||||
CohereEmbeddings:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/cohere"
|
||||
VertexAIEmbeddings:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/google_vertex_ai_palm"
|
||||
llms:
|
||||
OpenAI:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/openai"
|
||||
|
|
@ -127,8 +129,8 @@ llms:
|
|||
# There's a bug in this component deactivating until we get it sorted: _language_models.py", line 804, in send_message
|
||||
# is_blocked=safety_attributes.get("blocked", False),
|
||||
# AttributeError: 'list' object has no attribute 'get'
|
||||
# ChatVertexAI:
|
||||
# documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/integrations/google_vertex_ai_palm"
|
||||
ChatVertexAI:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/integrations/google_vertex_ai_palm"
|
||||
###
|
||||
memories:
|
||||
# https://github.com/supabase-community/supabase-py/issues/482
|
||||
|
|
@ -263,8 +265,8 @@ retrievers:
|
|||
# ZepRetriever:
|
||||
# documentation: "https://python.langchain.com/docs/modules/data_connection/retrievers/integrations/zep_memorystore"
|
||||
vectorstores:
|
||||
Chroma:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/chroma"
|
||||
# Chroma:
|
||||
# documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/chroma"
|
||||
Qdrant:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/qdrant"
|
||||
Weaviate:
|
||||
|
|
|
|||
11
src/backend/langflow/core/celery_app.py
Normal file
11
src/backend/langflow/core/celery_app.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from celery import Celery # type: ignore
|
||||
|
||||
|
||||
def make_celery(app_name: str, config: str) -> Celery:
|
||||
celery_app = Celery(app_name)
|
||||
celery_app.config_from_object(config)
|
||||
celery_app.conf.task_routes = {"langflow.worker.tasks.*": {"queue": "langflow"}}
|
||||
return celery_app
|
||||
|
||||
|
||||
celery_app = make_celery("langflow", "langflow.core.celeryconfig")
|
||||
14
src/backend/langflow/core/celeryconfig.py
Normal file
14
src/backend/langflow/core/celeryconfig.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# celeryconfig.py
|
||||
import os
|
||||
|
||||
langflow_redis_host = os.environ.get("LANGFLOW_REDIS_HOST")
|
||||
langflow_redis_port = os.environ.get("LANGFLOW_REDIS_PORT")
|
||||
if "BROKER_URL" in os.environ and "RESULT_BACKEND" in os.environ:
|
||||
# RabbitMQ
|
||||
broker_url = os.environ.get("BROKER_URL", "amqp://localhost")
|
||||
result_backend = os.environ.get("RESULT_BACKEND", "redis://localhost:6379/0")
|
||||
elif langflow_redis_host and langflow_redis_port:
|
||||
broker_url = f"redis://{langflow_redis_host}:{langflow_redis_port}/0"
|
||||
result_backend = f"redis://{langflow_redis_host}:{langflow_redis_port}/0"
|
||||
# tasks should be json or pickle
|
||||
accept_content = ["json", "pickle"]
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
from contextlib import contextmanager
|
||||
import os
|
||||
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
from langflow.utils.logger import logger
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
class Engine:
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
def get(cls):
|
||||
logger.debug("Getting database engine")
|
||||
if cls._instance is None:
|
||||
cls.create()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def create(cls):
|
||||
logger.debug("Creating database engine")
|
||||
from langflow.settings import settings
|
||||
|
||||
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
|
||||
settings.DATABASE_URL = langflow_database_url
|
||||
logger.debug("Using LANGFLOW_DATABASE_URL")
|
||||
|
||||
if settings.DATABASE_URL and settings.DATABASE_URL.startswith("sqlite"):
|
||||
connect_args = {"check_same_thread": False}
|
||||
else:
|
||||
connect_args = {}
|
||||
if not settings.DATABASE_URL:
|
||||
raise RuntimeError("No database_url provided")
|
||||
cls._instance = create_engine(settings.DATABASE_URL, connect_args=connect_args)
|
||||
|
||||
@classmethod
|
||||
def update(cls):
|
||||
logger.debug("Updating database engine")
|
||||
cls._instance = None
|
||||
cls.create()
|
||||
|
||||
@classmethod
|
||||
def teardown(cls):
|
||||
logger.debug("Tearing down database engine")
|
||||
if cls._instance is not None:
|
||||
cls._instance.dispose()
|
||||
cls._instance = None
|
||||
|
||||
|
||||
def create_db_and_tables():
|
||||
logger.debug("Creating database and tables")
|
||||
try:
|
||||
SQLModel.metadata.create_all(Engine.get())
|
||||
except sa.exc.OperationalError as exc:
|
||||
# Check if the error is because the table already exists
|
||||
if "already exists" in str(exc):
|
||||
logger.debug("Database and tables already exist")
|
||||
else:
|
||||
logger.error(f"Error creating database and tables: {exc}")
|
||||
raise RuntimeError("Error creating database and tables") from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Error creating database and tables: {exc}")
|
||||
raise RuntimeError("Error creating database and tables") from exc
|
||||
# Now check if the table Flow exists, if not, something went wrong
|
||||
# and we need to create the tables again.
|
||||
from sqlalchemy import inspect
|
||||
|
||||
inspector = inspect(Engine.get())
|
||||
if "flow" not in inspector.get_table_names():
|
||||
logger.error("Something went wrong creating the database and tables.")
|
||||
logger.error("Please check your database settings.")
|
||||
|
||||
raise RuntimeError("Something went wrong creating the database and tables.")
|
||||
else:
|
||||
logger.debug("Database and tables created successfully")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_getter():
|
||||
try:
|
||||
session = Session(Engine.get())
|
||||
yield session
|
||||
except Exception as e:
|
||||
print("Session rollback because of exception:", e)
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_session():
|
||||
with session_getter() as session:
|
||||
yield session
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
# Path: src/backend/langflow/database/models/flowstyle.py
|
||||
|
||||
from langflow.database.models.base import SQLModelSerializable
|
||||
from sqlmodel import Field, Relationship
|
||||
from uuid import UUID, uuid4
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.database.models.flow import Flow
|
||||
|
||||
|
||||
class FlowStyleBase(SQLModelSerializable):
|
||||
color: str
|
||||
emoji: str
|
||||
flow_id: UUID = Field(default=None, foreign_key="flow.id")
|
||||
|
||||
|
||||
class FlowStyle(FlowStyleBase, table=True):
|
||||
id: UUID = Field(default_factory=uuid4, primary_key=True, unique=True)
|
||||
flow: "Flow" = Relationship(back_populates="style")
|
||||
|
||||
|
||||
class FlowStyleUpdate(SQLModelSerializable):
|
||||
color: Optional[str] = None
|
||||
emoji: Optional[str] = None
|
||||
|
||||
|
||||
class FlowStyleCreate(FlowStyleBase):
|
||||
pass
|
||||
|
||||
|
||||
class FlowStyleRead(FlowStyleBase):
|
||||
id: UUID
|
||||
53
src/backend/langflow/field_typing/__init__.py
Normal file
53
src/backend/langflow/field_typing/__init__.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
# LANGCHAIN_BASE_TYPES = {
|
||||
# "Chain": Chain,
|
||||
# "AgentExecutor": AgentExecutor,
|
||||
# "Tool": Tool,
|
||||
# "BaseLLM": BaseLLM,
|
||||
# "PromptTemplate": PromptTemplate,
|
||||
# "BaseLoader": BaseLoader,
|
||||
# "Document": Document,
|
||||
# "TextSplitter": TextSplitter,
|
||||
# "VectorStore": VectorStore,
|
||||
# "Embeddings": Embeddings,
|
||||
# "BaseRetriever": BaseRetriever,
|
||||
# "BaseOutputParser": BaseOutputParser,
|
||||
# "BaseMemory": BaseMemory,
|
||||
# "BaseChatMemory": BaseChatMemory,
|
||||
# }
|
||||
from .constants import (
|
||||
Tool,
|
||||
PromptTemplate,
|
||||
Chain,
|
||||
BaseChatMemory,
|
||||
BaseLLM,
|
||||
BaseLoader,
|
||||
BaseMemory,
|
||||
BaseOutputParser,
|
||||
BaseRetriever,
|
||||
VectorStore,
|
||||
Embeddings,
|
||||
TextSplitter,
|
||||
Document,
|
||||
AgentExecutor,
|
||||
NestedDict,
|
||||
Data,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"NestedDict",
|
||||
"Data",
|
||||
"Tool",
|
||||
"PromptTemplate",
|
||||
"Chain",
|
||||
"BaseChatMemory",
|
||||
"BaseLLM",
|
||||
"BaseLoader",
|
||||
"BaseMemory",
|
||||
"BaseOutputParser",
|
||||
"BaseRetriever",
|
||||
"VectorStore",
|
||||
"Embeddings",
|
||||
"TextSplitter",
|
||||
"Document",
|
||||
"AgentExecutor",
|
||||
]
|
||||
50
src/backend/langflow/field_typing/constants.py
Normal file
50
src/backend/langflow/field_typing/constants.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, BaseRetriever, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.memory import BaseMemory
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from langchain.tools import Tool
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from typing import Union, Dict
|
||||
|
||||
# Type alias for more complex dicts
|
||||
NestedDict = Dict[str, Union[str, Dict]]
|
||||
|
||||
|
||||
class Data:
|
||||
pass
|
||||
|
||||
|
||||
LANGCHAIN_BASE_TYPES = {
|
||||
"Chain": Chain,
|
||||
"AgentExecutor": AgentExecutor,
|
||||
"Tool": Tool,
|
||||
"BaseLLM": BaseLLM,
|
||||
"PromptTemplate": PromptTemplate,
|
||||
"BaseLoader": BaseLoader,
|
||||
"Document": Document,
|
||||
"TextSplitter": TextSplitter,
|
||||
"VectorStore": VectorStore,
|
||||
"Embeddings": Embeddings,
|
||||
"BaseRetriever": BaseRetriever,
|
||||
"BaseOutputParser": BaseOutputParser,
|
||||
"BaseMemory": BaseMemory,
|
||||
"BaseChatMemory": BaseChatMemory,
|
||||
}
|
||||
# Langchain base types plus Python base types
|
||||
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
|
||||
**LANGCHAIN_BASE_TYPES,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"NestedDict": NestedDict,
|
||||
"Data": Data,
|
||||
}
|
||||
|
|
@ -1,22 +1,84 @@
|
|||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from typing import TYPE_CHECKING
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
||||
class SourceHandle(BaseModel):
|
||||
baseClasses: List[str] = Field(
|
||||
..., description="List of base classes for the source handle."
|
||||
)
|
||||
dataType: str = Field(..., description="Data type for the source handle.")
|
||||
id: str = Field(..., description="Unique identifier for the source handle.")
|
||||
|
||||
|
||||
class TargetHandle(BaseModel):
|
||||
fieldName: str = Field(..., description="Field name for the target handle.")
|
||||
id: str = Field(..., description="Unique identifier for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(
|
||||
None, description="List of input types for the target handle."
|
||||
)
|
||||
type: str = Field(..., description="Type of the target handle.")
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, source: "Vertex", target: "Vertex", edge: dict):
|
||||
self.source: "Vertex" = source
|
||||
self.target: "Vertex" = target
|
||||
self.source_handle = edge.get("sourceHandle", "")
|
||||
self.target_handle = edge.get("targetHandle", "")
|
||||
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
|
||||
# target_param is documents
|
||||
self.target_param = self.target_handle.split("|")[1]
|
||||
|
||||
if data := edge.get("data", {}):
|
||||
self._source_handle = data.get("sourceHandle", {})
|
||||
self._target_handle = data.get("targetHandle", {})
|
||||
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
|
||||
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
|
||||
self.target_param = self.target_handle.fieldName
|
||||
# validate handles
|
||||
self.validate_handles()
|
||||
else:
|
||||
# Logging here because this is a breaking change
|
||||
logger.error("Edge data is empty")
|
||||
self._source_handle = edge.get("sourceHandle", "")
|
||||
self._target_handle = edge.get("targetHandle", "")
|
||||
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
|
||||
# target_param is documents
|
||||
self.target_param = self._target_handle.split("|")[1]
|
||||
# Validate in __init__ to fail fast
|
||||
self.validate_edge()
|
||||
|
||||
def validate_handles(self) -> None:
|
||||
if self.target_handle.inputTypes is None:
|
||||
self.valid_handles = (
|
||||
self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
else:
|
||||
self.valid_handles = (
|
||||
any(
|
||||
baseClass in self.target_handle.inputTypes
|
||||
for baseClass in self.source_handle.baseClasses
|
||||
)
|
||||
or self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
if not self.valid_handles:
|
||||
logger.debug(self.source_handle)
|
||||
logger.debug(self.target_handle)
|
||||
raise ValueError(
|
||||
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} "
|
||||
f"has invalid handles"
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.source = state["source"]
|
||||
self.target = state["target"]
|
||||
self.target_param = state["target_param"]
|
||||
self.source_handle = state.get("source_handle")
|
||||
self.target_handle = state.get("target_handle")
|
||||
|
||||
def reset(self) -> None:
|
||||
self.source._build_params()
|
||||
self.target._build_params()
|
||||
|
||||
def validate_edge(self) -> None:
|
||||
# Validate that the outputs of the source node are valid inputs
|
||||
# for the target node
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Dict, Generator, List, Type, Union
|
||||
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.graph.constants import VERTEX_TYPE_MAP
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.utils import process_flow
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import (
|
||||
FileToolVertex,
|
||||
|
|
@ -10,7 +11,7 @@ from langflow.graph.vertex.types import (
|
|||
)
|
||||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.utils import payload
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
|
|
@ -19,13 +20,29 @@ class Graph:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]],
|
||||
nodes: List[Dict],
|
||||
edges: List[Dict[str, str]],
|
||||
) -> None:
|
||||
self._nodes = nodes
|
||||
self._edges = edges
|
||||
self.raw_graph_data = {"nodes": nodes, "edges": edges}
|
||||
|
||||
self.top_level_nodes = []
|
||||
for node in self._nodes:
|
||||
if node_id := node.get("id"):
|
||||
self.top_level_nodes.append(node_id)
|
||||
|
||||
self._graph_data = process_flow(self.raw_graph_data)
|
||||
self._nodes = self._graph_data["nodes"]
|
||||
self._edges = self._graph_data["edges"]
|
||||
self._build_graph()
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
for edge in self.edges:
|
||||
edge.reset()
|
||||
edge.validate_edge()
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Dict) -> "Graph":
|
||||
"""
|
||||
|
|
@ -44,10 +61,16 @@ class Graph:
|
|||
edges = payload["edges"]
|
||||
return cls(nodes, edges)
|
||||
except KeyError as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(
|
||||
f"Invalid payload. Expected keys 'nodes' and 'edges'. Found {list(payload.keys())}"
|
||||
) from exc
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Graph):
|
||||
return False
|
||||
return self.__repr__() == other.__repr__()
|
||||
|
||||
def _build_graph(self) -> None:
|
||||
"""Builds the graph from the nodes and edges."""
|
||||
self.nodes = self._build_vertices()
|
||||
|
|
@ -144,10 +167,10 @@ class Graph:
|
|||
|
||||
return list(reversed(sorted_vertices))
|
||||
|
||||
def generator_build(self) -> Generator:
|
||||
def generator_build(self) -> Generator[Vertex, None, None]:
|
||||
"""Builds each vertex in the graph and yields it."""
|
||||
sorted_vertices = self.topological_sort()
|
||||
logger.debug("Sorted vertices: %s", sorted_vertices)
|
||||
logger.debug("There are %s vertices in the graph", len(sorted_vertices))
|
||||
yield from sorted_vertices
|
||||
|
||||
def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]:
|
||||
|
|
@ -187,10 +210,12 @@ class Graph:
|
|||
"""Returns the node class based on the node type."""
|
||||
if node_type in FILE_TOOLS:
|
||||
return FileToolVertex
|
||||
if node_type in VERTEX_TYPE_MAP:
|
||||
return VERTEX_TYPE_MAP[node_type]
|
||||
if node_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type]
|
||||
return (
|
||||
VERTEX_TYPE_MAP[node_lc_type] if node_lc_type in VERTEX_TYPE_MAP else Vertex
|
||||
lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_lc_type]
|
||||
if node_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
|
||||
else Vertex
|
||||
)
|
||||
|
||||
def _build_vertices(self) -> List[Vertex]:
|
||||
|
|
@ -202,7 +227,9 @@ class Graph:
|
|||
node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore
|
||||
|
||||
VertexClass = self._get_vertex_class(node_type, node_lc_type)
|
||||
nodes.append(VertexClass(node))
|
||||
vertex = VertexClass(node)
|
||||
vertex.set_top_level(self.top_level_nodes)
|
||||
nodes.append(vertex)
|
||||
|
||||
return nodes
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex import types
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
|
|
@ -15,23 +14,45 @@ from langflow.interface.wrappers.base import wrapper_creator
|
|||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from typing import Dict, Type
|
||||
from langflow.utils.lazy_load import LazyLoadDictBase
|
||||
|
||||
|
||||
VERTEX_TYPE_MAP: Dict[str, Type[Vertex]] = {
|
||||
**{t: types.PromptVertex for t in prompt_creator.to_list()},
|
||||
**{t: types.AgentVertex for t in agent_creator.to_list()},
|
||||
**{t: types.ChainVertex for t in chain_creator.to_list()},
|
||||
**{t: types.ToolVertex for t in tool_creator.to_list()},
|
||||
**{t: types.ToolkitVertex for t in toolkits_creator.to_list()},
|
||||
**{t: types.WrapperVertex for t in wrapper_creator.to_list()},
|
||||
**{t: types.LLMVertex for t in llm_creator.to_list()},
|
||||
**{t: types.MemoryVertex for t in memory_creator.to_list()},
|
||||
**{t: types.EmbeddingVertex for t in embedding_creator.to_list()},
|
||||
**{t: types.VectorStoreVertex for t in vectorstore_creator.to_list()},
|
||||
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
|
||||
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
|
||||
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
|
||||
**{t: types.CustomComponentVertex for t in custom_component_creator.to_list()},
|
||||
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
|
||||
}
|
||||
class VertexTypesDict(LazyLoadDictBase):
|
||||
def __init__(self):
|
||||
self._all_types_dict = None
|
||||
|
||||
@property
|
||||
def VERTEX_TYPE_MAP(self):
|
||||
return self.all_types_dict
|
||||
|
||||
def _build_dict(self):
|
||||
langchain_types_dict = self.get_type_dict()
|
||||
return {
|
||||
**langchain_types_dict,
|
||||
"Custom": ["Custom Tool", "Python Function"],
|
||||
}
|
||||
|
||||
def get_type_dict(self):
|
||||
return {
|
||||
**{t: types.PromptVertex for t in prompt_creator.to_list()},
|
||||
**{t: types.AgentVertex for t in agent_creator.to_list()},
|
||||
**{t: types.ChainVertex for t in chain_creator.to_list()},
|
||||
**{t: types.ToolVertex for t in tool_creator.to_list()},
|
||||
**{t: types.ToolkitVertex for t in toolkits_creator.to_list()},
|
||||
**{t: types.WrapperVertex for t in wrapper_creator.to_list()},
|
||||
**{t: types.LLMVertex for t in llm_creator.to_list()},
|
||||
**{t: types.MemoryVertex for t in memory_creator.to_list()},
|
||||
**{t: types.EmbeddingVertex for t in embedding_creator.to_list()},
|
||||
**{t: types.VectorStoreVertex for t in vectorstore_creator.to_list()},
|
||||
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
|
||||
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
|
||||
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
|
||||
**{
|
||||
t: types.CustomComponentVertex
|
||||
for t in custom_component_creator.to_list()
|
||||
},
|
||||
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
|
||||
}
|
||||
|
||||
|
||||
lazy_load_vertex_dict = VertexTypesDict()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,230 @@
|
|||
from collections import deque
|
||||
import copy
|
||||
|
||||
|
||||
def find_last_node(nodes, edges):
|
||||
"""
|
||||
This function receives a flow and returns the last node.
|
||||
"""
|
||||
return next((n for n in nodes if all(e["source"] != n["id"] for e in edges)), None)
|
||||
|
||||
|
||||
def add_parent_node_id(nodes, parent_node_id):
|
||||
"""
|
||||
This function receives a list of nodes and adds a parent_node_id to each node.
|
||||
"""
|
||||
for node in nodes:
|
||||
node["parent_node_id"] = parent_node_id
|
||||
|
||||
|
||||
def ungroup_node(group_node_data, base_flow):
|
||||
template, flow = (
|
||||
group_node_data["node"]["template"],
|
||||
group_node_data["node"]["flow"],
|
||||
)
|
||||
parent_node_id = group_node_data["id"]
|
||||
g_nodes = flow["data"]["nodes"]
|
||||
add_parent_node_id(g_nodes, parent_node_id)
|
||||
g_edges = flow["data"]["edges"]
|
||||
|
||||
# Redirect edges to the correct proxy node
|
||||
updated_edges = get_updated_edges(
|
||||
base_flow, g_nodes, g_edges, group_node_data["id"]
|
||||
)
|
||||
|
||||
# Update template values
|
||||
update_template(template, g_nodes)
|
||||
|
||||
nodes = [
|
||||
n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]
|
||||
] + g_nodes
|
||||
edges = (
|
||||
[
|
||||
e
|
||||
for e in base_flow["edges"]
|
||||
if e["target"] != group_node_data["id"]
|
||||
and e["source"] != group_node_data["id"]
|
||||
]
|
||||
+ g_edges
|
||||
+ updated_edges
|
||||
)
|
||||
|
||||
base_flow["nodes"] = nodes
|
||||
base_flow["edges"] = edges
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def process_flow(flow_object):
|
||||
cloned_flow = copy.deepcopy(flow_object)
|
||||
processed_nodes = set() # To keep track of processed nodes
|
||||
|
||||
def process_node(node):
|
||||
node_id = node.get("id")
|
||||
|
||||
# If node already processed, skip
|
||||
if node_id in processed_nodes:
|
||||
return
|
||||
|
||||
if (
|
||||
node.get("data")
|
||||
and node["data"].get("node")
|
||||
and node["data"]["node"].get("flow")
|
||||
):
|
||||
process_flow(node["data"]["node"]["flow"]["data"])
|
||||
new_nodes = ungroup_node(node["data"], cloned_flow)
|
||||
# Add new nodes to the queue for future processing
|
||||
nodes_to_process.extend(new_nodes)
|
||||
|
||||
# Mark node as processed
|
||||
processed_nodes.add(node_id)
|
||||
|
||||
nodes_to_process = deque(cloned_flow["nodes"])
|
||||
|
||||
while nodes_to_process:
|
||||
node = nodes_to_process.popleft()
|
||||
process_node(node)
|
||||
|
||||
return cloned_flow
|
||||
|
||||
|
||||
def update_template(template, g_nodes):
|
||||
"""
|
||||
Updates the template of a node in a graph with the given template.
|
||||
|
||||
Args:
|
||||
template (dict): The new template to update the node with.
|
||||
g_nodes (list): The list of nodes in the graph.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for _, value in template.items():
|
||||
if not value.get("proxy"):
|
||||
continue
|
||||
proxy_dict = value["proxy"]
|
||||
field, id_ = proxy_dict["field"], proxy_dict["id"]
|
||||
node_index = next((i for i, n in enumerate(g_nodes) if n["id"] == id_), -1)
|
||||
if node_index != -1:
|
||||
display_name = None
|
||||
show = g_nodes[node_index]["data"]["node"]["template"][field]["show"]
|
||||
advanced = g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"advanced"
|
||||
]
|
||||
if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]:
|
||||
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"display_name"
|
||||
]
|
||||
else:
|
||||
display_name = g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"name"
|
||||
]
|
||||
|
||||
g_nodes[node_index]["data"]["node"]["template"][field] = value
|
||||
g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show
|
||||
g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"advanced"
|
||||
] = advanced
|
||||
g_nodes[node_index]["data"]["node"]["template"][field][
|
||||
"display_name"
|
||||
] = display_name
|
||||
|
||||
|
||||
def update_target_handle(new_edge, g_nodes, group_node_id):
|
||||
"""
|
||||
Updates the target handle of a given edge if it is a proxy node.
|
||||
|
||||
Args:
|
||||
new_edge (dict): The edge to update.
|
||||
g_nodes (list): The list of nodes in the graph.
|
||||
group_node_id (str): The ID of the group node.
|
||||
|
||||
Returns:
|
||||
dict: The updated edge.
|
||||
"""
|
||||
target_handle = new_edge["data"]["targetHandle"]
|
||||
if target_handle.get("proxy"):
|
||||
proxy_id = target_handle["proxy"]["id"]
|
||||
if node := next((n for n in g_nodes if n["id"] == proxy_id), None):
|
||||
set_new_target_handle(proxy_id, new_edge, target_handle, node)
|
||||
return new_edge
|
||||
|
||||
|
||||
def set_new_target_handle(proxy_id, new_edge, target_handle, node):
|
||||
"""
|
||||
Sets a new target handle for a given edge.
|
||||
|
||||
Args:
|
||||
proxy_id (str): The ID of the proxy.
|
||||
new_edge (dict): The new edge to be created.
|
||||
target_handle (dict): The target handle of the edge.
|
||||
node (dict): The node containing the edge.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
new_edge["target"] = proxy_id
|
||||
_type = target_handle.get("type")
|
||||
if _type is None:
|
||||
raise KeyError("The 'type' key must be present in target_handle.")
|
||||
|
||||
field = target_handle["proxy"]["field"]
|
||||
new_target_handle = {
|
||||
"fieldName": field,
|
||||
"type": _type,
|
||||
"id": proxy_id,
|
||||
}
|
||||
if node["data"]["node"].get("flow"):
|
||||
new_target_handle["proxy"] = {
|
||||
"field": node["data"]["node"]["template"][field]["proxy"]["field"],
|
||||
"id": node["data"]["node"]["template"][field]["proxy"]["id"],
|
||||
}
|
||||
if input_types := target_handle.get("inputTypes"):
|
||||
new_target_handle["inputTypes"] = input_types
|
||||
new_edge["data"]["targetHandle"] = new_target_handle
|
||||
|
||||
|
||||
def update_source_handle(new_edge, g_nodes, g_edges):
|
||||
"""
|
||||
Updates the source handle of a given edge to the last node in the flow data.
|
||||
|
||||
Args:
|
||||
new_edge (dict): The edge to update.
|
||||
flow_data (dict): The flow data containing the nodes and edges.
|
||||
|
||||
Returns:
|
||||
dict: The updated edge with the new source handle.
|
||||
"""
|
||||
last_node = copy.deepcopy(find_last_node(g_nodes, g_edges))
|
||||
new_edge["source"] = last_node["id"]
|
||||
new_source_handle = new_edge["data"]["sourceHandle"]
|
||||
new_source_handle["id"] = last_node["id"]
|
||||
new_edge["data"]["sourceHandle"] = new_source_handle
|
||||
return new_edge
|
||||
|
||||
|
||||
def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
|
||||
"""
|
||||
Given a base flow, a list of graph nodes and a group node id, returns a list of updated edges.
|
||||
An updated edge is an edge that has its target or source handle updated based on the group node id.
|
||||
|
||||
Args:
|
||||
base_flow (dict): The base flow containing a list of edges.
|
||||
g_nodes (list): A list of graph nodes.
|
||||
group_node_id (str): The id of the group node.
|
||||
|
||||
Returns:
|
||||
list: A list of updated edges.
|
||||
"""
|
||||
updated_edges = []
|
||||
for edge in base_flow["edges"]:
|
||||
new_edge = copy.deepcopy(edge)
|
||||
if new_edge["target"] == group_node_id:
|
||||
new_edge = update_target_handle(new_edge, g_nodes, group_node_id)
|
||||
|
||||
if new_edge["source"] == group_node_id:
|
||||
new_edge = update_source_handle(new_edge, g_nodes, g_edges)
|
||||
|
||||
if edge["target"] == group_node_id or edge["source"] == group_node_id:
|
||||
updated_edges.append(new_edge)
|
||||
return updated_edges
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
import ast
|
||||
import pickle
|
||||
from langflow.graph.utils import UnbuiltObject
|
||||
from langflow.graph.vertex.utils import is_basic_type
|
||||
from langflow.interface.initialize import loading
|
||||
from langflow.interface.listing import ALL_TYPES_DICT
|
||||
from langflow.interface.listing import lazy_load_dict
|
||||
from langflow.utils.constants import DIRECT_TYPES
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import sync_to_async
|
||||
|
||||
|
||||
|
|
@ -12,12 +14,19 @@ import types
|
|||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.edge.base import Edge
|
||||
|
||||
|
||||
class Vertex:
|
||||
def __init__(self, data: Dict, base_type: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
data: Dict,
|
||||
base_type: Optional[str] = None,
|
||||
is_task: bool = False,
|
||||
params: Optional[Dict] = None,
|
||||
) -> None:
|
||||
self.id: str = data["id"]
|
||||
self._data = data
|
||||
self.edges: List["Edge"] = []
|
||||
|
|
@ -26,6 +35,66 @@ class Vertex:
|
|||
self._built_object = UnbuiltObject()
|
||||
self._built = False
|
||||
self.artifacts: Dict[str, Any] = {}
|
||||
self.task_id: Optional[str] = None
|
||||
self.is_task = is_task
|
||||
self.params = params or {}
|
||||
self.parent_node_id: Optional[str] = self._data.get("parent_node_id")
|
||||
self.parent_is_top_level = False
|
||||
|
||||
def reset_params(self):
|
||||
for edge in self.edges:
|
||||
if edge.source != self:
|
||||
target_param = edge.target_param
|
||||
if target_param in ["document", "texts"]:
|
||||
# this means they got data and have already ingested it
|
||||
# so we continue after removing the param
|
||||
self.params.pop(target_param, None)
|
||||
continue
|
||||
|
||||
if target_param in self.params and not is_basic_type(
|
||||
self.params[target_param]
|
||||
):
|
||||
# edge.source.params = {}
|
||||
edge.source._build_params()
|
||||
edge.source._built_object = UnbuiltObject()
|
||||
edge.source._built = False
|
||||
|
||||
self.params[target_param] = edge.source
|
||||
|
||||
def __getstate__(self):
|
||||
state_dict = self.__dict__.copy()
|
||||
try:
|
||||
# try pickling the built object
|
||||
# if it fails, then we need to delete it
|
||||
# and build it again
|
||||
pickle.dumps(state_dict["_built_object"])
|
||||
except Exception:
|
||||
self.reset_params()
|
||||
del state_dict["_built_object"]
|
||||
del state_dict["_built"]
|
||||
return state_dict
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._data = state["_data"]
|
||||
self.params = state["params"]
|
||||
self.base_type = state["base_type"]
|
||||
self.is_task = state["is_task"]
|
||||
self.edges = state["edges"]
|
||||
self.id = state["id"]
|
||||
self._parse_data()
|
||||
if "_built_object" in state:
|
||||
self._built_object = state["_built_object"]
|
||||
self._built = state["_built"]
|
||||
else:
|
||||
self._built_object = UnbuiltObject()
|
||||
self._built = False
|
||||
self.artifacts: Dict[str, Any] = {}
|
||||
self.task_id: Optional[str] = None
|
||||
self.parent_node_id = state["parent_node_id"]
|
||||
self.parent_is_top_level = state["parent_is_top_level"]
|
||||
|
||||
def set_top_level(self, top_level_nodes: List[str]) -> None:
|
||||
self.parent_is_top_level = self.parent_node_id in top_level_nodes
|
||||
|
||||
def _parse_data(self) -> None:
|
||||
self.data = self._data["data"]
|
||||
|
|
@ -63,11 +132,18 @@ class Vertex:
|
|||
)
|
||||
|
||||
if self.base_type is None:
|
||||
for base_type, value in ALL_TYPES_DICT.items():
|
||||
for base_type, value in lazy_load_dict.ALL_TYPES_DICT.items():
|
||||
if self.vertex_type in value:
|
||||
self.base_type = base_type
|
||||
break
|
||||
|
||||
def get_task(self):
|
||||
# using the task_id, get the task from celery
|
||||
# and return it
|
||||
from celery.result import AsyncResult # type: ignore
|
||||
|
||||
return AsyncResult(self.task_id)
|
||||
|
||||
def _build_params(self):
|
||||
# sourcery skip: merge-list-append, remove-redundant-if
|
||||
# Some params are required, some are optional
|
||||
|
|
@ -89,9 +165,11 @@ class Vertex:
|
|||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
params = {}
|
||||
params = self.params.copy() if self.params else {}
|
||||
|
||||
for edge in self.edges:
|
||||
if not hasattr(edge, "target_param"):
|
||||
continue
|
||||
param_key = edge.target_param
|
||||
if param_key in template_dict:
|
||||
if template_dict[param_key]["list"]:
|
||||
|
|
@ -102,6 +180,8 @@ class Vertex:
|
|||
params[param_key] = edge.source
|
||||
|
||||
for key, value in template_dict.items():
|
||||
if key in params:
|
||||
continue
|
||||
# Skip _type and any value that has show == False and is not code
|
||||
# If we don't want to show code but we want to use it
|
||||
if key == "_type" or (not value.get("show") and key != "code"):
|
||||
|
|
@ -112,9 +192,10 @@ class Vertex:
|
|||
# Load the type in value.get('suffixes') using
|
||||
# what is inside value.get('content')
|
||||
# value.get('value') is the file name
|
||||
file_path = value.get("file_path")
|
||||
|
||||
params[key] = file_path
|
||||
if file_path := value.get("file_path"):
|
||||
params[key] = file_path
|
||||
else:
|
||||
raise ValueError(f"File path not found for {self.vertex_type}")
|
||||
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
|
||||
if value.get("type") == "code":
|
||||
try:
|
||||
|
|
@ -122,6 +203,29 @@ class Vertex:
|
|||
except Exception as exc:
|
||||
logger.debug(f"Error parsing code: {exc}")
|
||||
params[key] = value.get("value")
|
||||
elif value.get("type") in ["dict", "NestedDict"]:
|
||||
# When dict comes from the frontend it comes as a
|
||||
# list of dicts, so we need to convert it to a dict
|
||||
# before passing it to the build method
|
||||
_value = value.get("value")
|
||||
if isinstance(_value, list):
|
||||
params[key] = {
|
||||
k: v
|
||||
for item in value.get("value", [])
|
||||
for k, v in item.items()
|
||||
}
|
||||
elif isinstance(_value, dict):
|
||||
params[key] = _value
|
||||
elif value.get("type") == "int" and value.get("value") is not None:
|
||||
try:
|
||||
params[key] = int(value.get("value"))
|
||||
except ValueError:
|
||||
params[key] = value.get("value")
|
||||
elif value.get("type") == "float" and value.get("value") is not None:
|
||||
try:
|
||||
params[key] = float(value.get("value"))
|
||||
except ValueError:
|
||||
params[key] = value.get("value")
|
||||
else:
|
||||
params[key] = value.get("value")
|
||||
|
||||
|
|
@ -131,20 +235,21 @@ class Vertex:
|
|||
else:
|
||||
params.pop(key, None)
|
||||
# Add _type to params
|
||||
self._raw_params = params
|
||||
self.params = params
|
||||
|
||||
def _build(self):
|
||||
def _build(self, user_id=None):
|
||||
"""
|
||||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.vertex_type}")
|
||||
self._build_each_node_in_params_dict()
|
||||
self._get_and_instantiate_class()
|
||||
self._build_each_node_in_params_dict(user_id)
|
||||
self._get_and_instantiate_class(user_id)
|
||||
self._validate_built_object()
|
||||
|
||||
self._built = True
|
||||
|
||||
def _build_each_node_in_params_dict(self):
|
||||
def _build_each_node_in_params_dict(self, user_id=None):
|
||||
"""
|
||||
Iterates over each node in the params dictionary and builds it.
|
||||
"""
|
||||
|
|
@ -153,9 +258,9 @@ class Vertex:
|
|||
if value == self:
|
||||
del self.params[key]
|
||||
continue
|
||||
self._build_node_and_update_params(key, value)
|
||||
self._build_node_and_update_params(key, value, user_id)
|
||||
elif isinstance(value, list) and self._is_list_of_nodes(value):
|
||||
self._build_list_of_nodes_and_update_params(key, value)
|
||||
self._build_list_of_nodes_and_update_params(key, value, user_id)
|
||||
|
||||
def _is_node(self, value):
|
||||
"""
|
||||
|
|
@ -169,23 +274,45 @@ class Vertex:
|
|||
"""
|
||||
return all(self._is_node(node) for node in value)
|
||||
|
||||
def _build_node_and_update_params(self, key, node):
|
||||
def get_result(self, user_id=None, timeout=None) -> Any:
|
||||
# Check if the Vertex was built already
|
||||
if self._built:
|
||||
return self._built_object
|
||||
|
||||
if self.is_task and self.task_id is not None:
|
||||
task = self.get_task()
|
||||
result = task.get(timeout=timeout)
|
||||
if result is not None: # If result is ready
|
||||
self._update_built_object_and_artifacts(result)
|
||||
return self._built_object
|
||||
else:
|
||||
# Handle the case when the result is not ready (retry, throw exception, etc.)
|
||||
pass
|
||||
|
||||
# If there's no task_id, build the vertex locally
|
||||
self.build(user_id)
|
||||
return self._built_object
|
||||
|
||||
def _build_node_and_update_params(self, key, node, user_id=None):
|
||||
"""
|
||||
Builds a given node and updates the params dictionary accordingly.
|
||||
"""
|
||||
result = node.build()
|
||||
|
||||
result = node.get_result(user_id)
|
||||
self._handle_func(key, result)
|
||||
if isinstance(result, list):
|
||||
self._extend_params_list_with_result(key, result)
|
||||
self.params[key] = result
|
||||
|
||||
def _build_list_of_nodes_and_update_params(self, key, nodes):
|
||||
def _build_list_of_nodes_and_update_params(
|
||||
self, key, nodes: List["Vertex"], user_id=None
|
||||
):
|
||||
"""
|
||||
Iterates over a list of nodes, builds each and updates the params dictionary.
|
||||
"""
|
||||
self.params[key] = []
|
||||
for node in nodes:
|
||||
built = node.build()
|
||||
built = node.get_result(user_id)
|
||||
if isinstance(built, list):
|
||||
if key not in self.params:
|
||||
self.params[key] = []
|
||||
|
|
@ -215,7 +342,7 @@ class Vertex:
|
|||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
|
||||
def _get_and_instantiate_class(self):
|
||||
def _get_and_instantiate_class(self, user_id=None):
|
||||
"""
|
||||
Gets the class from a dictionary and instantiates it with the params.
|
||||
"""
|
||||
|
|
@ -226,11 +353,13 @@ class Vertex:
|
|||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
params=self.params,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(
|
||||
f"Error building node {self.vertex_type}: {str(exc)}"
|
||||
f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}"
|
||||
) from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result):
|
||||
|
|
@ -255,9 +384,9 @@ class Vertex:
|
|||
|
||||
raise ValueError(message)
|
||||
|
||||
def build(self, force: bool = False) -> Any:
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
self._build()
|
||||
self._build(user_id, *args, **kwargs)
|
||||
|
||||
return self._built_object
|
||||
|
||||
|
|
@ -269,7 +398,10 @@ class Vertex:
|
|||
return f"Vertex(id={self.id}, data={self.data})"
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self.id == __o.id if isinstance(__o, Vertex) else False
|
||||
try:
|
||||
return self.id == __o.id if isinstance(__o, Vertex) else False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
|
|
|||
|
|
@ -7,55 +7,68 @@ from langflow.interface.utils import extract_input_variables_from_prompt
|
|||
|
||||
|
||||
class AgentVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="agents")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="agents", params=params)
|
||||
|
||||
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
|
||||
self.chains: List[ChainVertex] = []
|
||||
|
||||
def __getstate__(self):
|
||||
state = super().__getstate__()
|
||||
state["tools"] = self.tools
|
||||
state["chains"] = self.chains
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.tools = state["tools"]
|
||||
self.chains = state["chains"]
|
||||
super().__setstate__(state)
|
||||
|
||||
def _set_tools_and_chains(self) -> None:
|
||||
for edge in self.edges:
|
||||
if not hasattr(edge, "source"):
|
||||
continue
|
||||
source_node = edge.source
|
||||
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
|
||||
self.tools.append(source_node)
|
||||
elif isinstance(source_node, ChainVertex):
|
||||
self.chains.append(source_node)
|
||||
|
||||
def build(self, force: bool = False) -> Any:
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
self._set_tools_and_chains()
|
||||
# First, build the tools
|
||||
for tool_node in self.tools:
|
||||
tool_node.build()
|
||||
tool_node.build(user_id=user_id)
|
||||
|
||||
# Next, build the chains and the rest
|
||||
for chain_node in self.chains:
|
||||
chain_node.build(tools=self.tools)
|
||||
chain_node.build(tools=self.tools, user_id=user_id)
|
||||
|
||||
self._build()
|
||||
self._build(user_id=user_id)
|
||||
|
||||
return self._built_object
|
||||
|
||||
|
||||
class ToolVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="tools")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="tools", params=params)
|
||||
|
||||
|
||||
class LLMVertex(Vertex):
|
||||
built_node_type = None
|
||||
class_built_object = None
|
||||
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="llms")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="llms", params=params)
|
||||
|
||||
def build(self, force: bool = False) -> Any:
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
# LLM is different because some models might take up too much memory
|
||||
# or time to load. So we only load them when we need them.ß
|
||||
if self.vertex_type == self.built_node_type:
|
||||
return self.class_built_object
|
||||
if not self._built or force:
|
||||
self._build()
|
||||
self._build(user_id=user_id)
|
||||
self.built_node_type = self.vertex_type
|
||||
self.class_built_object = self._built_object
|
||||
# Avoid deepcopying the LLM
|
||||
|
|
@ -64,39 +77,41 @@ class LLMVertex(Vertex):
|
|||
|
||||
|
||||
class ToolkitVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="toolkits")
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="toolkits", params=params)
|
||||
|
||||
|
||||
class FileToolVertex(ToolVertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data)
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, params=params)
|
||||
|
||||
|
||||
class WrapperVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="wrappers")
|
||||
|
||||
def build(self, force: bool = False) -> Any:
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
if "headers" in self.params:
|
||||
self.params["headers"] = ast.literal_eval(self.params["headers"])
|
||||
self._build()
|
||||
self._build(user_id=user_id)
|
||||
return self._built_object
|
||||
|
||||
|
||||
class DocumentLoaderVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="documentloaders")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="documentloaders", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
|
||||
self._built_object
|
||||
)
|
||||
avg_length = sum(
|
||||
len(doc.page_content)
|
||||
for doc in self._built_object
|
||||
if hasattr(doc, "page_content")
|
||||
) / len(self._built_object)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
Documents: {self._built_object[:3]}..."""
|
||||
|
|
@ -104,14 +119,51 @@ class DocumentLoaderVertex(Vertex):
|
|||
|
||||
|
||||
class EmbeddingVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="embeddings")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="embeddings", params=params)
|
||||
|
||||
|
||||
class VectorStoreVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="vectorstores")
|
||||
|
||||
self.params = params or {}
|
||||
|
||||
# VectorStores may contain databse connections
|
||||
# so we need to define the __reduce__ method and the __setstate__ method
|
||||
# to avoid pickling errors
|
||||
def clean_edges_for_pickling(self):
|
||||
# for each edge that has self as source
|
||||
# we need to clear the _built_object of the target
|
||||
# so that we don't try to pickle a database connection
|
||||
for edge in self.edges:
|
||||
if edge.source == self:
|
||||
edge.target._built_object = None
|
||||
edge.target._built = False
|
||||
edge.target.params[edge.target_param] = self
|
||||
|
||||
def remove_docs_and_texts_from_params(self):
|
||||
# remove documents and texts from params
|
||||
# so that we don't try to pickle a database connection
|
||||
self.params.pop("documents", None)
|
||||
self.params.pop("texts", None)
|
||||
|
||||
def __getstate__(self):
|
||||
# We want to save the params attribute
|
||||
# and if "documents" or "texts" are in the params
|
||||
# we want to remove them because they have already
|
||||
# been processed.
|
||||
params = self.params.copy()
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
self.clean_edges_for_pickling()
|
||||
|
||||
return super().__getstate__()
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
self.remove_docs_and_texts_from_params()
|
||||
|
||||
|
||||
class MemoryVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
|
|
@ -124,8 +176,8 @@ class RetrieverVertex(Vertex):
|
|||
|
||||
|
||||
class TextSplitterVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="textsplitters")
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="textsplitters", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
|
|
@ -148,16 +200,19 @@ class ChainVertex(Vertex):
|
|||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
|
||||
user_id=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
# Check if the chain requires a PromptVertex
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, PromptVertex):
|
||||
# Build the PromptVertex, passing the tools if available
|
||||
tools = kwargs.get("tools", None)
|
||||
self.params[key] = value.build(tools=tools, force=force)
|
||||
|
||||
self._build()
|
||||
self._build(user_id=user_id)
|
||||
|
||||
return self._built_object
|
||||
|
||||
|
|
@ -169,7 +224,10 @@ class PromptVertex(Vertex):
|
|||
def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
user_id=None,
|
||||
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
if (
|
||||
|
|
@ -180,7 +238,7 @@ class PromptVertex(Vertex):
|
|||
# Check if it is a ZeroShotPrompt and needs a tool
|
||||
if "ShotPrompt" in self.vertex_type:
|
||||
tools = (
|
||||
[tool_node.build() for tool_node in tools]
|
||||
[tool_node.build(user_id=user_id) for tool_node in tools]
|
||||
if tools is not None
|
||||
else []
|
||||
)
|
||||
|
|
@ -205,10 +263,10 @@ class PromptVertex(Vertex):
|
|||
self.params["input_variables"] = list(
|
||||
set(self.params["input_variables"])
|
||||
)
|
||||
else:
|
||||
elif isinstance(self.params, dict):
|
||||
self.params.pop("input_variables", None)
|
||||
|
||||
self._build()
|
||||
self._build(user_id=user_id)
|
||||
return self._built_object
|
||||
|
||||
def _built_object_repr(self):
|
||||
|
|
@ -252,8 +310,13 @@ class OutputParserVertex(Vertex):
|
|||
|
||||
class CustomComponentVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="custom_components")
|
||||
super().__init__(data, base_type="custom_components", is_task=True)
|
||||
|
||||
def _built_object_repr(self):
|
||||
if self.task_id and self.is_task:
|
||||
if task := self.get_task():
|
||||
return str(task.info)
|
||||
else:
|
||||
return f"Task {self.task_id} is not running"
|
||||
if self.artifacts and "repr" in self.artifacts:
|
||||
return self.artifacts["repr"] or super()._built_object_repr()
|
||||
|
|
|
|||
5
src/backend/langflow/graph/vertex/utils.py
Normal file
5
src/backend/langflow/graph/vertex/utils.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from langflow.utils.constants import PYTHON_BASIC_TYPES
|
||||
|
||||
|
||||
def is_basic_type(obj):
|
||||
return type(obj) in PYTHON_BASIC_TYPES
|
||||
|
|
@ -5,9 +5,10 @@ from langchain.agents import types
|
|||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.agents import AgentFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
|
||||
|
||||
|
|
@ -53,13 +54,17 @@ class AgentCreator(LangChainTypeCreator):
|
|||
# Now this is a generator
|
||||
def to_list(self) -> List[str]:
|
||||
names = []
|
||||
settings_service = get_settings_service()
|
||||
for _, agent in self.type_to_loader_dict.items():
|
||||
agent_name = (
|
||||
agent.function_name()
|
||||
if hasattr(agent, "function_name")
|
||||
else agent.__name__
|
||||
)
|
||||
if agent_name in settings.AGENTS or settings.DEV:
|
||||
if (
|
||||
agent_name in settings_service.settings.AGENTS
|
||||
or settings_service.settings.DEV
|
||||
):
|
||||
names.append(agent_name)
|
||||
return names
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.agents import (
|
||||
AgentExecutor,
|
||||
Tool,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from langchain import LLMChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.agents import AgentExecutor, ZeroShotAgent
|
||||
from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
|
||||
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@ from abc import ABC, abstractmethod
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.agents import AgentExecutor
|
||||
from langflow.services.getters import get_settings_service
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.settings import settings
|
||||
from loguru import logger
|
||||
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
|
@ -26,9 +27,12 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
@property
|
||||
def docs_map(self) -> Dict[str, str]:
|
||||
"""A dict with the name of the component as key and the documentation link as value."""
|
||||
settings_service = get_settings_service()
|
||||
if self.name_docs_dict is None:
|
||||
try:
|
||||
type_settings = getattr(settings, self.type_name.upper())
|
||||
type_settings = getattr(
|
||||
settings_service.settings, self.type_name.upper()
|
||||
)
|
||||
self.name_docs_dict = {
|
||||
name: value_dict["documentation"]
|
||||
for name, value_dict in type_settings.items()
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ from typing import Any, Dict, List, Optional, Type
|
|||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.chains import ChainFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
from langchain import chains
|
||||
from langchain_experimental.sql import SQLDatabaseChain # type: ignore
|
||||
|
|
@ -30,6 +31,7 @@ class ChainCreator(LangChainTypeCreator):
|
|||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
settings_service = get_settings_service()
|
||||
self.type_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
|
|
@ -43,7 +45,8 @@ class ChainCreator(LangChainTypeCreator):
|
|||
self.type_dict = {
|
||||
name: chain
|
||||
for name, chain in self.type_dict.items()
|
||||
if name in settings.CHAINS or settings.DEV
|
||||
if name in settings_service.settings.CHAINS
|
||||
or settings_service.settings.DEV
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from langflow.interface.custom.custom_component import CustomComponent
|
|||
from langflow.template.frontend_node.custom_components import (
|
||||
CustomComponentFrontendNode,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
|
|
|||
|
|
@ -1,65 +1,33 @@
|
|||
from langchain import PromptTemplate
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from langchain.tools import Tool
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema.memory import BaseMemory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
LANGCHAIN_BASE_TYPES = {
|
||||
"Chain": Chain,
|
||||
"AgentExecutor": AgentExecutor,
|
||||
"Tool": Tool,
|
||||
"BaseLLM": BaseLLM,
|
||||
"PromptTemplate": PromptTemplate,
|
||||
"BaseLoader": BaseLoader,
|
||||
"Document": Document,
|
||||
"TextSplitter": TextSplitter,
|
||||
"VectorStore": VectorStore,
|
||||
"Embeddings": Embeddings,
|
||||
"BaseRetriever": BaseRetriever,
|
||||
"BaseOutputParser": BaseOutputParser,
|
||||
"BaseMemory": BaseMemory,
|
||||
"BaseChatMemory": BaseChatMemory,
|
||||
}
|
||||
|
||||
# Langchain base types plus Python base types
|
||||
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
|
||||
**LANGCHAIN_BASE_TYPES,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_CUSTOM_COMPONENT_CODE = """from langflow import CustomComponent
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.chains import LLMChain
|
||||
from langchain import PromptTemplate
|
||||
from langchain.schema import Document
|
||||
from langflow.field_typing import (
|
||||
Tool,
|
||||
PromptTemplate,
|
||||
Chain,
|
||||
BaseChatMemory,
|
||||
BaseLLM,
|
||||
BaseLoader,
|
||||
BaseMemory,
|
||||
BaseOutputParser,
|
||||
BaseRetriever,
|
||||
VectorStore,
|
||||
Embeddings,
|
||||
TextSplitter,
|
||||
Document,
|
||||
AgentExecutor,
|
||||
NestedDict,
|
||||
Data,
|
||||
)
|
||||
|
||||
import requests
|
||||
|
||||
class YourComponent(CustomComponent):
|
||||
class Component(CustomComponent):
|
||||
display_name: str = "Custom Component"
|
||||
description: str = "Create any custom component you want!"
|
||||
|
||||
def build_config(self):
|
||||
return { "url": { "multiline": True, "required": True } }
|
||||
return {"param": {"display_name": "Parameter"}}
|
||||
|
||||
def build(self, param: Data) -> Data:
|
||||
return param
|
||||
|
||||
def build(self, url: str, llm: BaseLLM, prompt: PromptTemplate) -> Document:
|
||||
response = requests.get(url)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
result = chain.run(response.text[:300])
|
||||
return Document(page_content=str(result))
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.interface.custom.component import Component
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.services.getters import get_db_service
|
||||
from langflow.interface.custom.utils import extract_inner_type
|
||||
|
||||
from langflow.utils import validate
|
||||
|
||||
from langflow.database.base import session_getter
|
||||
from langflow.database.models.flow import Flow
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from pydantic import Extra
|
||||
import yaml
|
||||
|
||||
|
|
@ -21,6 +23,7 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
function: Optional[Callable] = None
|
||||
return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
|
||||
repr_value: Optional[Any] = ""
|
||||
user_id: Optional[Union[UUID, str]] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
|
@ -92,7 +95,23 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
|
||||
build_method = build_methods[0]
|
||||
|
||||
return build_method["args"]
|
||||
args = build_method["args"]
|
||||
for arg in args:
|
||||
if arg.get("type") == "prompt":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Type hint Error",
|
||||
"traceback": (
|
||||
"Prompt type is not supported in the build method."
|
||||
" Try using PromptTemplate instead."
|
||||
),
|
||||
},
|
||||
)
|
||||
elif not arg.get("type"):
|
||||
# Set the type to Data
|
||||
arg["type"] = "Data"
|
||||
return args
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_return_type(self) -> List[str]:
|
||||
|
|
@ -173,22 +192,29 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
|
||||
from langflow.processing.process import build_sorted_vertices_with_caching
|
||||
from langflow.processing.process import build_sorted_vertices
|
||||
from langflow.processing.process import process_tweaks
|
||||
|
||||
with session_getter() as session:
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
|
||||
if not graph_data:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
if tweaks:
|
||||
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
|
||||
return build_sorted_vertices_with_caching(graph_data)
|
||||
return build_sorted_vertices(graph_data)
|
||||
|
||||
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
|
||||
get_session = get_session or session_getter
|
||||
with get_session() as session:
|
||||
flows = session.query(Flow).all()
|
||||
return flows
|
||||
if not self.user_id:
|
||||
raise ValueError("Session is invalid")
|
||||
try:
|
||||
get_session = get_session or session_getter
|
||||
db_service = get_db_service()
|
||||
with get_session(db_service) as session:
|
||||
flows = session.query(Flow).filter(Flow.user_id == self.user_id).all()
|
||||
return flows
|
||||
except Exception as e:
|
||||
raise ValueError("Session is invalid") from e
|
||||
|
||||
def get_flow(
|
||||
self,
|
||||
|
|
@ -199,12 +225,16 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
get_session: Optional[Callable] = None,
|
||||
) -> Flow:
|
||||
get_session = get_session or session_getter
|
||||
|
||||
with get_session() as session:
|
||||
db_service = get_db_service()
|
||||
with get_session(db_service) as session:
|
||||
if flow_id:
|
||||
flow = session.query(Flow).get(flow_id)
|
||||
elif flow_name:
|
||||
flow = session.query(Flow).filter(Flow.name == flow_name).first()
|
||||
flow = (
|
||||
session.query(Flow)
|
||||
.filter(Flow.name == flow_name)
|
||||
.filter(Flow.user_id == self.user_id)
|
||||
).first()
|
||||
else:
|
||||
raise ValueError("Either flow_name or flow_id must be provided")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import ast
|
||||
import zlib
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CustomComponentPathValueError(ValueError):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.template.frontend_node.documentloaders import DocumentLoaderFrontNode
|
||||
from langflow.interface.custom_lists import documentloaders_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -30,10 +31,12 @@ class DocumentLoaderCreator(LangChainTypeCreator):
|
|||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
settings_service = get_settings_service()
|
||||
return [
|
||||
documentloader.__name__
|
||||
for documentloader in self.type_to_loader_dict.values()
|
||||
if documentloader.__name__ in settings.DOCUMENTLOADERS or settings.DEV
|
||||
if documentloader.__name__ in settings_service.settings.DOCUMENTLOADERS
|
||||
or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ from typing import Dict, List, Optional, Type
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import embedding_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.embeddings import EmbeddingFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -32,10 +33,12 @@ class EmbeddingCreator(LangChainTypeCreator):
|
|||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
settings_service = get_settings_service()
|
||||
return [
|
||||
embedding.__name__
|
||||
for embedding in self.type_to_loader_dict.values()
|
||||
if embedding.__name__ in settings.EMBEDDINGS or settings.DEV
|
||||
if embedding.__name__ in settings_service.settings.EMBEDDINGS
|
||||
or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import importlib
|
||||
from typing import Any, Type
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.agents import Agent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
|
|
@ -144,6 +144,8 @@ def import_chain(chain: str) -> Type[Chain]:
|
|||
|
||||
if chain in CUSTOM_CHAINS:
|
||||
return CUSTOM_CHAINS[chain]
|
||||
if chain == "SQLDatabaseChain":
|
||||
return import_class("langchain_experimental.sql.SQLDatabaseChain")
|
||||
return import_class(f"langchain.chains.{chain}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import orjson
|
||||
from typing import Any, Callable, Dict, Sequence, Type
|
||||
|
||||
from typing import Any, Callable, Dict, Sequence, Type, TYPE_CHECKING
|
||||
from langchain.schema import Document
|
||||
from langchain.agents import agent as agent_module
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
|
|
@ -34,13 +34,29 @@ from langflow.utils import validate
|
|||
from langchain.chains.base import Chain
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
||||
def build_vertex_in_params(params: Dict) -> Dict:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
# If any of the values in params is a Vertex, we will build it
|
||||
return {
|
||||
key: value.build() if isinstance(value, Vertex) else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
|
||||
|
||||
def instantiate_class(
|
||||
node_type: str, base_type: str, params: Dict, user_id=None
|
||||
) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_node := CUSTOM_NODES.get(node_type):
|
||||
if hasattr(custom_node, "initialize"):
|
||||
|
|
@ -48,7 +64,9 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
|||
return custom_node(**params)
|
||||
logger.debug(f"Instantiating {node_type} of type {base_type}")
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
return instantiate_based_on_type(
|
||||
class_object, base_type, node_type, params, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def convert_params_to_sets(params):
|
||||
|
|
@ -75,7 +93,7 @@ def convert_kwargs(params):
|
|||
return params
|
||||
|
||||
|
||||
def instantiate_based_on_type(class_object, base_type, node_type, params):
|
||||
def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(node_type, class_object, params)
|
||||
elif base_type == "prompts":
|
||||
|
|
@ -89,11 +107,11 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
|
|||
elif base_type == "toolkits":
|
||||
return instantiate_toolkit(node_type, class_object, params)
|
||||
elif base_type == "embeddings":
|
||||
return instantiate_embedding(class_object, params)
|
||||
return instantiate_embedding(node_type, class_object, params)
|
||||
elif base_type == "vectorstores":
|
||||
return instantiate_vectorstore(class_object, params)
|
||||
elif base_type == "documentloaders":
|
||||
return instantiate_documentloader(class_object, params)
|
||||
return instantiate_documentloader(node_type, class_object, params)
|
||||
elif base_type == "textsplitters":
|
||||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
|
|
@ -109,19 +127,19 @@ def instantiate_based_on_type(class_object, base_type, node_type, params):
|
|||
elif base_type == "memory":
|
||||
return instantiate_memory(node_type, class_object, params)
|
||||
elif base_type == "custom_components":
|
||||
return instantiate_custom_component(node_type, class_object, params)
|
||||
return instantiate_custom_component(node_type, class_object, params, user_id)
|
||||
elif base_type == "wrappers":
|
||||
return instantiate_wrapper(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_custom_component(node_type, class_object, params):
|
||||
def instantiate_custom_component(node_type, class_object, params, user_id):
|
||||
# we need to make a copy of the params because we will be
|
||||
# modifying it
|
||||
params_copy = params.copy()
|
||||
class_object = get_function_custom(params_copy.pop("code"))
|
||||
custom_component = class_object()
|
||||
class_object: "CustomComponent" = get_function_custom(params_copy.pop("code"))
|
||||
custom_component = class_object(user_id=user_id)
|
||||
built_object = custom_component.build(**params_copy)
|
||||
return built_object, {"repr": custom_component.custom_repr()}
|
||||
|
||||
|
|
@ -148,7 +166,7 @@ def instantiate_llm(node_type, class_object, params: Dict):
|
|||
# This is a workaround so JinaChat works until streaming is implemented
|
||||
# if "openai_api_base" in params and "jina" in params["openai_api_base"]:
|
||||
# False if condition is True
|
||||
if node_type == "VertexAI":
|
||||
if "VertexAI" in node_type:
|
||||
return initialize_vertexai(class_object=class_object, params=params)
|
||||
# max_tokens sometimes is a string and should be an int
|
||||
if "max_tokens" in params:
|
||||
|
|
@ -262,9 +280,13 @@ def instantiate_toolkit(node_type, class_object: Type[BaseToolkit], params: Dict
|
|||
return loaded_toolkit
|
||||
|
||||
|
||||
def instantiate_embedding(class_object, params: Dict):
|
||||
def instantiate_embedding(node_type, class_object, params: Dict):
|
||||
params.pop("model", None)
|
||||
params.pop("headers", None)
|
||||
|
||||
if "VertexAI" in node_type:
|
||||
return initialize_vertexai(class_object=class_object, params=params)
|
||||
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
|
|
@ -278,6 +300,15 @@ def instantiate_embedding(class_object, params: Dict):
|
|||
|
||||
def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
|
||||
search_kwargs = params.pop("search_kwargs", {})
|
||||
if search_kwargs == {"yourkey": "value"}:
|
||||
search_kwargs = {}
|
||||
# clean up docs or texts to have only documents
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
if "documents" in params:
|
||||
params["documents"] = [
|
||||
doc for doc in params["documents"] if isinstance(doc, Document)
|
||||
]
|
||||
if initializer := vecstore_initializer.get(class_object.__name__):
|
||||
vecstore = initializer(class_object, params)
|
||||
else:
|
||||
|
|
@ -292,7 +323,9 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
|
|||
return vecstore
|
||||
|
||||
|
||||
def instantiate_documentloader(class_object: Type[BaseLoader], params: Dict):
|
||||
def instantiate_documentloader(
|
||||
node_type: str, class_object: Type[BaseLoader], params: Dict
|
||||
):
|
||||
if "file_filter" in params:
|
||||
# file_filter will be a string but we need a function
|
||||
# that will be used to filter the files using file_filter
|
||||
|
|
@ -312,6 +345,11 @@ def instantiate_documentloader(class_object: Type[BaseLoader], params: Dict):
|
|||
raise ValueError(
|
||||
"The metadata you provided is not a valid JSON string."
|
||||
) from exc
|
||||
|
||||
if node_type == "WebBaseLoader":
|
||||
if web_path := params.pop("web_path", None):
|
||||
params["web_paths"] = [web_path]
|
||||
|
||||
docs = class_object(**params).load()
|
||||
# Now if metadata is an empty dict, we will not add it to the documents
|
||||
if metadata:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import contextlib
|
||||
import json
|
||||
from langflow.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
import orjson
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from langchain.vectorstores import (
|
|||
SupabaseVectorStore,
|
||||
MongoDBAtlasVectorSearch,
|
||||
)
|
||||
|
||||
from langchain.schema import Document
|
||||
import os
|
||||
|
||||
import orjson
|
||||
|
|
@ -201,11 +201,16 @@ def initialize_chroma(class_object: Type[Chroma], params: dict):
|
|||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
for doc in params["documents"]:
|
||||
if not isinstance(doc, Document):
|
||||
# remove any non-Document objects from the list
|
||||
params["documents"].remove(doc)
|
||||
continue
|
||||
if doc.metadata is None:
|
||||
doc.metadata = {}
|
||||
for key, value in doc.metadata.items():
|
||||
if value is None:
|
||||
doc.metadata[key] = ""
|
||||
|
||||
chromadb = class_object.from_documents(**params)
|
||||
if persist:
|
||||
chromadb.persist()
|
||||
|
|
|
|||
|
|
@ -1,47 +1,27 @@
|
|||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.document_loaders.base import documentloader_creator
|
||||
from langflow.interface.embeddings.base import embedding_creator
|
||||
from langflow.interface.llms.base import llm_creator
|
||||
from langflow.interface.memories.base import memory_creator
|
||||
from langflow.interface.prompts.base import prompt_creator
|
||||
from langflow.interface.text_splitters.base import textsplitter_creator
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.tools.base import tool_creator
|
||||
from langflow.interface.utilities.base import utility_creator
|
||||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.utils.lazy_load import LazyLoadDictBase
|
||||
|
||||
|
||||
def get_type_dict():
|
||||
return {
|
||||
"agents": agent_creator.to_list(),
|
||||
"prompts": prompt_creator.to_list(),
|
||||
"llms": llm_creator.to_list(),
|
||||
"tools": tool_creator.to_list(),
|
||||
"chains": chain_creator.to_list(),
|
||||
"memory": memory_creator.to_list(),
|
||||
"toolkits": toolkits_creator.to_list(),
|
||||
"wrappers": wrapper_creator.to_list(),
|
||||
"documentLoaders": documentloader_creator.to_list(),
|
||||
"vectorStore": vectorstore_creator.to_list(),
|
||||
"embeddings": embedding_creator.to_list(),
|
||||
"textSplitters": textsplitter_creator.to_list(),
|
||||
"utilities": utility_creator.to_list(),
|
||||
"outputParsers": output_parser_creator.to_list(),
|
||||
"retrievers": retriever_creator.to_list(),
|
||||
"custom_components": custom_component_creator.to_list(),
|
||||
}
|
||||
class AllTypesDict(LazyLoadDictBase):
|
||||
def __init__(self):
|
||||
self._all_types_dict = None
|
||||
|
||||
@property
|
||||
def ALL_TYPES_DICT(self):
|
||||
return self.all_types_dict
|
||||
|
||||
def _build_dict(self):
|
||||
langchain_types_dict = self.get_type_dict()
|
||||
return {
|
||||
**langchain_types_dict,
|
||||
"Custom": ["Custom Tool", "Python Function"],
|
||||
}
|
||||
|
||||
def get_type_dict(self):
|
||||
from langflow.interface.types import get_all_types_dict
|
||||
|
||||
settings_service = get_settings_service()
|
||||
return get_all_types_dict(settings_service=settings_service)
|
||||
|
||||
|
||||
LANGCHAIN_TYPES_DICT = get_type_dict()
|
||||
|
||||
# Now we'll build a dict with Langchain types and ours
|
||||
|
||||
ALL_TYPES_DICT = {
|
||||
**LANGCHAIN_TYPES_DICT,
|
||||
"Custom": ["Custom Tool", "Python Function"],
|
||||
}
|
||||
lazy_load_dict = AllTypesDict()
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@ from typing import Dict, List, Optional, Type
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import llm_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.llms import LLMFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -33,10 +34,12 @@ class LLMCreator(LangChainTypeCreator):
|
|||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
settings_service = get_settings_service()
|
||||
return [
|
||||
llm.__name__
|
||||
for llm in self.type_to_loader_dict.values()
|
||||
if llm.__name__ in settings.LLMS or settings.DEV
|
||||
if llm.__name__ in settings_service.settings.LLMS
|
||||
or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ from typing import Dict, List, Optional, Type
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import memory_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.memories import MemoryFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
|
||||
|
|
@ -48,10 +49,12 @@ class MemoryCreator(LangChainTypeCreator):
|
|||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
settings_service = get_settings_service()
|
||||
return [
|
||||
memory.__name__
|
||||
for memory in self.type_to_loader_dict.values()
|
||||
if memory.__name__ in settings.MEMORIES or settings.DEV
|
||||
if memory.__name__ in settings_service.settings.MEMORIES
|
||||
or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from langchain import output_parsers
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.output_parsers import OutputParserFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
|
||||
|
||||
|
|
@ -23,6 +24,7 @@ class OutputParserCreator(LangChainTypeCreator):
|
|||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
settings_service = get_settings_service()
|
||||
self.type_dict = {
|
||||
output_parser_name: import_class(
|
||||
f"langchain.output_parsers.{output_parser_name}"
|
||||
|
|
@ -33,7 +35,8 @@ class OutputParserCreator(LangChainTypeCreator):
|
|||
self.type_dict = {
|
||||
name: output_parser
|
||||
for name, output_parser in self.type_dict.items()
|
||||
if name in settings.OUTPUT_PARSERS or settings.DEV
|
||||
if name in settings_service.settings.OUTPUT_PARSERS
|
||||
or settings_service.settings.DEV
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@ from langchain import prompts
|
|||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.prompts import PromptFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class PromptCreator(LangChainTypeCreator):
|
|||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
settings_service = get_settings_service()
|
||||
if self.type_dict is None:
|
||||
self.type_dict = {
|
||||
prompt_name: import_class(f"langchain.prompts.{prompt_name}")
|
||||
|
|
@ -34,7 +36,8 @@ class PromptCreator(LangChainTypeCreator):
|
|||
self.type_dict = {
|
||||
name: prompt
|
||||
for name, prompt in self.type_dict.items()
|
||||
if name in settings.PROMPTS or settings.DEV
|
||||
if name in settings_service.settings.PROMPTS
|
||||
or settings_service.settings.DEV
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from langchain import retrievers
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.retrievers import RetrieverFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_method, build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -48,10 +49,12 @@ class RetrieverCreator(LangChainTypeCreator):
|
|||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
settings_service = get_settings_service()
|
||||
return [
|
||||
retriever
|
||||
for retriever in self.type_to_loader_dict.keys()
|
||||
if retriever in settings.RETRIEVERS or settings.DEV
|
||||
if retriever in settings_service.settings.RETRIEVERS
|
||||
or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,9 @@
|
|||
from langflow.cache.utils import memoize_dict
|
||||
from typing import Dict, Tuple
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@memoize_dict(maxsize=10)
|
||||
def build_langchain_object_with_caching(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
graph = Graph.from_payload(data_graph)
|
||||
return graph.build()
|
||||
|
||||
|
||||
@memoize_dict(maxsize=10)
|
||||
def build_sorted_vertices_with_caching(data_graph):
|
||||
def build_sorted_vertices(data_graph) -> Tuple[Graph, Dict]:
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
|
@ -28,7 +16,7 @@ def build_sorted_vertices_with_caching(data_graph):
|
|||
vertex.build()
|
||||
if vertex.artifacts:
|
||||
artifacts.update(vertex.artifacts)
|
||||
return graph.build(), artifacts
|
||||
return graph, artifacts
|
||||
|
||||
|
||||
def build_langchain_object(data_graph):
|
||||
|
|
@ -57,8 +45,12 @@ def get_memory_key(langchain_object):
|
|||
"chat_history": "history",
|
||||
"history": "chat_history",
|
||||
}
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
return mem_key_dict.get(memory_key)
|
||||
# Check if memory_key attribute exists
|
||||
if hasattr(langchain_object.memory, "memory_key"):
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
return mem_key_dict.get(memory_key)
|
||||
else:
|
||||
return None # or some other default value or action
|
||||
|
||||
|
||||
def update_memory_keys(langchain_object, possible_new_mem_key):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.template.frontend_node.textsplitters import TextSplittersFrontendNode
|
||||
from langflow.interface.custom_lists import textsplitter_type_to_cls_dict
|
||||
from langflow.settings import settings
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -30,10 +31,12 @@ class TextSplitterCreator(LangChainTypeCreator):
|
|||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
settings_service = get_settings_service()
|
||||
return [
|
||||
textsplitter.__name__
|
||||
for textsplitter in self.type_to_loader_dict.values()
|
||||
if textsplitter.__name__ in settings.TEXTSPLITTERS or settings.DEV
|
||||
if textsplitter.__name__ in settings_service.settings.TEXTSPLITTERS
|
||||
or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ from langchain.agents import agent_toolkits
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class, import_module
|
||||
from langflow.settings import settings
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -29,13 +30,15 @@ class ToolkitCreator(LangChainTypeCreator):
|
|||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
settings_service = get_settings_service()
|
||||
self.type_dict = {
|
||||
toolkit_name: import_class(
|
||||
f"langchain.agents.agent_toolkits.{toolkit_name}"
|
||||
)
|
||||
# if toolkit_name is not lower case it is a class
|
||||
for toolkit_name in agent_toolkits.__all__
|
||||
if not toolkit_name.islower() and toolkit_name in settings.TOOLKITS
|
||||
if not toolkit_name.islower()
|
||||
and toolkit_name in settings_service.settings.TOOLKITS
|
||||
}
|
||||
|
||||
return self.type_dict
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from langflow.interface.tools.constants import (
|
|||
OTHER_TOOLS,
|
||||
)
|
||||
from langflow.interface.tools.util import get_tool_params
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils import util
|
||||
|
|
@ -66,6 +67,7 @@ class ToolCreator(LangChainTypeCreator):
|
|||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
settings_service = get_settings_service()
|
||||
if self.tools_dict is None:
|
||||
all_tools = {}
|
||||
|
||||
|
|
@ -74,7 +76,10 @@ class ToolCreator(LangChainTypeCreator):
|
|||
|
||||
tool_name = tool_params.get("name") or tool
|
||||
|
||||
if tool_name in settings.TOOLS or settings.DEV:
|
||||
if (
|
||||
tool_name in settings_service.settings.TOOLS
|
||||
or settings_service.settings.DEV
|
||||
):
|
||||
if tool_name == "JsonSpec":
|
||||
tool_params["path"] = tool_params.pop("dict_") # type: ignore
|
||||
all_tools[tool_name] = {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import ast
|
||||
import contextlib
|
||||
from typing import Any, List
|
||||
from langflow.api.utils import merge_nested_dicts_with_renaming
|
||||
from langflow.api.utils import get_new_key
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.interface.custom.utils import extract_inner_type
|
||||
from langflow.interface.document_loaders.base import documentloader_creator
|
||||
from langflow.interface.embeddings.base import embedding_creator
|
||||
|
|
@ -30,7 +30,7 @@ from langflow.template.frontend_node.custom_components import (
|
|||
from langflow.interface.retrievers.base import retriever_creator
|
||||
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import get_base_classes
|
||||
|
||||
import re
|
||||
|
|
@ -288,6 +288,24 @@ def add_base_classes(frontend_node, return_types: List[str]):
|
|||
frontend_node.get("base_classes").append(base_class)
|
||||
|
||||
|
||||
def add_output_types(frontend_node, return_types: List[str]):
|
||||
"""Add output types to the frontend node"""
|
||||
for return_type in return_types:
|
||||
if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": (
|
||||
"Invalid return type should be one of: "
|
||||
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
|
||||
),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
|
||||
frontend_node.get("output_types").append(return_type)
|
||||
|
||||
|
||||
def build_langchain_template_custom_component(custom_component: CustomComponent):
|
||||
"""Build a custom component template for the langchain"""
|
||||
try:
|
||||
|
|
@ -303,9 +321,9 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
|
|||
logger.debug("Updated attributes")
|
||||
field_config = build_field_config(custom_component)
|
||||
logger.debug("Built field config")
|
||||
add_extra_fields(
|
||||
frontend_node, field_config, custom_component.get_function_entrypoint_args
|
||||
)
|
||||
entrypoint_args = custom_component.get_function_entrypoint_args
|
||||
|
||||
add_extra_fields(frontend_node, field_config, entrypoint_args)
|
||||
logger.debug("Added extra fields")
|
||||
frontend_node = add_code_field(
|
||||
frontend_node, custom_component.code, field_config.get("code", {})
|
||||
|
|
@ -314,9 +332,14 @@ def build_langchain_template_custom_component(custom_component: CustomComponent)
|
|||
add_base_classes(
|
||||
frontend_node, custom_component.get_function_entrypoint_return_type
|
||||
)
|
||||
add_output_types(
|
||||
frontend_node, custom_component.get_function_entrypoint_return_type
|
||||
)
|
||||
logger.debug("Added base classes")
|
||||
return frontend_node
|
||||
except Exception as exc:
|
||||
if isinstance(exc, HTTPException):
|
||||
raise exc
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
|
@ -433,6 +456,24 @@ def build_invalid_menu(invalid_components):
|
|||
return invalid_menu
|
||||
|
||||
|
||||
def merge_nested_dicts_with_renaming(dict1, dict2):
|
||||
for key, value in dict2.items():
|
||||
if (
|
||||
key in dict1
|
||||
and isinstance(value, dict)
|
||||
and isinstance(dict1.get(key), dict)
|
||||
):
|
||||
for sub_key, sub_value in value.items():
|
||||
if sub_key in dict1[key]:
|
||||
new_key = get_new_key(dict1[key], sub_key)
|
||||
dict1[key][new_key] = sub_value
|
||||
else:
|
||||
dict1[key][sub_key] = sub_value
|
||||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
||||
|
||||
def build_langchain_custom_component_list_from_path(path: str):
|
||||
"""Build a list of custom components for the langchain from a given path"""
|
||||
file_list = load_files_from_path(path)
|
||||
|
|
@ -446,3 +487,51 @@ def build_langchain_custom_component_list_from_path(path: str):
|
|||
invalid_menu = build_invalid_menu(invalid_components)
|
||||
|
||||
return merge_nested_dicts_with_renaming(valid_menu, invalid_menu)
|
||||
|
||||
|
||||
def get_all_types_dict(settings_service):
|
||||
native_components = build_langchain_types_dict()
|
||||
# custom_components is a list of dicts
|
||||
# need to merge all the keys into one dict
|
||||
custom_components_from_file: dict[str, Any] = {}
|
||||
if settings_service.settings.COMPONENTS_PATH:
|
||||
logger.info(
|
||||
f"Building custom components from {settings_service.settings.COMPONENTS_PATH}"
|
||||
)
|
||||
|
||||
custom_component_dicts = []
|
||||
processed_paths = []
|
||||
for path in settings_service.settings.COMPONENTS_PATH:
|
||||
if str(path) in processed_paths:
|
||||
continue
|
||||
custom_component_dict = build_langchain_custom_component_list_from_path(
|
||||
str(path)
|
||||
)
|
||||
custom_component_dicts.append(custom_component_dict)
|
||||
processed_paths.append(str(path))
|
||||
|
||||
logger.info(f"Loading {len(custom_component_dicts)} category(ies)")
|
||||
for custom_component_dict in custom_component_dicts:
|
||||
# custom_component_dict is a dict of dicts
|
||||
if not custom_component_dict:
|
||||
continue
|
||||
category = list(custom_component_dict.keys())[0]
|
||||
logger.info(
|
||||
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
|
||||
)
|
||||
custom_components_from_file = merge_nested_dicts_with_renaming(
|
||||
custom_components_from_file, custom_component_dict
|
||||
)
|
||||
|
||||
return merge_nested_dicts_with_renaming(
|
||||
native_components, custom_components_from_file
|
||||
)
|
||||
|
||||
|
||||
def merge_nested_dicts(dict1, dict2):
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict) and isinstance(dict1.get(key), dict):
|
||||
dict1[key] = merge_nested_dicts(dict1[key], value)
|
||||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain import SQLDatabase, utilities
|
||||
from langchain import utilities
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.utilities import UtilitiesFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -26,16 +27,18 @@ class UtilityCreator(LangChainTypeCreator):
|
|||
from the langchain.chains module and filtering them according to the settings.utilities list.
|
||||
"""
|
||||
if self.type_dict is None:
|
||||
settings_service = get_settings_service()
|
||||
self.type_dict = {
|
||||
utility_name: import_class(f"langchain.utilities.{utility_name}")
|
||||
for utility_name in utilities.__all__
|
||||
}
|
||||
self.type_dict["SQLDatabase"] = SQLDatabase
|
||||
self.type_dict["SQLDatabase"] = utilities.SQLDatabase
|
||||
# Filter according to settings.utilities
|
||||
self.type_dict = {
|
||||
name: utility
|
||||
for name, utility in self.type_dict.items()
|
||||
if name in settings.UTILITIES or settings.DEV
|
||||
if name in settings_service.settings.UTILITIES
|
||||
or settings_service.settings.DEV
|
||||
}
|
||||
|
||||
return self.type_dict
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import re
|
|||
import yaml
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from PIL.Image import Image
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.chat.config import ChatConfig
|
||||
from loguru import logger
|
||||
from langflow.services.chat.config import ChatConfig
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
|
||||
def load_file_into_dict(file_path: str) -> dict:
|
||||
|
|
@ -35,7 +36,7 @@ def pil_to_base64(image: Image) -> str:
|
|||
return img_str.decode("utf-8")
|
||||
|
||||
|
||||
def try_setting_streaming_options(langchain_object, websocket):
|
||||
def try_setting_streaming_options(langchain_object):
|
||||
# If the LLM type is OpenAI or ChatOpenAI,
|
||||
# set streaming to True
|
||||
# First we need to find the LLM
|
||||
|
|
@ -63,13 +64,11 @@ def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
|||
|
||||
def setup_llm_caching():
|
||||
"""Setup LLM caching."""
|
||||
|
||||
from langflow.settings import settings
|
||||
|
||||
settings_service = get_settings_service()
|
||||
try:
|
||||
set_langchain_cache(settings)
|
||||
set_langchain_cache(settings_service.settings)
|
||||
except ImportError:
|
||||
logger.warning(f"Could not import {settings.CACHE}. ")
|
||||
logger.warning(f"Could not import {settings_service.settings.CACHE_TYPE}. ")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Could not setup LLM caching. Error: {exc}")
|
||||
|
||||
|
|
@ -81,7 +80,7 @@ def set_langchain_cache(settings):
|
|||
if cache_type := os.getenv("LANGFLOW_LANGCHAIN_CACHE"):
|
||||
try:
|
||||
cache_class = import_class(
|
||||
f"langchain.cache.{cache_type or settings.CACHE}"
|
||||
f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}"
|
||||
)
|
||||
|
||||
logger.debug(f"Setting up LLM caching with {cache_class.__name__}")
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from langchain import vectorstores
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.vectorstores import VectorStoreFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_method
|
||||
|
||||
|
||||
|
|
@ -43,10 +44,12 @@ class VectorstoreCreator(LangChainTypeCreator):
|
|||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
settings_service = get_settings_service()
|
||||
return [
|
||||
vectorstore
|
||||
for vectorstore in self.type_to_loader_dict.keys()
|
||||
if vectorstore in settings.VECTORSTORES or settings.DEV
|
||||
if vectorstore in settings_service.settings.VECTORSTORES
|
||||
or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain import requests, sql_database
|
||||
from langchain.utilities import requests, sql_database
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
instance: C4
|
||||
autoscale_min: 1
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# This file is used by lc-serve to load the mounted app and serve it.
|
||||
|
||||
import os
|
||||
|
||||
# Use the JCLOUD_WORKSPACE for db URL if it's provided by JCloud.
|
||||
if "JCLOUD_WORKSPACE" in os.environ:
|
||||
os.environ[
|
||||
"LANGFLOW_DATABASE_URL"
|
||||
] = f"sqlite:///{os.environ['JCLOUD_WORKSPACE']}/langflow.db"
|
||||
|
||||
from langflow.main import setup_app
|
||||
from langflow.utils.logger import configure
|
||||
|
||||
configure(log_level="DEBUG")
|
||||
app = setup_app()
|
||||
|
|
@ -6,25 +6,25 @@ from fastapi.responses import FileResponse
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from langflow.api import router
|
||||
from langflow.database.base import create_db_and_tables, Engine
|
||||
|
||||
|
||||
from langflow.interface.utils import setup_llm_caching
|
||||
from langflow.services.utils import initialize_services
|
||||
from langflow.services.plugins.langfuse import LangfuseInstance
|
||||
from langflow.services.utils import (
|
||||
teardown_services,
|
||||
)
|
||||
from langflow.utils.logger import configure
|
||||
|
||||
|
||||
def create_app():
|
||||
"""Create the FastAPI app and include the router."""
|
||||
|
||||
configure()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = [
|
||||
"*",
|
||||
]
|
||||
|
||||
@app.get("/health")
|
||||
def get_health():
|
||||
return {"status": "OK"}
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
|
@ -34,13 +34,19 @@ def create_app():
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
app.include_router(router)
|
||||
app.on_event("startup")(Engine.update)
|
||||
app.on_event("startup")(create_db_and_tables)
|
||||
|
||||
app.on_event("startup")(initialize_services)
|
||||
app.on_event("startup")(setup_llm_caching)
|
||||
app.on_event("startup")(LangfuseInstance.update)
|
||||
app.on_event("shutdown")(Engine.teardown)
|
||||
|
||||
app.on_event("shutdown")(teardown_services)
|
||||
app.on_event("shutdown")(LangfuseInstance.teardown)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
@ -72,22 +78,25 @@ def get_static_files_dir():
|
|||
return frontend_path / "frontend"
|
||||
|
||||
|
||||
def setup_app(static_files_dir: Optional[Path] = None) -> FastAPI:
|
||||
def setup_app(
|
||||
static_files_dir: Optional[Path] = None, backend_only: bool = False
|
||||
) -> FastAPI:
|
||||
"""Setup the FastAPI app."""
|
||||
# get the directory of the current file
|
||||
if not static_files_dir:
|
||||
static_files_dir = get_static_files_dir()
|
||||
|
||||
if not static_files_dir or not static_files_dir.exists():
|
||||
if not backend_only and (not static_files_dir or not static_files_dir.exists()):
|
||||
raise RuntimeError(f"Static files directory {static_files_dir} does not exist.")
|
||||
app = create_app()
|
||||
setup_static_files(app, static_files_dir)
|
||||
if not backend_only and static_files_dir is not None:
|
||||
setup_static_files(app, static_files_dir)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from langflow.utils.util import get_number_of_workers
|
||||
from langflow.__main__ import get_number_of_workers
|
||||
|
||||
configure()
|
||||
uvicorn.run(
|
||||
|
|
|
|||
|
|
@ -4,8 +4,7 @@ from langflow.api.v1.callback import (
|
|||
StreamingLLMCallbackHandler,
|
||||
)
|
||||
from langflow.processing.process import fix_memory_inputs, format_actions
|
||||
|
||||
from langflow.utils.logger import logger
|
||||
from loguru import logger
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
|
|
@ -35,7 +34,9 @@ def get_langfuse_callback(trace_id):
|
|||
if langfuse := LangfuseInstance.get():
|
||||
logger.debug("Langfuse credentials found")
|
||||
try:
|
||||
trace = langfuse.trace(CreateTrace(id=trace_id))
|
||||
trace = langfuse.trace(
|
||||
CreateTrace(name="langflow-" + trace_id, id=trace_id)
|
||||
)
|
||||
return trace.getNewHandler()
|
||||
except Exception as exc:
|
||||
logger.error(f"Error initializing langfuse callback: {exc}")
|
||||
|
|
|
|||
|
|
@ -2,17 +2,20 @@ import json
|
|||
from pathlib import Path
|
||||
from langchain.schema import AgentAction
|
||||
from langflow.interface.run import (
|
||||
build_sorted_vertices_with_caching,
|
||||
build_sorted_vertices,
|
||||
get_memory_key,
|
||||
update_memory_keys,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.services.getters import get_session_service
|
||||
from loguru import logger
|
||||
from langflow.graph import Graph
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from langchain.schema import Document
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
|
|
@ -65,7 +68,7 @@ def get_result_and_thought(langchain_object: Any, inputs: dict):
|
|||
langchain_object.verbose = True
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
langchain_object.return_intermediate_steps = True
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
|
|
@ -86,44 +89,54 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]:
|
|||
return list(inputs.values())[0] if len(inputs) == 1 else None
|
||||
|
||||
|
||||
def process_graph_cached(
|
||||
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
|
||||
):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
if clear_cache:
|
||||
build_sorted_vertices_with_caching.clear_cache()
|
||||
logger.debug("Cleared cache")
|
||||
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
|
||||
logger.debug("Loaded LangChain object")
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
def get_build_result(data_graph, session_id):
|
||||
# If session_id is provided, load the langchain_object from the session
|
||||
# using build_sorted_vertices_with_caching.get_result_by_session_id
|
||||
# if it returns something different than None, return it
|
||||
# otherwise, build the graph and return the result
|
||||
if session_id:
|
||||
logger.debug(f"Loading LangChain object from session {session_id}")
|
||||
result = build_sorted_vertices(data_graph=data_graph)
|
||||
if result is not None:
|
||||
logger.debug("Loaded LangChain object")
|
||||
return result
|
||||
|
||||
# Add artifacts to inputs
|
||||
# artifacts can be documents loaded when building
|
||||
# the flow
|
||||
for (
|
||||
key,
|
||||
value,
|
||||
) in artifacts.items():
|
||||
if key not in inputs or not inputs[key]:
|
||||
inputs[key] = value
|
||||
logger.debug("Building langchain object")
|
||||
return build_sorted_vertices(data_graph)
|
||||
|
||||
|
||||
def load_langchain_object(
|
||||
data_graph: Dict[str, Any], session_id: str
|
||||
) -> Tuple[Union[Chain, VectorStore], Dict[str, Any], str]:
|
||||
langchain_object, artifacts = get_build_result(data_graph, session_id)
|
||||
logger.debug("Loaded LangChain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
return langchain_object, artifacts, session_id
|
||||
|
||||
|
||||
def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
|
||||
for key, value in artifacts.items():
|
||||
if key not in inputs or not inputs[key]:
|
||||
inputs[key] = value
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
||||
if isinstance(langchain_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
logger.debug("Generating result and thought")
|
||||
result = get_result_and_thought(langchain_object, inputs)
|
||||
|
||||
logger.debug("Generated result and thought")
|
||||
elif isinstance(langchain_object, VectorStore):
|
||||
result = langchain_object.search(**inputs)
|
||||
|
|
@ -136,6 +149,36 @@ def process_graph_cached(
|
|||
return result
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
result: Any
|
||||
session_id: str
|
||||
|
||||
|
||||
async def process_graph_cached(
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[dict] = None,
|
||||
clear_cache=False,
|
||||
session_id=None,
|
||||
) -> Result:
|
||||
session_service = get_session_service()
|
||||
if clear_cache:
|
||||
session_service.clear_session(session_id)
|
||||
if session_id is None:
|
||||
session_id = session_service.generate_key(
|
||||
session_id=session_id, data_graph=data_graph
|
||||
)
|
||||
# Load the graph using SessionService
|
||||
graph, artifacts = session_service.load_session(session_id, data_graph)
|
||||
built_object = graph.build()
|
||||
processed_inputs = process_inputs(inputs, artifacts)
|
||||
result = generate_result(built_object, processed_inputs)
|
||||
# langchain_object is now updated with the new memory
|
||||
# we need to update the cache with the updated langchain_object
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
|
||||
return Result(result=result, session_id=session_id)
|
||||
|
||||
|
||||
def load_flow_from_json(
|
||||
flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True
|
||||
):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from gunicorn.app.base import BaseApplication # type: ignore
|
|||
class LangflowApplication(BaseApplication):
|
||||
def __init__(self, app, options=None):
|
||||
self.options = options or {}
|
||||
|
||||
self.options["worker_class"] = "uvicorn.workers.UvicornWorker"
|
||||
self.application = app
|
||||
super().__init__()
|
||||
|
||||
|
|
|
|||
4
src/backend/langflow/services/__init__.py
Normal file
4
src/backend/langflow/services/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .manager import service_manager
|
||||
from .schema import ServiceType
|
||||
|
||||
__all__ = ["service_manager", "ServiceType"]
|
||||
12
src/backend/langflow/services/auth/factory.py
Normal file
12
src/backend/langflow/services/auth/factory.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from langflow.services.factory import ServiceFactory
|
||||
from langflow.services.auth.service import AuthService
|
||||
|
||||
|
||||
class AuthServiceFactory(ServiceFactory):
|
||||
name = "auth_service"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(AuthService)
|
||||
|
||||
def create(self, settings_service):
|
||||
return AuthService(settings_service)
|
||||
12
src/backend/langflow/services/auth/service.py
Normal file
12
src/backend/langflow/services/auth/service.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from langflow.services.base import Service
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.settings.manager import SettingsService
|
||||
|
||||
|
||||
class AuthService(Service):
|
||||
name = "auth_service"
|
||||
|
||||
def __init__(self, settings_service: "SettingsService"):
|
||||
self.settings_service = settings_service
|
||||
296
src/backend/langflow/services/auth/utils.py
Normal file
296
src/backend/langflow/services/auth/utils.py
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import Depends, HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from typing import Annotated, Coroutine, Optional, Union
|
||||
from uuid import UUID
|
||||
from langflow.services.database.models.api_key.api_key import ApiKey
|
||||
from langflow.services.database.models.api_key.crud import check_key
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.database.models.user.crud import (
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
update_user_last_login_at,
|
||||
)
|
||||
from langflow.services.getters import get_session, get_settings_service
|
||||
from sqlmodel import Session
|
||||
|
||||
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login")
|
||||
|
||||
API_KEY_NAME = "x-api-key"
|
||||
|
||||
api_key_query = APIKeyQuery(
|
||||
name=API_KEY_NAME, scheme_name="API key query", auto_error=False
|
||||
)
|
||||
api_key_header = APIKeyHeader(
|
||||
name=API_KEY_NAME, scheme_name="API key header", auto_error=False
|
||||
)
|
||||
|
||||
|
||||
# Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py
|
||||
async def api_key_security(
|
||||
query_param: str = Security(api_key_query),
|
||||
header_param: str = Security(api_key_header),
|
||||
db: Session = Depends(get_session),
|
||||
) -> Optional[User]:
|
||||
settings_service = get_settings_service()
|
||||
result: Optional[Union[ApiKey, User]] = None
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
# Get the first user
|
||||
if not settings_service.auth_settings.SUPERUSER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing first superuser credentials",
|
||||
)
|
||||
|
||||
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
|
||||
|
||||
elif not query_param and not header_param:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="An API key must be passed as query or header",
|
||||
)
|
||||
|
||||
elif query_param:
|
||||
result = check_key(db, query_param)
|
||||
|
||||
else:
|
||||
result = check_key(db, header_param)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
if isinstance(result, ApiKey):
|
||||
return result.user
|
||||
elif isinstance(result, User):
|
||||
return result
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_login)],
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
settings_service = get_settings_service()
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if isinstance(token, Coroutine):
|
||||
token = await token
|
||||
|
||||
if settings_service.auth_settings.SECRET_KEY is None:
|
||||
raise credentials_exception
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings_service.auth_settings.SECRET_KEY,
|
||||
algorithms=[settings_service.auth_settings.ALGORITHM],
|
||||
)
|
||||
user_id: UUID = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
if expires := payload.get("exp", None):
|
||||
expires_datetime = datetime.fromtimestamp(expires, timezone.utc)
|
||||
# TypeError: can't compare offset-naive and offset-aware datetimes
|
||||
if datetime.now(timezone.utc) > expires_datetime:
|
||||
raise credentials_exception
|
||||
|
||||
if user_id is None or token_type:
|
||||
raise credentials_exception
|
||||
except JWTError as e:
|
||||
raise credentials_exception from e
|
||||
|
||||
user = get_user_by_id(db, user_id) # type: ignore
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_active_superuser(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
) -> User:
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=401, detail="Inactive user")
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="The user doesn't have enough privileges"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
settings_service = get_settings_service()
|
||||
return settings_service.auth_settings.pwd_context.verify(
|
||||
plain_password, hashed_password
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
settings_service = get_settings_service()
|
||||
return settings_service.auth_settings.pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_token(data: dict, expires_delta: timedelta):
|
||||
settings_service = get_settings_service()
|
||||
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode["exp"] = expire
|
||||
|
||||
return jwt.encode(
|
||||
to_encode,
|
||||
settings_service.auth_settings.SECRET_KEY,
|
||||
algorithm=settings_service.auth_settings.ALGORITHM,
|
||||
)
|
||||
|
||||
|
||||
def create_super_user(
|
||||
username: str,
|
||||
password: str,
|
||||
db: Session = Depends(get_session),
|
||||
) -> User:
|
||||
super_user = get_user_by_username(db, username)
|
||||
|
||||
if not super_user:
|
||||
super_user = User(
|
||||
username=username,
|
||||
password=get_password_hash(password),
|
||||
is_superuser=True,
|
||||
is_active=True,
|
||||
last_login_at=None,
|
||||
)
|
||||
|
||||
db.add(super_user)
|
||||
db.commit()
|
||||
db.refresh(super_user)
|
||||
|
||||
return super_user
|
||||
|
||||
|
||||
def create_user_longterm_token(db: Session = Depends(get_session)) -> dict:
|
||||
settings_service = get_settings_service()
|
||||
username = settings_service.auth_settings.SUPERUSER
|
||||
password = settings_service.auth_settings.SUPERUSER_PASSWORD
|
||||
if not username or not password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing first superuser credentials",
|
||||
)
|
||||
super_user = create_super_user(db=db, username=username, password=password)
|
||||
|
||||
access_token_expires_longterm = timedelta(days=365)
|
||||
access_token = create_token(
|
||||
data={"sub": str(super_user.id)},
|
||||
expires_delta=access_token_expires_longterm,
|
||||
)
|
||||
|
||||
# Update: last_login_at
|
||||
update_user_last_login_at(super_user.id, db)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": None,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
def create_user_api_key(user_id: UUID) -> dict:
|
||||
access_token = create_token(
|
||||
data={"sub": str(user_id), "role": "api_key"},
|
||||
expires_delta=timedelta(days=365 * 2),
|
||||
)
|
||||
|
||||
return {"api_key": access_token}
|
||||
|
||||
|
||||
def get_user_id_from_token(token: str) -> UUID:
|
||||
try:
|
||||
user_id = jwt.get_unverified_claims(token)["sub"]
|
||||
return UUID(user_id)
|
||||
except (KeyError, JWTError, ValueError):
|
||||
return UUID(int=0)
|
||||
|
||||
|
||||
def create_user_tokens(
|
||||
user_id: UUID, db: Session = Depends(get_session), update_last_login: bool = False
|
||||
) -> dict:
|
||||
settings_service = get_settings_service()
|
||||
|
||||
access_token_expires = timedelta(
|
||||
minutes=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
access_token = create_token(
|
||||
data={"sub": str(user_id)},
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
refresh_token_expires = timedelta(
|
||||
minutes=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
refresh_token = create_token(
|
||||
data={"sub": str(user_id), "type": "rf"},
|
||||
expires_delta=refresh_token_expires,
|
||||
)
|
||||
|
||||
# Update: last_login_at
|
||||
if update_last_login:
|
||||
update_user_last_login_at(user_id, db)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
def create_refresh_token(refresh_token: str, db: Session = Depends(get_session)):
|
||||
settings_service = get_settings_service()
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
refresh_token,
|
||||
settings_service.auth_settings.SECRET_KEY,
|
||||
algorithms=[settings_service.auth_settings.ALGORITHM],
|
||||
)
|
||||
user_id: UUID = payload.get("sub") # type: ignore
|
||||
token_type: str = payload.get("type") # type: ignore
|
||||
|
||||
if user_id is None or token_type is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token"
|
||||
)
|
||||
|
||||
return create_user_tokens(user_id, db)
|
||||
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
) from e
|
||||
|
||||
|
||||
def authenticate_user(
|
||||
username: str, password: str, db: Session = Depends(get_session)
|
||||
) -> Optional[User]:
|
||||
user = get_user_by_username(db, username)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
if not user.last_login_at:
|
||||
raise HTTPException(status_code=400, detail="Waiting for approval")
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
return user if verify_password(password, user.password) else None
|
||||
12
src/backend/langflow/services/base.py
Normal file
12
src/backend/langflow/services/base.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from abc import ABC
|
||||
|
||||
|
||||
class Service(ABC):
|
||||
name: str
|
||||
ready: bool = False
|
||||
|
||||
def teardown(self):
|
||||
pass
|
||||
|
||||
def set_ready(self):
|
||||
self.ready = True
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue