Merge remote-tracking branch 'origin/dev' into merge
This commit is contained in:
commit
e3a2abacae
231 changed files with 22158 additions and 6728 deletions
|
|
@ -9,17 +9,19 @@ from typing import Optional
|
|||
import httpx
|
||||
import typer
|
||||
from dotenv import load_dotenv
|
||||
from langflow.main import setup_app
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.getters import get_db_service, get_settings_service
|
||||
from langflow.services.utils import initialize_services, initialize_settings_service
|
||||
from langflow.utils.logger import configure, logger
|
||||
from multiprocess import Process, cpu_count # type: ignore
|
||||
from rich import box
|
||||
from rich import print as rprint
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.main import setup_app
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_db_service, get_settings_service
|
||||
from langflow.services.utils import initialize_services, initialize_settings_service
|
||||
from langflow.utils.logger import configure, logger
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -70,6 +72,7 @@ def update_settings(
|
|||
dev: bool = False,
|
||||
remove_api_keys: bool = False,
|
||||
components_path: Optional[Path] = None,
|
||||
store: bool = False,
|
||||
):
|
||||
"""Update the settings from a config file."""
|
||||
|
||||
|
|
@ -88,16 +91,38 @@ def update_settings(
|
|||
if components_path:
|
||||
logger.debug(f"Adding component path {components_path}")
|
||||
settings_service.settings.update_settings(COMPONENTS_PATH=components_path)
|
||||
if not store:
|
||||
logger.debug("Setting store to False")
|
||||
settings_service.settings.update_settings(STORE=False)
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
"""
|
||||
Show the version and exit.
|
||||
"""
|
||||
from langflow import __version__
|
||||
|
||||
if value:
|
||||
typer.echo(f"Langflow Version: {__version__}")
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main_entry_point(
|
||||
version: bool = typer.Option(
|
||||
None, "--version", callback=version_callback, is_eager=True, help="Show the version and exit."
|
||||
),
|
||||
):
|
||||
"""
|
||||
Main entry point for the Langflow CLI.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
host: str = typer.Option(
|
||||
"127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST"
|
||||
),
|
||||
workers: int = typer.Option(
|
||||
1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"
|
||||
),
|
||||
host: str = typer.Option("127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST"),
|
||||
workers: int = typer.Option(1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"),
|
||||
timeout: int = typer.Option(300, help="Worker timeout in seconds."),
|
||||
port: int = typer.Option(7860, help="Port to listen on.", envvar="LANGFLOW_PORT"),
|
||||
components_path: Optional[Path] = typer.Option(
|
||||
|
|
@ -105,32 +130,17 @@ def run(
|
|||
help="Path to the directory containing custom components.",
|
||||
envvar="LANGFLOW_COMPONENTS_PATH",
|
||||
),
|
||||
config: str = typer.Option(
|
||||
Path(__file__).parent / "config.yaml", help="Path to the configuration file."
|
||||
),
|
||||
config: str = typer.Option(Path(__file__).parent / "config.yaml", help="Path to the configuration file."),
|
||||
# .env file param
|
||||
env_file: Path = typer.Option(
|
||||
None, help="Path to the .env file containing environment variables."
|
||||
),
|
||||
log_level: str = typer.Option(
|
||||
"critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"
|
||||
),
|
||||
log_file: Path = typer.Option(
|
||||
"logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE"
|
||||
),
|
||||
env_file: Path = typer.Option(None, help="Path to the .env file containing environment variables."),
|
||||
log_level: str = typer.Option("critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
|
||||
log_file: Path = typer.Option("logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE"),
|
||||
cache: Optional[str] = typer.Option(
|
||||
envvar="LANGFLOW_LANGCHAIN_CACHE",
|
||||
help="Type of cache to use. (InMemoryCache, SQLiteCache)",
|
||||
default=None,
|
||||
),
|
||||
dev: bool = typer.Option(False, help="Run in development mode (may contain bugs)"),
|
||||
# This variable does not work but is set by the .env file
|
||||
# and works with Pydantic
|
||||
# database_url: str = typer.Option(
|
||||
# None,
|
||||
# help="Database URL to connect to. If not provided, a local SQLite database will be used.",
|
||||
# envvar="LANGFLOW_DATABASE_URL",
|
||||
# ),
|
||||
path: str = typer.Option(
|
||||
None,
|
||||
help="Path to the frontend directory containing build files. This is for development purposes only.",
|
||||
|
|
@ -151,6 +161,11 @@ def run(
|
|||
help="Run only the backend server without the frontend.",
|
||||
envvar="LANGFLOW_BACKEND_ONLY",
|
||||
),
|
||||
store: bool = typer.Option(
|
||||
True,
|
||||
help="Enables the store features.",
|
||||
envvar="LANGFLOW_STORE",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run the Langflow.
|
||||
|
|
@ -169,6 +184,7 @@ def run(
|
|||
remove_api_keys=remove_api_keys,
|
||||
cache=cache,
|
||||
components_path=components_path,
|
||||
store=store,
|
||||
)
|
||||
# create path object if path is provided
|
||||
static_files_dir: Optional[Path] = Path(path) if path else None
|
||||
|
|
@ -198,9 +214,7 @@ def run(
|
|||
|
||||
|
||||
def run_on_mac_or_linux(host, port, log_level, options, app, open_browser=True):
|
||||
webapp_process = Process(
|
||||
target=run_langflow, args=(host, port, log_level, options, app)
|
||||
)
|
||||
webapp_process = Process(target=run_langflow, args=(host, port, log_level, options, app))
|
||||
webapp_process.start()
|
||||
status_code = 0
|
||||
while status_code != 200:
|
||||
|
|
@ -276,9 +290,7 @@ def print_banner(host, port):
|
|||
)
|
||||
|
||||
# Create a panel with the title and the info text, and a border around it
|
||||
panel = Panel(
|
||||
f"{title}\n{info_text}", box=box.ROUNDED, border_style="blue", expand=False
|
||||
)
|
||||
panel = Panel(f"{title}\n{info_text}", box=box.ROUNDED, border_style="blue", expand=False)
|
||||
|
||||
# Print the banner with a separator line before and after
|
||||
rprint(panel)
|
||||
|
|
@ -310,12 +322,8 @@ def run_langflow(host, port, log_level, options, app):
|
|||
@app.command()
|
||||
def superuser(
|
||||
username: str = typer.Option(..., prompt=True, help="Username for the superuser."),
|
||||
password: str = typer.Option(
|
||||
..., prompt=True, hide_input=True, help="Password for the superuser."
|
||||
),
|
||||
log_level: str = typer.Option(
|
||||
"critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"
|
||||
),
|
||||
password: str = typer.Option(..., prompt=True, hide_input=True, help="Password for the superuser."),
|
||||
log_level: str = typer.Option("critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
|
||||
):
|
||||
"""
|
||||
Create a superuser.
|
||||
|
|
@ -328,9 +336,9 @@ def superuser(
|
|||
|
||||
if create_super_user(db=session, username=username, password=password):
|
||||
# Verify that the superuser was created
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.database.models.user.model import User
|
||||
|
||||
user: User = session.query(User).filter(User.username == username).first()
|
||||
user: User = session.exec(select(User).where(User.username == username)).first()
|
||||
if user is None or not user.is_superuser:
|
||||
typer.echo("Superuser creation failed.")
|
||||
return
|
||||
|
|
@ -342,11 +350,23 @@ def superuser(
|
|||
|
||||
|
||||
@app.command()
|
||||
def migration(test: bool = typer.Option(True, help="Run migrations in test mode.")):
|
||||
def migration(
|
||||
test: bool = typer.Option(True, help="Run migrations in test mode."),
|
||||
fix: bool = typer.Option(
|
||||
False,
|
||||
help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run or test migrations.
|
||||
"""
|
||||
initialize_services()
|
||||
if fix:
|
||||
if not typer.confirm(
|
||||
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
|
||||
):
|
||||
raise typer.Abort()
|
||||
|
||||
initialize_services(fix_migration=fix)
|
||||
db_service = get_db_service()
|
||||
if not test:
|
||||
db_service.run_migrations()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy import pool
|
|||
|
||||
from alembic import context
|
||||
|
||||
from langflow.services.database.manager import SQLModel
|
||||
from langflow.services.database.service import SQLModel
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
|
|||
43
src/backend/langflow/alembic/versions/1ef9c4f3765d_.py
Normal file
43
src/backend/langflow/alembic/versions/1ef9c4f3765d_.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""
|
||||
|
||||
|
||||
Revision ID: 1ef9c4f3765d
|
||||
Revises: fd531f8868b1
|
||||
Create Date: 2023-12-04 15:00:27.968998
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1ef9c4f3765d'
|
||||
down_revision: Union[str, None] = 'fd531f8868b1'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('apikey', schema=None) as batch_op:
|
||||
batch_op.alter_column('name',
|
||||
existing_type=sqlmodel.sql.sqltypes.AutoString(),
|
||||
nullable=True)
|
||||
except Exception as e:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('apikey', schema=None) as batch_op:
|
||||
batch_op.alter_column('name',
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=False)
|
||||
except Exception as e:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -7,9 +7,9 @@ Create Date: 2023-08-27 19:49:02.681355
|
|||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
@ -23,7 +23,7 @@ def upgrade() -> None:
|
|||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
inspector = Inspector.from_engine(conn) # type: ignore
|
||||
# List existing tables
|
||||
existing_tables = inspector.get_table_names()
|
||||
# Drop 'flowstyle' table if it exists
|
||||
|
|
@ -145,8 +145,8 @@ def upgrade() -> None:
|
|||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn) # type: ignore
|
||||
# List existing tables
|
||||
existing_tables = inspector.get_table_names()
|
||||
if "flow" in existing_tables:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
"""Adds Credential table
|
||||
|
||||
Revision ID: c1c8e217a069
|
||||
Revises: 7d2162acc8b2
|
||||
Create Date: 2023-11-24 10:45:38.465302
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2ac71eb9c3ae'
|
||||
down_revision: Union[str, None] = '7d2162acc8b2'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
op.create_table('credential',
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('user_id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column('id', sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
op.drop_table('credential')
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -7,9 +7,9 @@ Create Date: 2023-09-08 07:36:13.387318
|
|||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
@ -22,7 +22,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
inspector = Inspector.from_engine(conn) # type: ignore
|
||||
if "user" in inspector.get_table_names() and "profile_image" not in [
|
||||
column["name"] for column in inspector.get_columns("user")
|
||||
]:
|
||||
|
|
@ -39,7 +39,7 @@ def upgrade() -> None:
|
|||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
inspector = Inspector.from_engine(conn) # type: ignore
|
||||
if "user" in inspector.get_table_names() and "profile_image" in [
|
||||
column["name"] for column in inspector.get_columns("user")
|
||||
]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
"""Store updates
|
||||
|
||||
Revision ID: 7843803a87b5
|
||||
Revises: eb5866d51fd2
|
||||
Create Date: 2023-10-18 23:08:57.744906
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from loguru import logger
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7843803a87b5"
|
||||
down_revision: Union[str, None] = "eb5866d51fd2"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("is_component", sa.Boolean(), nullable=True))
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"store_api_key", sqlmodel.AutoString(), nullable=True
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.drop_column("store_api_key")
|
||||
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
batch_op.drop_column("is_component")
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
"""Adds updated_at and folder cols
|
||||
|
||||
Revision ID: 7d2162acc8b2
|
||||
Revises: f5ee9749d1a6
|
||||
Create Date: 2023-11-21 20:56:53.998781
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7d2162acc8b2'
|
||||
down_revision: Union[str, None] = 'f5ee9749d1a6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('component', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_component_frontend_node_id')
|
||||
batch_op.drop_index('ix_component_name')
|
||||
op.drop_table('component')
|
||||
op.drop_table('flowstyle')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
with op.batch_alter_table('apikey', schema=None) as batch_op:
|
||||
batch_op.alter_column('name',
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('flow', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('updated_at', sa.DateTime(), nullable=True))
|
||||
batch_op.add_column(sa.Column('folder', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('flow', schema=None) as batch_op:
|
||||
batch_op.drop_column('folder')
|
||||
batch_op.drop_column('updated_at')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
try:
|
||||
|
||||
with op.batch_alter_table('apikey', schema=None) as batch_op:
|
||||
batch_op.alter_column('name',
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=True)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
try:
|
||||
op.create_table('flowstyle',
|
||||
sa.Column('color', sa.VARCHAR(), nullable=False),
|
||||
sa.Column('emoji', sa.VARCHAR(), nullable=False),
|
||||
sa.Column('flow_id', sa.CHAR(length=32), nullable=True),
|
||||
sa.Column('id', sa.CHAR(length=32), nullable=False),
|
||||
sa.ForeignKeyConstraint(['flow_id'], ['flow.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
)
|
||||
op.create_table('component',
|
||||
sa.Column('id', sa.CHAR(length=32), nullable=False),
|
||||
sa.Column('frontend_node_id', sa.CHAR(length=32), nullable=False),
|
||||
sa.Column('name', sa.VARCHAR(), nullable=False),
|
||||
sa.Column('description', sa.VARCHAR(), nullable=True),
|
||||
sa.Column('python_code', sa.VARCHAR(), nullable=True),
|
||||
sa.Column('return_type', sa.VARCHAR(), nullable=True),
|
||||
sa.Column('is_disabled', sa.BOOLEAN(), nullable=False),
|
||||
sa.Column('is_read_only', sa.BOOLEAN(), nullable=False),
|
||||
sa.Column('create_at', sa.DATETIME(), nullable=False),
|
||||
sa.Column('update_at', sa.DATETIME(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('component', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_component_name', ['name'], unique=False)
|
||||
batch_op.create_index('ix_component_frontend_node_id', ['frontend_node_id'], unique=False)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -7,10 +7,9 @@ Create Date: 2023-10-04 10:18:25.640458
|
|||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import exc
|
||||
import sqlmodel # noqa: F401
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "eb5866d51fd2"
|
||||
|
|
@ -28,14 +27,16 @@ def upgrade() -> None:
|
|||
batch_op.drop_index("ix_component_frontend_node_id")
|
||||
batch_op.drop_index("ix_component_name")
|
||||
except exc.SQLAlchemyError:
|
||||
connection.execute("ROLLBACK")
|
||||
# connection.execute(text("ROLLBACK"))
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
op.drop_table("component")
|
||||
except exc.SQLAlchemyError:
|
||||
connection.execute("ROLLBACK")
|
||||
# connection.execute(text("ROLLBACK"))
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
"""User id can be null in Flow
|
||||
|
||||
Revision ID: f5ee9749d1a6
|
||||
Revises: 7843803a87b5
|
||||
Create Date: 2023-10-18 23:12:27.297016
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f5ee9749d1a6"
|
||||
down_revision: Union[str, None] = "7843803a87b5"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"user_id", existing_type=sa.CHAR(length=32), nullable=True
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table("flow", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"user_id", existing_type=sa.CHAR(length=32), nullable=False
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""Fix Credential table
|
||||
|
||||
Revision ID: fd531f8868b1
|
||||
Revises: 2ac71eb9c3ae
|
||||
Create Date: 2023-11-24 15:07:37.566516
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fd531f8868b1'
|
||||
down_revision: Union[str, None] = '2ac71eb9c3ae'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('credential', schema=None) as batch_op:
|
||||
batch_op.create_foreign_key("fk_credential_user_id", 'user', ['user_id'], ['id'])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
try:
|
||||
with op.batch_alter_table('credential', schema=None) as batch_op:
|
||||
batch_op.drop_constraint("fk_credential_user_id", type_='foreignkey')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
# Router for base api
|
||||
from fastapi import APIRouter
|
||||
|
||||
from langflow.api.v1 import (
|
||||
chat_router,
|
||||
endpoints_router,
|
||||
validate_router,
|
||||
flows_router,
|
||||
component_router,
|
||||
users_router,
|
||||
api_key_router,
|
||||
chat_router,
|
||||
credentials_router,
|
||||
endpoints_router,
|
||||
flows_router,
|
||||
login_router,
|
||||
store_router,
|
||||
users_router,
|
||||
validate_router,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -17,8 +19,9 @@ router = APIRouter(
|
|||
router.include_router(chat_router)
|
||||
router.include_router(endpoints_router)
|
||||
router.include_router(validate_router)
|
||||
router.include_router(component_router)
|
||||
router.include_router(store_router)
|
||||
router.include_router(flows_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(api_key_router)
|
||||
router.include_router(login_router)
|
||||
router.include_router(credentials_router)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,22 @@
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from fastapi import HTTPException
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
from langflow.services.store.schema import StoreComponentCreate
|
||||
from langflow.services.store.utils import get_lf_version_from_pypi
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.services.database.models.flow.model import Flow
|
||||
|
||||
|
||||
API_WORDS = ["api", "key", "token"]
|
||||
|
||||
|
||||
def has_api_terms(word: str):
|
||||
return "api" in word and (
|
||||
"key" in word or ("token" in word and "tokens" not in word)
|
||||
)
|
||||
return "api" in word and ("key" in word or ("token" in word and "tokens" not in word))
|
||||
|
||||
|
||||
def remove_api_keys(flow: dict):
|
||||
|
|
@ -14,11 +26,7 @@ def remove_api_keys(flow: dict):
|
|||
node_data = node.get("data").get("node")
|
||||
template = node_data.get("template")
|
||||
for value in template.values():
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and has_api_terms(value["name"])
|
||||
and value.get("password")
|
||||
):
|
||||
if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"):
|
||||
value["value"] = None
|
||||
|
||||
return flow
|
||||
|
|
@ -39,9 +47,7 @@ def build_input_keys_response(langchain_object, artifacts):
|
|||
input_keys_response["input_keys"][key] = value
|
||||
# If the object has memory, that memory will have a memory_variables attribute
|
||||
# memory variables should be removed from the input keys
|
||||
if hasattr(langchain_object, "memory") and hasattr(
|
||||
langchain_object.memory, "memory_variables"
|
||||
):
|
||||
if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"):
|
||||
# Remove memory variables from input keys
|
||||
input_keys_response["input_keys"] = {
|
||||
key: value
|
||||
|
|
@ -51,18 +57,133 @@ def build_input_keys_response(langchain_object, artifacts):
|
|||
# Add memory variables to memory_keys
|
||||
input_keys_response["memory_keys"] = langchain_object.memory.memory_variables
|
||||
|
||||
if hasattr(langchain_object, "prompt") and hasattr(
|
||||
langchain_object.prompt, "template"
|
||||
):
|
||||
if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"):
|
||||
input_keys_response["template"] = langchain_object.prompt.template
|
||||
|
||||
return input_keys_response
|
||||
|
||||
|
||||
def get_new_key(dictionary, original_key):
|
||||
counter = 1
|
||||
new_key = original_key + " (" + str(counter) + ")"
|
||||
while new_key in dictionary:
|
||||
counter += 1
|
||||
new_key = original_key + " (" + str(counter) + ")"
|
||||
return new_key
|
||||
def update_frontend_node_with_template_values(frontend_node, raw_frontend_node):
|
||||
"""
|
||||
Updates the given frontend node with values from the raw template data.
|
||||
|
||||
:param frontend_node: A dict representing a built frontend node.
|
||||
:param raw_template_data: A dict representing raw template data.
|
||||
:return: Updated frontend node.
|
||||
"""
|
||||
if not is_valid_data(frontend_node, raw_frontend_node):
|
||||
return frontend_node
|
||||
|
||||
# Check if the display_name is different than "CustomComponent"
|
||||
# if so, update the display_name in the frontend_node
|
||||
if raw_frontend_node["display_name"] != "CustomComponent":
|
||||
frontend_node["display_name"] = raw_frontend_node["display_name"]
|
||||
|
||||
update_template_values(frontend_node["template"], raw_frontend_node["template"])
|
||||
|
||||
return frontend_node
|
||||
|
||||
|
||||
def raw_frontend_data_is_valid(raw_frontend_data):
|
||||
"""Check if the raw frontend data is valid for processing."""
|
||||
return "template" in raw_frontend_data and "display_name" in raw_frontend_data
|
||||
|
||||
|
||||
def is_valid_data(frontend_node, raw_frontend_data):
|
||||
"""Check if the data is valid for processing."""
|
||||
|
||||
return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data)
|
||||
|
||||
|
||||
def update_template_values(frontend_template, raw_template):
|
||||
"""Updates the frontend template with values from the raw template."""
|
||||
for key, value_dict in raw_template.items():
|
||||
if key == "code" or not isinstance(value_dict, dict):
|
||||
continue
|
||||
|
||||
update_template_field(frontend_template, key, value_dict)
|
||||
|
||||
|
||||
def update_template_field(frontend_template, key, value_dict):
|
||||
"""Updates a specific field in the frontend template."""
|
||||
template_field = frontend_template.get(key)
|
||||
if not template_field or template_field.get("type") != value_dict.get("type"):
|
||||
return
|
||||
|
||||
if "value" in value_dict and value_dict["value"]:
|
||||
template_field["value"] = value_dict["value"]
|
||||
|
||||
if "file_path" in value_dict and value_dict["file_path"]:
|
||||
file_path_value = get_file_path_value(value_dict["file_path"])
|
||||
if not file_path_value:
|
||||
# If the file does not exist, remove the value from the template_field["value"]
|
||||
template_field["value"] = ""
|
||||
template_field["file_path"] = file_path_value
|
||||
|
||||
|
||||
def get_file_path_value(file_path):
|
||||
"""Get the file path value if the file exists, else return empty string."""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
except TypeError:
|
||||
return ""
|
||||
|
||||
# Check for safety
|
||||
# If the path is not in the cache dir, return empty string
|
||||
# This is to prevent access to files outside the cache dir
|
||||
# If the path is not a file, return empty string
|
||||
if not path.exists() or not str(path).startswith(user_cache_dir("langflow", "langflow")):
|
||||
return ""
|
||||
return file_path
|
||||
|
||||
|
||||
def validate_is_component(flows: List["Flow"]):
|
||||
for flow in flows:
|
||||
if not flow.data or flow.is_component is not None:
|
||||
continue
|
||||
|
||||
is_component = get_is_component_from_data(flow.data)
|
||||
if is_component is not None:
|
||||
flow.is_component = is_component
|
||||
else:
|
||||
flow.is_component = len(flow.data.get("nodes", [])) == 1
|
||||
return flows
|
||||
|
||||
|
||||
def get_is_component_from_data(data: dict):
|
||||
"""Returns True if the data is a component."""
|
||||
return data.get("is_component")
|
||||
|
||||
|
||||
async def check_langflow_version(component: StoreComponentCreate):
|
||||
from langflow import __version__ as current_version
|
||||
|
||||
if not component.last_tested_version:
|
||||
component.last_tested_version = current_version
|
||||
|
||||
langflow_version = get_lf_version_from_pypi()
|
||||
if langflow_version is None:
|
||||
raise HTTPException(status_code=500, detail="Unable to verify the latest version of Langflow")
|
||||
elif langflow_version != component.last_tested_version:
|
||||
warnings.warn(
|
||||
f"Your version of Langflow ({component.last_tested_version}) is outdated. "
|
||||
f"Please update to the latest version ({langflow_version}) and try again."
|
||||
)
|
||||
|
||||
|
||||
def format_elapsed_time(elapsed_time) -> str:
|
||||
# Format elapsed time to human readable format coming from
|
||||
# perf_counter()
|
||||
# If the elapsed time is less than 1 second, return ms
|
||||
# If the elapsed time is less than 1 minute, return seconds rounded to 2 decimals
|
||||
time_str = ""
|
||||
if elapsed_time < 1:
|
||||
elapsed_time = int(round(elapsed_time * 1000))
|
||||
time_str = f"{elapsed_time} ms"
|
||||
elif elapsed_time < 60:
|
||||
elapsed_time = round(elapsed_time, 2)
|
||||
time_str = f"{elapsed_time} seconds"
|
||||
else:
|
||||
elapsed_time = round(elapsed_time / 60, 2)
|
||||
time_str = f"{elapsed_time} minutes"
|
||||
return time_str
|
||||
|
|
|
|||
|
|
@ -1,19 +1,21 @@
|
|||
from langflow.api.v1.endpoints import router as endpoints_router
|
||||
from langflow.api.v1.validate import router as validate_router
|
||||
from langflow.api.v1.chat import router as chat_router
|
||||
from langflow.api.v1.flows import router as flows_router
|
||||
from langflow.api.v1.components import router as component_router
|
||||
from langflow.api.v1.users import router as users_router
|
||||
from langflow.api.v1.api_key import router as api_key_router
|
||||
from langflow.api.v1.chat import router as chat_router
|
||||
from langflow.api.v1.credential import router as credentials_router
|
||||
from langflow.api.v1.endpoints import router as endpoints_router
|
||||
from langflow.api.v1.flows import router as flows_router
|
||||
from langflow.api.v1.login import router as login_router
|
||||
from langflow.api.v1.store import router as store_router
|
||||
from langflow.api.v1.users import router as users_router
|
||||
from langflow.api.v1.validate import router as validate_router
|
||||
|
||||
__all__ = [
|
||||
"chat_router",
|
||||
"endpoints_router",
|
||||
"component_router",
|
||||
"store_router",
|
||||
"validate_router",
|
||||
"flows_router",
|
||||
"users_router",
|
||||
"api_key_router",
|
||||
"login_router",
|
||||
"credentials_router",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,22 +1,30 @@
|
|||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from langflow.api.v1.schemas import ApiKeysResponse
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.api_key.api_key import (
|
||||
ApiKeyCreate,
|
||||
UnmaskedApiKeyRead,
|
||||
)
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
|
||||
# Assuming you have these methods in your service layer
|
||||
from langflow.services.database.models.api_key.crud import (
|
||||
get_api_keys,
|
||||
create_api_key,
|
||||
delete_api_key,
|
||||
get_api_keys,
|
||||
)
|
||||
from langflow.services.database.models.api_key.model import (
|
||||
ApiKeyCreate,
|
||||
UnmaskedApiKeyRead,
|
||||
)
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import (
|
||||
get_session,
|
||||
get_settings_service,
|
||||
)
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.getters import get_session
|
||||
from sqlmodel import Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
router = APIRouter(tags=["APIKey"], prefix="/api_key")
|
||||
|
||||
|
|
@ -24,7 +32,7 @@ router = APIRouter(tags=["APIKey"], prefix="/api_key")
|
|||
@router.get("/", response_model=ApiKeysResponse)
|
||||
def get_api_keys_route(
|
||||
db: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
current_user: User = Depends(auth_utils.get_current_active_user),
|
||||
):
|
||||
try:
|
||||
user_id = current_user.id
|
||||
|
|
@ -38,7 +46,7 @@ def get_api_keys_route(
|
|||
@router.post("/", response_model=UnmaskedApiKeyRead)
|
||||
def create_api_key_route(
|
||||
req: ApiKeyCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
current_user: User = Depends(auth_utils.get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
|
|
@ -51,7 +59,7 @@ def create_api_key_route(
|
|||
@router.delete("/{api_key_id}")
|
||||
def delete_api_key_route(
|
||||
api_key_id: UUID,
|
||||
current_user=Depends(get_current_active_user),
|
||||
current_user=Depends(auth_utils.get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
|
|
@ -59,3 +67,34 @@ def delete_api_key_route(
|
|||
return {"detail": "API Key deleted"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/store")
|
||||
def save_store_api_key(
|
||||
api_key_request: ApiKeyCreateRequest,
|
||||
current_user: User = Depends(auth_utils.get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
settings_service=Depends(get_settings_service),
|
||||
):
|
||||
try:
|
||||
api_key = api_key_request.api_key
|
||||
# Encrypt the API key
|
||||
encrypted = auth_utils.encrypt_api_key(api_key, settings_service=settings_service)
|
||||
current_user.store_api_key = encrypted
|
||||
db.commit()
|
||||
return {"detail": "API Key saved"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/store")
|
||||
def delete_store_api_key(
|
||||
current_user: User = Depends(auth_utils.get_current_active_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
try:
|
||||
current_user.store_api_key = None
|
||||
db.commit()
|
||||
return {"detail": "API Key deleted"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Optional
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pydantic import BaseModel, field_validator, model_serializer
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
||||
|
||||
class CacheResponse(BaseModel):
|
||||
|
|
@ -17,6 +18,12 @@ class Code(BaseModel):
|
|||
class FrontendNodeRequest(FrontendNode):
|
||||
template: dict # type: ignore
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(self, handler):
|
||||
# Override the default serialization method in FrontendNode
|
||||
# because we don't need the name in the response (i.e. {name: {}})
|
||||
return handler(self)
|
||||
|
||||
|
||||
class ValidatePromptRequest(BaseModel):
|
||||
name: str
|
||||
|
|
@ -30,11 +37,13 @@ class CodeValidationResponse(BaseModel):
|
|||
imports: dict
|
||||
function: dict
|
||||
|
||||
@validator("imports")
|
||||
@field_validator("imports")
|
||||
@classmethod
|
||||
def validate_imports(cls, v):
|
||||
return v or {"errors": []}
|
||||
|
||||
@validator("function")
|
||||
@field_validator("function")
|
||||
@classmethod
|
||||
def validate_function(cls, v):
|
||||
return v or {"errors": []}
|
||||
|
||||
|
|
@ -79,9 +88,7 @@ def validate_prompt(template: str):
|
|||
# Check if there are invalid characters in the input_variables
|
||||
input_variables = check_input_variables(input_variables)
|
||||
if any(var in INVALID_NAMES for var in input_variables):
|
||||
raise ValueError(
|
||||
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
|
||||
)
|
||||
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
|
||||
|
||||
try:
|
||||
PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
|
@ -132,9 +139,7 @@ def check_input_variables(input_variables: list):
|
|||
return input_variables
|
||||
|
||||
|
||||
def build_error_message(
|
||||
input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables
|
||||
):
|
||||
def build_error_message(input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables):
|
||||
input_variables_str = ", ".join([f"'{var}'" for var in input_variables])
|
||||
error_string = f"Invalid input variables: {input_variables_str}. "
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,15 @@
|
|||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
from langflow.api.v1.schemas import ChatResponse, PromptResponse
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langflow.services.getters import get_chat_service
|
||||
|
||||
|
||||
from langflow.utils.util import remove_ansi_escape_codes
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.v1.schemas import ChatResponse, PromptResponse
|
||||
from langflow.services.deps import get_chat_service
|
||||
from langflow.utils.util import remove_ansi_escape_codes
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
|
|
@ -26,18 +22,16 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.websocket.send_json(resp.dict())
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=f"Tool input: {input_str}",
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
|
|
@ -68,7 +62,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
try:
|
||||
# This is to emulate the stream of tokens
|
||||
for resp in resps:
|
||||
await self.websocket.send_json(resp.dict())
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
except Exception as exc:
|
||||
logger.error(f"Error sending response: {exc}")
|
||||
|
||||
|
|
@ -94,7 +88,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
resp = PromptResponse(
|
||||
prompt=text,
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
self.chat_service.chat_history.add_message(self.client_id, resp)
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
|
||||
|
|
@ -105,10 +99,10 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
logs = log.split("\n")
|
||||
for log in logs:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
else:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
|
|
@ -117,7 +111,7 @@ class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
|||
type="stream",
|
||||
intermediate_steps=finish.log,
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
await self.websocket.send_json(resp.model_dump())
|
||||
|
||||
|
||||
class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
||||
|
|
@ -132,5 +126,5 @@ class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
|||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
coroutine = self.websocket.send_json(resp.dict())
|
||||
coroutine = self.websocket.send_json(resp.model_dump())
|
||||
asyncio.run_coroutine_threadsafe(coroutine, loop)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,22 @@
|
|||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Query,
|
||||
WebSocket,
|
||||
WebSocketException,
|
||||
status,
|
||||
)
|
||||
import time
|
||||
|
||||
from fastapi import (APIRouter, Depends, HTTPException, Query, WebSocket,
|
||||
WebSocketException, status)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langflow.api.utils import build_input_keys_response
|
||||
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import get_current_active_user, get_current_user
|
||||
from langflow.services.cache.utils import update_build_status
|
||||
from loguru import logger
|
||||
from langflow.services.getters import get_chat_service, get_session, get_cache_service
|
||||
from sqlmodel import Session
|
||||
from langflow.services.chat.manager import ChatService
|
||||
from langflow.services.cache.manager import BaseCacheService
|
||||
|
||||
from langflow.api.utils import build_input_keys_response, format_elapsed_time
|
||||
from langflow.api.v1.schemas import (BuildStatus, BuiltResponse, InitResponse,
|
||||
StreamData)
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.services.auth.utils import (get_current_active_user,
|
||||
get_current_user_by_jwt)
|
||||
from langflow.services.cache.service import BaseCacheService
|
||||
from langflow.services.cache.utils import update_build_status
|
||||
from langflow.services.chat.service import ChatService
|
||||
from langflow.services.deps import (get_cache_service, get_chat_service,
|
||||
get_session)
|
||||
|
||||
router = APIRouter(tags=["Chat"])
|
||||
|
||||
|
|
@ -34,16 +31,12 @@ async def chat(
|
|||
):
|
||||
"""Websocket endpoint for chat."""
|
||||
try:
|
||||
user = await get_current_user_by_jwt(token, db)
|
||||
await websocket.accept()
|
||||
user = await get_current_user(token, db)
|
||||
if not user:
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
if not user.is_active:
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
|
||||
if client_id in chat_service.cache_service:
|
||||
await chat_service.handle_websocket(client_id, websocket)
|
||||
|
|
@ -59,9 +52,7 @@ async def chat(
|
|||
logger.error(f"Error in chat websocket: {exc}")
|
||||
messsage = exc.detail if isinstance(exc, HTTPException) else str(exc)
|
||||
if "Could not validate credentials" in str(exc):
|
||||
await websocket.close(
|
||||
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
|
||||
)
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized")
|
||||
else:
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage)
|
||||
|
||||
|
|
@ -103,15 +94,10 @@ async def init_build(
|
|||
|
||||
|
||||
@router.get("/build/{flow_id}/status", response_model=BuiltResponse)
|
||||
async def build_status(
|
||||
flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)
|
||||
):
|
||||
async def build_status(flow_id: str, cache_service: "BaseCacheService" = Depends(get_cache_service)):
|
||||
"""Check the flow_id is in the cache_service."""
|
||||
try:
|
||||
built = (
|
||||
flow_id in cache_service
|
||||
and cache_service[flow_id]["status"] == BuildStatus.SUCCESS
|
||||
)
|
||||
built = flow_id in cache_service and cache_service[flow_id]["status"] == BuildStatus.SUCCESS
|
||||
|
||||
return BuiltResponse(
|
||||
built=built,
|
||||
|
|
@ -133,19 +119,20 @@ async def stream_build(
|
|||
async def event_stream(flow_id):
|
||||
final_response = {"end_of_stream": True}
|
||||
artifacts = {}
|
||||
flow_cache = cache_service[flow_id]
|
||||
flow_cache = flow_cache if isinstance(flow_cache, dict) else {}
|
||||
try:
|
||||
if flow_id not in cache_service:
|
||||
error_message = "Invalid session ID"
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
if cache_service[flow_id].get("status") == BuildStatus.IN_PROGRESS:
|
||||
if flow_cache.get("status") == BuildStatus.IN_PROGRESS:
|
||||
error_message = "Already building"
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
graph_data = cache_service[flow_id].get("graph_data")
|
||||
cache_service[flow_id]["user_id"]
|
||||
graph_data = flow_cache.get("graph_data")
|
||||
|
||||
if not graph_data:
|
||||
error_message = "No data provided"
|
||||
|
|
@ -157,25 +144,32 @@ async def stream_build(
|
|||
# Some error could happen when building the graph
|
||||
graph = Graph.from_payload(graph_data)
|
||||
|
||||
number_of_nodes = len(graph.nodes)
|
||||
number_of_nodes = len(graph.vertices)
|
||||
update_build_status(cache_service, flow_id, BuildStatus.IN_PROGRESS)
|
||||
|
||||
try:
|
||||
user_id = flow_cache["user_id"]
|
||||
except KeyError:
|
||||
logger.debug("No user_id found in cache_service")
|
||||
user_id = None
|
||||
for i, vertex in enumerate(graph.generator_build(), 1):
|
||||
try:
|
||||
log_dict = {
|
||||
"log": f"Building node {vertex.vertex_type}",
|
||||
}
|
||||
yield str(StreamData(event="log", data=log_dict))
|
||||
# time this
|
||||
start_time = time.perf_counter()
|
||||
if vertex.is_task:
|
||||
vertex = try_running_celery_task(vertex)
|
||||
vertex = await try_running_celery_task(vertex, user_id)
|
||||
else:
|
||||
vertex.build()
|
||||
await vertex.build(user_id=user_id)
|
||||
time_elapsed = format_elapsed_time(time.perf_counter() - start_time)
|
||||
params = vertex._built_object_repr()
|
||||
valid = True
|
||||
|
||||
logger.debug(f"Building node {str(vertex.vertex_type)}")
|
||||
logger.debug(
|
||||
f"Output: {params[:100]}{'...' if len(params) > 100 else ''}"
|
||||
)
|
||||
logger.debug(f"Output: {params[:100]}{'...' if len(params) > 100 else ''}")
|
||||
if vertex.artifacts:
|
||||
# The artifacts will be prompt variables
|
||||
# passed to build_input_keys_response
|
||||
|
|
@ -187,21 +181,22 @@ async def stream_build(
|
|||
valid = False
|
||||
update_build_status(cache_service, flow_id, BuildStatus.FAILURE)
|
||||
|
||||
response = {
|
||||
"valid": valid,
|
||||
"params": params,
|
||||
"id": vertex.id,
|
||||
"progress": round(i / number_of_nodes, 2),
|
||||
}
|
||||
vertex_id = vertex.parent_node_id if vertex.parent_is_top_level else vertex.id
|
||||
if vertex_id in graph.top_level_vertices:
|
||||
response = {
|
||||
"valid": valid,
|
||||
"params": params,
|
||||
"id": vertex_id,
|
||||
"progress": round(i / number_of_nodes, 2),
|
||||
"duration": time_elapsed,
|
||||
}
|
||||
|
||||
yield str(StreamData(event="message", data=response))
|
||||
|
||||
langchain_object = graph.build()
|
||||
langchain_object = await graph.build()
|
||||
# Now we need to check the input_keys to send them to the client
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
input_keys_response = build_input_keys_response(
|
||||
langchain_object, artifacts
|
||||
)
|
||||
input_keys_response = build_input_keys_response(langchain_object, artifacts)
|
||||
else:
|
||||
input_keys_response = {
|
||||
"input_keys": None,
|
||||
|
|
@ -229,7 +224,7 @@ async def stream_build(
|
|||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
def try_running_celery_task(vertex):
|
||||
async def try_running_celery_task(vertex, user_id):
|
||||
# Try running the task in celery
|
||||
# and set the task_id to the local vertex
|
||||
# if it fails, run the task locally
|
||||
|
|
@ -241,5 +236,5 @@ def try_running_celery_task(vertex):
|
|||
except Exception as exc:
|
||||
logger.debug(f"Error running task in celery: {exc}")
|
||||
vertex.task_id = None
|
||||
vertex.build()
|
||||
await vertex.build(user_id=user_id)
|
||||
return vertex
|
||||
|
|
|
|||
|
|
@ -1,77 +0,0 @@
|
|||
from datetime import timezone
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from langflow.services.database.models.component import Component, ComponentModel
|
||||
from langflow.services.getters import get_session
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
COMPONENT_NOT_FOUND = "Component not found"
|
||||
COMPONENT_ALREADY_EXISTS = "A component with the same id already exists."
|
||||
COMPONENT_DELETED = "Component deleted"
|
||||
|
||||
|
||||
router = APIRouter(prefix="/components", tags=["Components"])
|
||||
|
||||
|
||||
@router.post("/", response_model=Component)
|
||||
def create_component(component: ComponentModel, db: Session = Depends(get_session)):
|
||||
db_component = Component(**component.dict())
|
||||
try:
|
||||
db.add(db_component)
|
||||
db.commit()
|
||||
db.refresh(db_component)
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=COMPONENT_ALREADY_EXISTS,
|
||||
) from e
|
||||
return db_component
|
||||
|
||||
|
||||
@router.get("/{component_id}", response_model=Component)
|
||||
def read_component(component_id: UUID, db: Session = Depends(get_session)):
|
||||
if component := db.get(Component, component_id):
|
||||
return component
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=COMPONENT_NOT_FOUND)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Component])
|
||||
def read_components(skip: int = 0, limit: int = 50, db: Session = Depends(get_session)):
|
||||
query = select(Component)
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
return db.execute(query).fetchall()
|
||||
|
||||
|
||||
@router.patch("/{component_id}", response_model=Component)
|
||||
def update_component(
|
||||
component_id: UUID, component: ComponentModel, db: Session = Depends(get_session)
|
||||
):
|
||||
db_component = db.get(Component, component_id)
|
||||
if not db_component:
|
||||
raise HTTPException(status_code=404, detail=COMPONENT_NOT_FOUND)
|
||||
component_data = component.dict(exclude_unset=True)
|
||||
|
||||
for key, value in component_data.items():
|
||||
setattr(db_component, key, value)
|
||||
|
||||
db_component.update_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(db_component)
|
||||
return db_component
|
||||
|
||||
|
||||
@router.delete("/{component_id}")
|
||||
def delete_component(component_id: UUID, db: Session = Depends(get_session)):
|
||||
component = db.get(Component, component_id)
|
||||
if not component:
|
||||
raise HTTPException(status_code=404, detail=COMPONENT_NOT_FOUND)
|
||||
db.delete(component)
|
||||
db.commit()
|
||||
return {"detail": COMPONENT_DELETED}
|
||||
86
src/backend/langflow/api/v1/credential.py
Normal file
86
src/backend/langflow/api/v1/credential.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.credential import Credential, CredentialCreate, CredentialRead, CredentialUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
from sqlmodel import Session, select
|
||||
|
||||
router = APIRouter(prefix="/credentials", tags=["Credentials"])
|
||||
|
||||
|
||||
@router.post("/", response_model=CredentialRead, status_code=201)
|
||||
def create_credential(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
credential: CredentialCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
settings_service=Depends(get_settings_service),
|
||||
):
|
||||
"""Create a new credential."""
|
||||
try:
|
||||
# check if credential name already exists
|
||||
credential_exists = session.exec(
|
||||
select(Credential).where(Credential.name == credential.name, Credential.user_id == current_user.id)
|
||||
).first()
|
||||
if credential_exists:
|
||||
raise HTTPException(status_code=400, detail="Credential name already exists")
|
||||
|
||||
db_credential = Credential.model_validate(credential, from_attributes=True)
|
||||
if not db_credential.value:
|
||||
raise HTTPException(status_code=400, detail="Credential value cannot be empty")
|
||||
encrypted = auth_utils.encrypt_api_key(db_credential.value, settings_service=settings_service)
|
||||
db_credential.value = encrypted
|
||||
db_credential.user_id = current_user.id
|
||||
session.add(db_credential)
|
||||
session.commit()
|
||||
session.refresh(db_credential)
|
||||
return db_credential
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/", response_model=list[CredentialRead], status_code=200)
|
||||
def read_credentials(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Read all credentials."""
|
||||
try:
|
||||
credentials = session.exec(select(Credential).where(Credential.user_id == current_user.id)).all()
|
||||
return credentials
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.patch("/{credential_id}", response_model=CredentialRead, status_code=200)
|
||||
def update_credential(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
credential_id: UUID,
|
||||
credential: CredentialUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Update a credential."""
|
||||
try:
|
||||
db_credential = session.exec(
|
||||
select(Credential).where(Credential.id == credential_id, Credential.user_id == current_user.id)
|
||||
).first()
|
||||
if not db_credential:
|
||||
raise HTTPException(status_code=404, detail="Credential not found")
|
||||
|
||||
credential_data = credential.model_dump(exclude_unset=True)
|
||||
for key, value in credential_data.items():
|
||||
setattr(db_credential, key, value)
|
||||
db_credential.updated_at = datetime.utcnow()
|
||||
session.commit()
|
||||
session.refresh(db_credential)
|
||||
return db_credential
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
|
@ -1,33 +1,28 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Annotated, Optional, Union
|
||||
from langflow.services.auth.utils import api_key_security, get_current_active_user
|
||||
|
||||
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.processing.process import process_graph_cached, process_tweaks
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.getters import (
|
||||
get_session_service,
|
||||
get_settings_service,
|
||||
get_task_service,
|
||||
)
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body, status
|
||||
import sqlalchemy as sa
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import update_frontend_node_with_template_values
|
||||
from langflow.api.v1.schemas import (
|
||||
CustomComponentCode,
|
||||
ProcessResponse,
|
||||
TaskResponse,
|
||||
TaskStatusResponse,
|
||||
UploadFileResponse,
|
||||
CustomComponentCode,
|
||||
)
|
||||
|
||||
|
||||
from langflow.services.getters import get_session
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.interface.types import build_custom_component_template, create_and_validate_component
|
||||
from langflow.processing.process import process_graph_cached, process_tweaks
|
||||
from langflow.services.auth.utils import api_key_security, get_current_active_user
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
|
||||
|
||||
try:
|
||||
from langflow.worker import process_graph_cached_task
|
||||
|
|
@ -39,8 +34,7 @@ except ImportError:
|
|||
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
from langflow.services.task.manager import TaskService
|
||||
from langflow.services.task.service import TaskService
|
||||
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
|
@ -93,18 +87,15 @@ async def process_flow(
|
|||
)
|
||||
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
flow = (
|
||||
session.query(Flow)
|
||||
.filter(Flow.id == flow_id)
|
||||
.filter(Flow.user_id == api_key_user.id)
|
||||
.first()
|
||||
)
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
if flow.data is None:
|
||||
raise ValueError(f"Flow {flow_id} has no data")
|
||||
graph_data = flow.data
|
||||
task_result = None
|
||||
if tweaks:
|
||||
try:
|
||||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
|
|
@ -112,9 +103,7 @@ async def process_flow(
|
|||
logger.error(f"Error processing tweaks: {exc}")
|
||||
if sync:
|
||||
task_id, result = await task_service.launch_and_await_task(
|
||||
process_graph_cached_task
|
||||
if task_service.use_celery
|
||||
else process_graph_cached,
|
||||
process_graph_cached_task if task_service.use_celery else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
|
|
@ -134,13 +123,9 @@ async def process_flow(
|
|||
)
|
||||
if session_id is None:
|
||||
# Generate a session ID
|
||||
session_id = get_session_service().generate_key(
|
||||
session_id=session_id, data_graph=graph_data
|
||||
)
|
||||
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
|
||||
task_id, task = await task_service.launch_task(
|
||||
process_graph_cached_task
|
||||
if task_service.use_celery
|
||||
else process_graph_cached,
|
||||
process_graph_cached_task if task_service.use_celery else process_graph_cached,
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
|
|
@ -163,18 +148,12 @@ async def process_flow(
|
|||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
|
||||
) from exc
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
@ -188,6 +167,10 @@ async def get_task_status(task_id: str):
|
|||
result = None
|
||||
if task.ready():
|
||||
result = task.result
|
||||
# If result isinstance of Exception, can we get the traceback?
|
||||
if isinstance(result, Exception):
|
||||
logger.exception(task.traceback)
|
||||
|
||||
if isinstance(result, dict) and "result" in result:
|
||||
result = result["result"]
|
||||
elif hasattr(result, "result"):
|
||||
|
|
@ -195,6 +178,10 @@ async def get_task_status(task_id: str):
|
|||
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
if task.status == "FAILURE":
|
||||
result = str(task.result)
|
||||
logger.error(f"Task {task_id} failed: {task.traceback}")
|
||||
|
||||
return TaskStatusResponse(status=task.status, result=result)
|
||||
|
||||
|
||||
|
|
@ -228,12 +215,40 @@ def get_version():
|
|||
@router.post("/custom_component", status_code=HTTPStatus.OK)
|
||||
async def custom_component(
|
||||
raw_code: CustomComponentCode,
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
from langflow.interface.types import (
|
||||
build_langchain_template_custom_component,
|
||||
)
|
||||
component = create_and_validate_component(raw_code.code)
|
||||
|
||||
extractor = CustomComponent(code=raw_code.code)
|
||||
extractor.is_check_valid()
|
||||
built_frontend_node = build_custom_component_template(component, user_id=user.id)
|
||||
|
||||
return build_langchain_template_custom_component(extractor)
|
||||
built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node)
|
||||
return built_frontend_node
|
||||
|
||||
|
||||
@router.post("/custom_component/reload", status_code=HTTPStatus.OK)
|
||||
async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)):
|
||||
from langflow.interface.types import build_custom_component_template
|
||||
|
||||
try:
|
||||
reader = DirectoryReader("")
|
||||
valid, content = reader.process_file(path)
|
||||
if not valid:
|
||||
raise ValueError(content)
|
||||
|
||||
extractor = CustomComponent(code=content)
|
||||
extractor.validate()
|
||||
return build_custom_component_template(extractor, user_id=user.id)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/custom_component/update", status_code=HTTPStatus.OK)
|
||||
async def custom_component_update(
|
||||
raw_code: CustomComponentCode,
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
component = create_and_validate_component(raw_code.code)
|
||||
|
||||
component_node = build_custom_component_template(component, user_id=user.id, update_field=raw_code.field)
|
||||
# Update the field
|
||||
return component_node
|
||||
|
|
|
|||
|
|
@ -1,24 +1,17 @@
|
|||
from datetime import datetime
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from langflow.api.utils import remove_api_keys
|
||||
import orjson
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from langflow.api.utils import remove_api_keys, validate_is_component
|
||||
from langflow.api.v1.schemas import FlowListCreate, FlowListRead
|
||||
from langflow.services.auth.utils import get_current_active_user
|
||||
from langflow.services.database.models.flow import (
|
||||
Flow,
|
||||
FlowCreate,
|
||||
FlowRead,
|
||||
FlowUpdate,
|
||||
)
|
||||
from langflow.services.database.models.user.user import User
|
||||
from langflow.services.getters import get_session
|
||||
from langflow.services.getters import get_settings_service
|
||||
import orjson
|
||||
from sqlmodel import Session
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from fastapi import File, UploadFile
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate, FlowRead, FlowUpdate
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
from sqlmodel import Session, select
|
||||
|
||||
# build router
|
||||
router = APIRouter(prefix="/flows", tags=["Flows"])
|
||||
|
|
@ -35,7 +28,8 @@ def create_flow(
|
|||
if flow.user_id is None:
|
||||
flow.user_id = current_user.id
|
||||
|
||||
db_flow = Flow.from_orm(flow)
|
||||
db_flow = Flow.model_validate(flow, from_attributes=True)
|
||||
db_flow.updated_at = datetime.utcnow()
|
||||
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
|
|
@ -46,12 +40,12 @@ def create_flow(
|
|||
@router.get("/", response_model=list[FlowRead], status_code=200)
|
||||
def read_flows(
|
||||
*,
|
||||
session: Session = Depends(get_session),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Read all flows."""
|
||||
try:
|
||||
flows = current_user.flows
|
||||
flows = validate_is_component(flows)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
return [jsonable_encoder(flow) for flow in flows]
|
||||
|
|
@ -65,12 +59,7 @@ def read_flow(
|
|||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Read a flow."""
|
||||
if user_flow := (
|
||||
session.query(Flow)
|
||||
.filter(Flow.id == flow_id)
|
||||
.filter(Flow.user_id == current_user.id)
|
||||
.first()
|
||||
):
|
||||
if user_flow := (session.exec(select(Flow).where(Flow.id == flow_id, Flow.user_id == current_user.id)).first()):
|
||||
return user_flow
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
|
|
@ -90,12 +79,13 @@ def update_flow(
|
|||
db_flow = read_flow(session=session, flow_id=flow_id, current_user=current_user)
|
||||
if not db_flow:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
flow_data = flow.dict(exclude_unset=True)
|
||||
flow_data = flow.model_dump(exclude_unset=True)
|
||||
if settings_service.settings.REMOVE_API_KEYS:
|
||||
flow_data = remove_api_keys(flow_data)
|
||||
for key, value in flow_data.items():
|
||||
if value is not None:
|
||||
setattr(db_flow, key, value)
|
||||
db_flow.updated_at = datetime.utcnow()
|
||||
session.add(db_flow)
|
||||
session.commit()
|
||||
session.refresh(db_flow)
|
||||
|
|
@ -169,5 +159,5 @@ async def download_file(
|
|||
current_user: User = Depends(get_current_active_user),
|
||||
):
|
||||
"""Download all flows as a file."""
|
||||
flows = read_flows(session=session, current_user=current_user)
|
||||
flows = read_flows(current_user=current_user)
|
||||
return FlowListRead(flows=flows)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,15 @@
|
|||
from sqlmodel import Session
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.getters import get_session
|
||||
from langflow.api.v1.schemas import Token
|
||||
from langflow.services.auth.utils import (
|
||||
authenticate_user,
|
||||
create_user_tokens,
|
||||
create_refresh_token,
|
||||
create_user_longterm_token,
|
||||
get_current_active_user,
|
||||
create_user_tokens,
|
||||
)
|
||||
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
|
||||
router = APIRouter(tags=["Login"])
|
||||
|
||||
|
|
@ -44,9 +41,7 @@ async def login_to_get_access_token(
|
|||
|
||||
|
||||
@router.get("/auto_login")
|
||||
async def auto_login(
|
||||
db: Session = Depends(get_session), settings_service=Depends(get_settings_service)
|
||||
):
|
||||
async def auto_login(db: Session = Depends(get_session), settings_service=Depends(get_settings_service)):
|
||||
if settings_service.auth_settings.AUTO_LOGIN:
|
||||
return create_user_longterm_token(db)
|
||||
|
||||
|
|
@ -60,9 +55,7 @@ async def auto_login(
|
|||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
token: str, current_user: Session = Depends(get_current_active_user)
|
||||
):
|
||||
async def refresh_token(token: str):
|
||||
if token:
|
||||
return create_refresh_token(token)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2,12 +2,13 @@ from enum import Enum
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
from langflow.services.database.models.api_key.api_key import ApiKeyRead
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from langflow.services.database.models.api_key.model import ApiKeyRead
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
from langflow.services.database.models.flow import FlowCreate, FlowRead
|
||||
from langflow.services.database.models.user import UserRead
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class BuildStatus(Enum):
|
||||
|
|
@ -91,7 +92,8 @@ class ChatResponse(ChatMessage):
|
|||
is_bot: bool = True
|
||||
files: list = []
|
||||
|
||||
@validator("type")
|
||||
@field_validator("type")
|
||||
@classmethod
|
||||
def validate_message_type(cls, v):
|
||||
if v not in ["start", "stream", "end", "error", "info", "file"]:
|
||||
raise ValueError("type must be start, stream, end, error, info, or file")
|
||||
|
|
@ -109,12 +111,13 @@ class PromptResponse(ChatMessage):
|
|||
class FileResponse(ChatMessage):
|
||||
"""File response schema."""
|
||||
|
||||
data: Any
|
||||
data: Any = None
|
||||
data_type: str
|
||||
type: str = "file"
|
||||
is_bot: bool = True
|
||||
|
||||
@validator("data_type")
|
||||
@field_validator("data_type")
|
||||
@classmethod
|
||||
def validate_data_type(cls, v):
|
||||
if v not in ["image", "csv"]:
|
||||
raise ValueError("data_type must be image or csv")
|
||||
|
|
@ -149,13 +152,13 @@ class StreamData(BaseModel):
|
|||
data: dict
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
|
||||
)
|
||||
return f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n"
|
||||
|
||||
|
||||
class CustomComponentCode(BaseModel):
|
||||
code: str
|
||||
field: Optional[str] = None
|
||||
frontend_node: Optional[dict] = None
|
||||
|
||||
|
||||
class CustomComponentResponseError(BaseModel):
|
||||
|
|
@ -198,3 +201,7 @@ class Token(BaseModel):
|
|||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class ApiKeyCreateRequest(BaseModel):
|
||||
api_key: str
|
||||
|
|
|
|||
193
src/backend/langflow/api/v1/store.py
Normal file
193
src/backend/langflow/api/v1/store.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
from typing import Annotated, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from langflow.api.utils import check_langflow_version
|
||||
from langflow.services.auth import utils as auth_utils
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_settings_service, get_store_service
|
||||
from langflow.services.store.exceptions import CustomException
|
||||
from langflow.services.store.schema import (
|
||||
CreateComponentResponse,
|
||||
DownloadComponentResponse,
|
||||
ListComponentResponseModel,
|
||||
StoreComponentCreate,
|
||||
TagResponse,
|
||||
UsersLikesResponse,
|
||||
)
|
||||
from langflow.services.store.service import StoreService
|
||||
|
||||
router = APIRouter(prefix="/store", tags=["Components Store"])
|
||||
|
||||
|
||||
def get_user_store_api_key(
|
||||
user: User = Depends(auth_utils.get_current_active_user),
|
||||
settings_service=Depends(get_settings_service),
|
||||
):
|
||||
if not user.store_api_key:
|
||||
raise HTTPException(status_code=400, detail="You must have a store API key set.")
|
||||
decrypted = auth_utils.decrypt_api_key(user.store_api_key, settings_service)
|
||||
return decrypted
|
||||
|
||||
|
||||
def get_optional_user_store_api_key(
|
||||
user: User = Depends(auth_utils.get_current_active_user),
|
||||
settings_service=Depends(get_settings_service),
|
||||
):
|
||||
if not user.store_api_key:
|
||||
return None
|
||||
decrypted = auth_utils.decrypt_api_key(user.store_api_key, settings_service)
|
||||
return decrypted
|
||||
|
||||
|
||||
@router.get("/check/")
|
||||
def check_if_store_is_enabled(
|
||||
settings_service=Depends(get_settings_service),
|
||||
):
|
||||
return {
|
||||
"enabled": settings_service.settings.STORE,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/check/api_key")
|
||||
async def check_if_store_has_api_key(
|
||||
api_key: Optional[str] = Depends(get_optional_user_store_api_key),
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
):
|
||||
if api_key is None:
|
||||
return {"has_api_key": False, "is_valid": False}
|
||||
|
||||
try:
|
||||
is_valid = await store_service.check_api_key(api_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return {"has_api_key": api_key is not None, "is_valid": is_valid}
|
||||
|
||||
|
||||
@router.post("/components/", response_model=CreateComponentResponse, status_code=201)
|
||||
async def share_component(
|
||||
component: StoreComponentCreate,
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
store_api_key: str = Depends(get_user_store_api_key),
|
||||
):
|
||||
try:
|
||||
await check_langflow_version(component)
|
||||
result = await store_service.upload(store_api_key, component)
|
||||
return result
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.patch("/components/{component_id}", response_model=CreateComponentResponse, status_code=201)
|
||||
async def update_shared_component(
|
||||
component_id: UUID,
|
||||
component: StoreComponentCreate,
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
store_api_key: str = Depends(get_user_store_api_key),
|
||||
):
|
||||
try:
|
||||
await check_langflow_version(component)
|
||||
result = await store_service.update(store_api_key, component_id, component)
|
||||
return result
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/components/", response_model=ListComponentResponseModel)
|
||||
async def get_components(
|
||||
component_id: Annotated[Optional[str], Query()] = None,
|
||||
search: Annotated[Optional[str], Query()] = None,
|
||||
private: Annotated[Optional[bool], Query()] = None,
|
||||
is_component: Annotated[Optional[bool], Query()] = None,
|
||||
tags: Annotated[Optional[list[str]], Query()] = None,
|
||||
sort: Annotated[Union[list[str], None], Query()] = None,
|
||||
liked: Annotated[bool, Query()] = False,
|
||||
filter_by_user: Annotated[bool, Query()] = False,
|
||||
fields: Annotated[Optional[list[str]], Query()] = None,
|
||||
page: int = 1,
|
||||
limit: int = 10,
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
store_api_key: Optional[str] = Depends(get_optional_user_store_api_key),
|
||||
):
|
||||
try:
|
||||
return await store_service.get_list_component_response_model(
|
||||
component_id=component_id,
|
||||
search=search,
|
||||
private=private,
|
||||
is_component=is_component,
|
||||
fields=fields,
|
||||
tags=tags,
|
||||
sort=sort,
|
||||
liked=liked,
|
||||
filter_by_user=filter_by_user,
|
||||
page=page,
|
||||
limit=limit,
|
||||
store_api_key=store_api_key,
|
||||
)
|
||||
except CustomException as exc:
|
||||
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.get("/components/{component_id}", response_model=DownloadComponentResponse)
|
||||
async def download_component(
|
||||
component_id: UUID,
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
store_api_key: str = Depends(get_user_store_api_key),
|
||||
):
|
||||
try:
|
||||
component = await store_service.download(store_api_key, component_id)
|
||||
except CustomException as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
if component is None:
|
||||
raise HTTPException(status_code=400, detail="Component not found")
|
||||
|
||||
return component
|
||||
|
||||
|
||||
@router.get("/tags", response_model=List[TagResponse])
|
||||
async def get_tags(
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
):
|
||||
try:
|
||||
return await store_service.get_tags()
|
||||
except CustomException as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/users/likes", response_model=List[UsersLikesResponse])
|
||||
async def get_list_of_components_liked_by_user(
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
store_api_key: str = Depends(get_user_store_api_key),
|
||||
):
|
||||
try:
|
||||
return await store_service.get_user_likes(store_api_key)
|
||||
except CustomException as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/users/likes/{component_id}", response_model=UsersLikesResponse)
|
||||
async def like_component(
|
||||
component_id: UUID,
|
||||
store_service: StoreService = Depends(get_store_service),
|
||||
store_api_key: str = Depends(get_user_store_api_key),
|
||||
):
|
||||
try:
|
||||
result = await store_service.like_component(store_api_key, str(component_id))
|
||||
likes_count = await store_service.get_component_likes_count(str(component_id), store_api_key)
|
||||
|
||||
return UsersLikesResponse(likes_count=likes_count, liked_by_user=result)
|
||||
except CustomException as exc:
|
||||
raise HTTPException(status_code=exc.status_code, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
|
@ -1,29 +1,20 @@
|
|||
from uuid import UUID
|
||||
from langflow.api.v1.schemas import UsersResponse
|
||||
from langflow.services.database.models.user import (
|
||||
User,
|
||||
UserCreate,
|
||||
UserRead,
|
||||
UserUpdate,
|
||||
)
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from langflow.services.getters import get_session, get_settings_service
|
||||
from langflow.api.v1.schemas import UsersResponse
|
||||
from langflow.services.auth.utils import (
|
||||
get_current_active_superuser,
|
||||
get_current_active_user,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from langflow.services.database.models.user.crud import (
|
||||
get_user_by_id,
|
||||
update_user,
|
||||
)
|
||||
from langflow.services.database.models.user import User, UserCreate, UserRead, UserUpdate
|
||||
from langflow.services.database.models.user.crud import get_user_by_id, update_user
|
||||
from langflow.services.deps import get_session, get_settings_service
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
|
||||
router = APIRouter(tags=["Users"], prefix="/users")
|
||||
|
||||
|
|
@ -46,9 +37,7 @@ def add_user(
|
|||
session.refresh(new_user)
|
||||
except IntegrityError as e:
|
||||
session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="This username is unavailable."
|
||||
) from e
|
||||
raise HTTPException(status_code=400, detail="This username is unavailable.") from e
|
||||
|
||||
return new_user
|
||||
|
||||
|
|
@ -73,15 +62,15 @@ def read_all_users(
|
|||
"""
|
||||
Retrieve a list of users from the database with pagination.
|
||||
"""
|
||||
query = select(User).offset(skip).limit(limit)
|
||||
users = session.execute(query).fetchall()
|
||||
query: SelectOfScalar = select(User).offset(skip).limit(limit)
|
||||
users = session.exec(query).fetchall()
|
||||
|
||||
count_query = select(func.count()).select_from(User) # type: ignore
|
||||
total_count = session.execute(count_query).scalar()
|
||||
total_count = session.exec(count_query).first()
|
||||
|
||||
return UsersResponse(
|
||||
total_count=total_count, # type: ignore
|
||||
users=[UserRead(**dict(user.User)) for user in users],
|
||||
users=[UserRead(**user.model_dump()) for user in users],
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -96,14 +85,10 @@ def patch_user(
|
|||
Update an existing user's data.
|
||||
"""
|
||||
if not user.is_superuser and user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have the permission to update this user"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="You don't have the permission to update this user")
|
||||
if user_update.password:
|
||||
if not user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't change your password here"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="You can't change your password here")
|
||||
user_update.password = get_password_hash(user_update.password)
|
||||
|
||||
if user_db := get_user_by_id(session, user_id):
|
||||
|
|
@ -123,16 +108,12 @@ def reset_password(
|
|||
Reset a user's password.
|
||||
"""
|
||||
if user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't change another user's password"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="You can't change another user's password")
|
||||
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if verify_password(user_update.password, user.password):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't use your current password"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="You can't use your current password")
|
||||
new_password = get_password_hash(user_update.password)
|
||||
user.password = new_password
|
||||
session.commit()
|
||||
|
|
@ -151,15 +132,11 @@ def delete_user(
|
|||
Delete a user from the database.
|
||||
"""
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete your own user account"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="You can't delete your own user account")
|
||||
elif not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have the permission to delete this user"
|
||||
)
|
||||
raise HTTPException(status_code=403, detail="You don't have the permission to delete this user")
|
||||
|
||||
user_db = session.query(User).filter(User.id == user_id).first()
|
||||
user_db = session.exec(select(User).where(User.id == user_id)).first()
|
||||
if not user_db:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.v1.base import (
|
||||
Code,
|
||||
CodeValidationResponse,
|
||||
ValidatePromptRequest,
|
||||
PromptValidationResponse,
|
||||
ValidatePromptRequest,
|
||||
validate_prompt,
|
||||
)
|
||||
from langflow.template.field.base import TemplateField
|
||||
from loguru import logger
|
||||
from langflow.utils.validate import validate_code
|
||||
from loguru import logger
|
||||
|
||||
# build router
|
||||
router = APIRouter(prefix="/validate", tags=["Validate"])
|
||||
|
|
@ -41,9 +40,7 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
|
|||
|
||||
add_new_variables_to_template(input_variables, prompt_request)
|
||||
|
||||
remove_old_variables_from_template(
|
||||
old_custom_fields, input_variables, prompt_request
|
||||
)
|
||||
remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request)
|
||||
|
||||
update_input_variables_field(input_variables, prompt_request)
|
||||
|
||||
|
|
@ -58,19 +55,16 @@ def post_validate_prompt(prompt_request: ValidatePromptRequest):
|
|||
|
||||
def get_old_custom_fields(prompt_request):
|
||||
try:
|
||||
if (
|
||||
len(prompt_request.frontend_node.custom_fields) == 1
|
||||
and prompt_request.name == ""
|
||||
):
|
||||
if len(prompt_request.frontend_node.custom_fields) == 1 and prompt_request.name == "":
|
||||
# If there is only one custom field and the name is empty string
|
||||
# then we are dealing with the first prompt request after the node was created
|
||||
prompt_request.name = list(
|
||||
prompt_request.frontend_node.custom_fields.keys()
|
||||
)[0]
|
||||
prompt_request.name = list(prompt_request.frontend_node.custom_fields.keys())[0]
|
||||
|
||||
old_custom_fields = prompt_request.frontend_node.custom_fields[
|
||||
prompt_request.name
|
||||
].copy()
|
||||
old_custom_fields = prompt_request.frontend_node.custom_fields[prompt_request.name]
|
||||
if old_custom_fields is None:
|
||||
old_custom_fields = []
|
||||
|
||||
old_custom_fields = old_custom_fields.copy()
|
||||
except KeyError:
|
||||
old_custom_fields = []
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name] = []
|
||||
|
|
@ -92,40 +86,26 @@ def add_new_variables_to_template(input_variables, prompt_request):
|
|||
)
|
||||
if variable in prompt_request.frontend_node.template:
|
||||
# Set the new field with the old value
|
||||
template_field.value = prompt_request.frontend_node.template[variable][
|
||||
"value"
|
||||
]
|
||||
template_field.value = prompt_request.frontend_node.template[variable]["value"]
|
||||
|
||||
prompt_request.frontend_node.template[variable] = template_field.to_dict()
|
||||
|
||||
# Check if variable is not already in the list before appending
|
||||
if (
|
||||
variable
|
||||
not in prompt_request.frontend_node.custom_fields[prompt_request.name]
|
||||
):
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name].append(
|
||||
variable
|
||||
)
|
||||
if variable not in prompt_request.frontend_node.custom_fields[prompt_request.name]:
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name].append(variable)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def remove_old_variables_from_template(
|
||||
old_custom_fields, input_variables, prompt_request
|
||||
):
|
||||
def remove_old_variables_from_template(old_custom_fields, input_variables, prompt_request):
|
||||
for variable in old_custom_fields:
|
||||
if variable not in input_variables:
|
||||
try:
|
||||
# Remove the variable from custom_fields associated with the given name
|
||||
if (
|
||||
variable
|
||||
in prompt_request.frontend_node.custom_fields[prompt_request.name]
|
||||
):
|
||||
prompt_request.frontend_node.custom_fields[
|
||||
prompt_request.name
|
||||
].remove(variable)
|
||||
if variable in prompt_request.frontend_node.custom_fields[prompt_request.name]:
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name].remove(variable)
|
||||
|
||||
# Remove the variable from the template
|
||||
prompt_request.frontend_node.template.pop(variable, None)
|
||||
|
|
@ -137,6 +117,4 @@ def remove_old_variables_from_template(
|
|||
|
||||
def update_input_variables_field(input_variables, prompt_request):
|
||||
if "input_variables" in prompt_request.frontend_node.template:
|
||||
prompt_request.frontend_node.template["input_variables"][
|
||||
"value"
|
||||
] = input_variables
|
||||
prompt_request.frontend_node.template["input_variables"]["value"] = input_variables
|
||||
|
|
|
|||
37
src/backend/langflow/components/agents/AgentInitializer.py
Normal file
37
src/backend/langflow/components/agents/AgentInitializer.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from typing import Callable, List, Union
|
||||
|
||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent, types
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import BaseChatMemory, BaseLanguageModel, Tool
|
||||
|
||||
|
||||
class AgentInitializerComponent(CustomComponent):
|
||||
display_name: str = "Agent Initializer"
|
||||
description: str = "Initialize a Langchain Agent."
|
||||
documentation: str = "https://python.langchain.com/docs/modules/agents/agent_types/"
|
||||
|
||||
def build_config(self):
|
||||
agents = list(types.AGENT_TO_CLASS.keys())
|
||||
# field_type and required are optional
|
||||
return {
|
||||
"agent": {"options": agents, "value": agents[0], "display_name": "Agent Type"},
|
||||
"max_iterations": {"display_name": "Max Iterations", "value": 10},
|
||||
"memory": {"display_name": "Memory"},
|
||||
"tools": {"display_name": "Tools"},
|
||||
"llm": {"display_name": "Language Model"},
|
||||
}
|
||||
|
||||
def build(
|
||||
self, agent: str, llm: BaseLanguageModel, memory: BaseChatMemory, tools: List[Tool], max_iterations: int
|
||||
) -> Union[AgentExecutor, Callable]:
|
||||
agent = AgentType(agent)
|
||||
return initialize_agent(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
agent=agent,
|
||||
memory=memory,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
|
@ -1,17 +1,15 @@
|
|||
from langflow import CustomComponent
|
||||
from typing import Optional
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
from langchain.tools import Tool
|
||||
from langchain.schema.memory import BaseMemory
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.conversational_retrieval.openai_functions import _get_default_system_message
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.memory.token_buffer import ConversationTokenBufferMemory
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
from langchain.prompts.chat import MessagesPlaceholder
|
||||
from langchain.agents.agent_toolkits.conversational_retrieval.openai_functions import (
|
||||
_get_default_system_message,
|
||||
)
|
||||
from langchain.schema.memory import BaseMemory
|
||||
from langchain.tools import Tool
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class ConversationalAgent(CustomComponent):
|
||||
|
|
@ -20,13 +18,14 @@ class ConversationalAgent(CustomComponent):
|
|||
|
||||
def build_config(self):
|
||||
openai_function_models = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
]
|
||||
return {
|
||||
"tools": {"is_list": True, "display_name": "Tools"},
|
||||
"tools": {"display_name": "Tools"},
|
||||
"memory": {"display_name": "Memory"},
|
||||
"system_message": {"display_name": "System Message"},
|
||||
"max_token_limit": {"display_name": "Max Token Limit"},
|
||||
|
|
@ -42,7 +41,7 @@ class ConversationalAgent(CustomComponent):
|
|||
self,
|
||||
model_name: str,
|
||||
openai_api_key: str,
|
||||
tools: Tool,
|
||||
tools: List[Tool],
|
||||
openai_api_base: Optional[str] = None,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
system_message: Optional[SystemMessagePromptTemplate] = None,
|
||||
|
|
@ -50,8 +49,8 @@ class ConversationalAgent(CustomComponent):
|
|||
) -> AgentExecutor:
|
||||
llm = ChatOpenAI(
|
||||
model=model_name,
|
||||
openai_api_key=openai_api_key,
|
||||
openai_api_base=openai_api_base,
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
if not memory:
|
||||
memory_key = "chat_history"
|
||||
|
|
@ -71,7 +70,9 @@ class ConversationalAgent(CustomComponent):
|
|||
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
|
||||
)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm, tools=tools, prompt=prompt # type: ignore
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt, # type: ignore
|
||||
)
|
||||
return AgentExecutor(
|
||||
agent=agent,
|
||||
|
|
|
|||
29
src/backend/langflow/components/chains/ConversationChain.py
Normal file
29
src/backend/langflow/components/chains/ConversationChain.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from langflow import CustomComponent
|
||||
from langchain.chains import ConversationChain
|
||||
from typing import Optional, Union, Callable
|
||||
from langflow.field_typing import BaseLanguageModel, BaseMemory, Chain
|
||||
|
||||
|
||||
class ConversationChainComponent(CustomComponent):
|
||||
display_name = "ConversationChain"
|
||||
description = "Chain to have a conversation and load context from memory."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"prompt": {"display_name": "Prompt"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"memory": {
|
||||
"display_name": "Memory",
|
||||
"info": "Memory to load context from. If none is provided, a ConversationBufferMemory will be used.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[Chain, Callable]:
|
||||
if memory is None:
|
||||
return ConversationChain(llm=llm)
|
||||
return ConversationChain(llm=llm, memory=memory)
|
||||
32
src/backend/langflow/components/chains/LLMChain.py
Normal file
32
src/backend/langflow/components/chains/LLMChain.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Callable, Optional, Union
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import (
|
||||
BaseLanguageModel,
|
||||
BaseMemory,
|
||||
BasePromptTemplate,
|
||||
Chain,
|
||||
)
|
||||
|
||||
|
||||
class LLMChainComponent(CustomComponent):
|
||||
display_name = "LLMChain"
|
||||
description = "Chain to run queries against LLMs"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"prompt": {"display_name": "Prompt"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"memory": {"display_name": "Memory"},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
prompt: BasePromptTemplate,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Union[Chain, Callable]:
|
||||
return LLMChain(prompt=prompt, llm=llm, memory=memory)
|
||||
|
|
@ -8,7 +8,7 @@ from langchain.schema import Document
|
|||
class PromptRunner(CustomComponent):
|
||||
display_name: str = "Prompt Runner"
|
||||
description: str = "Run a Chain with the given PromptTemplate"
|
||||
beta = True
|
||||
beta: bool = True
|
||||
field_config = {
|
||||
"llm": {"display_name": "LLM"},
|
||||
"prompt": {
|
||||
|
|
@ -18,9 +18,7 @@ class PromptRunner(CustomComponent):
|
|||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}
|
||||
) -> Document:
|
||||
def build(self, llm: BaseLLM, prompt: PromptTemplate, inputs: dict = {}) -> Document:
|
||||
chain = prompt | llm
|
||||
# The input is an empty dict because the prompt is already filled
|
||||
result = chain.invoke(input=inputs)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
from langflow import CustomComponent
|
||||
from langflow.field_typing import Data
|
||||
|
||||
|
||||
class Component(CustomComponent):
|
||||
documentation: str = "http://docs.langflow.org/components/custom"
|
||||
|
||||
def build_config(self):
|
||||
return {"param": {"display_name": "Parameter"}}
|
||||
|
||||
def build(self, param: Data) -> Data:
|
||||
return param
|
||||
114
src/backend/langflow/components/documentloaders/FileLoader.py
Normal file
114
src/backend/langflow/components/documentloaders/FileLoader.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from langchain.schema import Document
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.utils.constants import LOADERS_INFO
|
||||
|
||||
|
||||
class FileLoaderComponent(CustomComponent):
|
||||
display_name: str = "File Loader"
|
||||
description: str = "Generic File Loader"
|
||||
beta = True
|
||||
|
||||
def build_config(self):
|
||||
loader_options = ["Automatic"] + [loader_info["name"] for loader_info in LOADERS_INFO]
|
||||
|
||||
file_types = []
|
||||
suffixes = []
|
||||
|
||||
for loader_info in LOADERS_INFO:
|
||||
if "allowedTypes" in loader_info:
|
||||
file_types.extend(loader_info["allowedTypes"])
|
||||
suffixes.extend([f".{ext}" for ext in loader_info["allowedTypes"]])
|
||||
|
||||
return {
|
||||
"file_path": {
|
||||
"display_name": "File Path",
|
||||
"required": True,
|
||||
"field_type": "file",
|
||||
"file_types": [
|
||||
"json",
|
||||
"txt",
|
||||
"csv",
|
||||
"jsonl",
|
||||
"html",
|
||||
"htm",
|
||||
"conllu",
|
||||
"enex",
|
||||
"msg",
|
||||
"pdf",
|
||||
"srt",
|
||||
"eml",
|
||||
"md",
|
||||
"pptx",
|
||||
"docx",
|
||||
],
|
||||
"suffixes": [
|
||||
".json",
|
||||
".txt",
|
||||
".csv",
|
||||
".jsonl",
|
||||
".html",
|
||||
".htm",
|
||||
".conllu",
|
||||
".enex",
|
||||
".msg",
|
||||
".pdf",
|
||||
".srt",
|
||||
".eml",
|
||||
".md",
|
||||
".pptx",
|
||||
".docx",
|
||||
],
|
||||
# "file_types" : file_types,
|
||||
# "suffixes": suffixes,
|
||||
},
|
||||
"loader": {
|
||||
"display_name": "Loader",
|
||||
"is_list": True,
|
||||
"required": True,
|
||||
"options": loader_options,
|
||||
"value": "Automatic",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(self, file_path: str, loader: str) -> Document:
|
||||
file_type = file_path.split(".")[-1]
|
||||
|
||||
# Mapeie o nome do loader selecionado para suas informações
|
||||
selected_loader_info = None
|
||||
for loader_info in LOADERS_INFO:
|
||||
if loader_info["name"] == loader:
|
||||
selected_loader_info = loader_info
|
||||
break
|
||||
|
||||
if selected_loader_info is None and loader != "Automatic":
|
||||
raise ValueError(f"Loader {loader} not found in the loader info list")
|
||||
|
||||
if loader == "Automatic":
|
||||
# Determine o loader automaticamente com base na extensão do arquivo
|
||||
default_loader_info = None
|
||||
for info in LOADERS_INFO:
|
||||
if "defaultFor" in info and file_type in info["defaultFor"]:
|
||||
default_loader_info = info
|
||||
break
|
||||
|
||||
if default_loader_info is None:
|
||||
raise ValueError(f"No default loader found for file type: {file_type}")
|
||||
|
||||
selected_loader_info = default_loader_info
|
||||
if isinstance(selected_loader_info, dict):
|
||||
loader_import: str = selected_loader_info["import"]
|
||||
else:
|
||||
raise ValueError(f"Loader info for {loader} is not a dict\nLoader info:\n{selected_loader_info}")
|
||||
module_name, class_name = loader_import.rsplit(".", 1)
|
||||
|
||||
try:
|
||||
# Importe o loader dinamicamente
|
||||
loader_module = __import__(module_name, fromlist=[class_name])
|
||||
loader_instance = getattr(loader_module, class_name)
|
||||
except ImportError as e:
|
||||
raise ValueError(f"Loader {loader} could not be imported\nLoader info:\n{selected_loader_info}") from e
|
||||
|
||||
result = loader_instance(file_path=file_path)
|
||||
return result.load()
|
||||
46
src/backend/langflow/components/documentloaders/UrlLoader.py
Normal file
46
src/backend/langflow/components/documentloaders/UrlLoader.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from typing import List
|
||||
|
||||
from langchain import document_loaders
|
||||
from langchain.schema import Document
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class UrlLoaderComponent(CustomComponent):
|
||||
display_name: str = "Url Loader"
|
||||
description: str = "Generic Url Loader Component"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"web_path": {
|
||||
"display_name": "Url",
|
||||
"required": True,
|
||||
},
|
||||
"loader": {
|
||||
"display_name": "Loader",
|
||||
"is_list": True,
|
||||
"required": True,
|
||||
"options": [
|
||||
"AZLyricsLoader",
|
||||
"CollegeConfidentialLoader",
|
||||
"GitbookLoader",
|
||||
"HNLoader",
|
||||
"IFixitLoader",
|
||||
"IMSDbLoader",
|
||||
"WebBaseLoader",
|
||||
],
|
||||
"value": "WebBaseLoader",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(self, web_path: str, loader: str) -> List[Document]:
|
||||
try:
|
||||
loader_instance = getattr(document_loaders, loader)(web_path=web_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"No loader found for: {web_path}") from e
|
||||
docs = loader_instance.load()
|
||||
avg_length = sum(len(doc.page_content) for doc in docs if hasattr(doc, "page_content")) / len(docs)
|
||||
self.status = f"""{len(docs)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
Documents: {docs[:3]}..."""
|
||||
return docs
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import BedrockEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class AmazonBedrockEmeddingsComponent(CustomComponent):
|
||||
"""
|
||||
A custom component for implementing an Embeddings Model using Amazon Bedrock.
|
||||
"""
|
||||
|
||||
display_name: str = "Amazon Bedrock Embeddings"
|
||||
description: str = "Embeddings model from Amazon Bedrock."
|
||||
documentation = "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/bedrock"
|
||||
beta = True
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model_id": {
|
||||
"display_name": "Model Id",
|
||||
"options": ["amazon.titan-embed-text-v1"],
|
||||
},
|
||||
"credentials_profile_name": {"display_name": "Credentials Profile Name"},
|
||||
"endpoint_url": {"display_name": "Bedrock Endpoint URL"},
|
||||
"region_name": {"display_name": "AWS Region"},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
model_id: str = "amazon.titan-embed-text-v1",
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
endpoint_url: Optional[str] = None,
|
||||
region_name: Optional[str] = None,
|
||||
) -> Embeddings:
|
||||
try:
|
||||
output = BedrockEmbeddings(
|
||||
credentials_profile_name=credentials_profile_name,
|
||||
model_id=model_id,
|
||||
endpoint_url=endpoint_url,
|
||||
region_name=region_name,
|
||||
) # type: ignore
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to AmazonBedrock API.") from e
|
||||
return output
|
||||
0
src/backend/langflow/components/embeddings/__init__.py
Normal file
0
src/backend/langflow/components/embeddings/__init__.py
Normal file
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.base import BaseLanguageModel
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class AnthropicLLM(CustomComponent):
|
||||
|
|
@ -53,16 +56,16 @@ class AnthropicLLM(CustomComponent):
|
|||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
) -> BaseLLM:
|
||||
) -> BaseLanguageModel:
|
||||
# Set default API endpoint if not provided
|
||||
if not api_endpoint:
|
||||
api_endpoint = "https://api.anthropic.com"
|
||||
|
||||
try:
|
||||
output = ChatAnthropic(
|
||||
model=model,
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
max_tokens_to_sample=max_tokens,
|
||||
model_name=model,
|
||||
anthropic_api_key=SecretStr(anthropic_api_key) if anthropic_api_key else None,
|
||||
max_tokens_to_sample=max_tokens, # type: ignore
|
||||
temperature=temperature,
|
||||
anthropic_api_url=api_endpoint,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.llms.base import BaseLLM
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class QianfanChatEndpointComponent(CustomComponent):
|
||||
display_name: str = "QianfanChatEndpoint"
|
||||
description: str = (
|
||||
"Baidu Qianfan chat models. Get more detail from "
|
||||
"https://python.langchain.com/docs/integrations/chat/baidu_qianfan_endpoint."
|
||||
)
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"model": {
|
||||
"display_name": "Model Name",
|
||||
"options": [
|
||||
"ERNIE-Bot",
|
||||
"ERNIE-Bot-turbo",
|
||||
"BLOOMZ-7B",
|
||||
"Llama-2-7b-chat",
|
||||
"Llama-2-13b-chat",
|
||||
"Llama-2-70b-chat",
|
||||
"Qianfan-BLOOMZ-7B-compressed",
|
||||
"Qianfan-Chinese-Llama-2-7B",
|
||||
"ChatGLM2-6B-32K",
|
||||
"AquilaChat-7B",
|
||||
],
|
||||
"info": "https://python.langchain.com/docs/integrations/chat/baidu_qianfan_endpoint",
|
||||
"required": True,
|
||||
},
|
||||
"qianfan_ak": {
|
||||
"display_name": "Qianfan Ak",
|
||||
"required": True,
|
||||
"password": True,
|
||||
"info": "which you could get from https://cloud.baidu.com/product/wenxinworkshop",
|
||||
},
|
||||
"qianfan_sk": {
|
||||
"display_name": "Qianfan Sk",
|
||||
"required": True,
|
||||
"password": True,
|
||||
"info": "which you could get from https://cloud.baidu.com/product/wenxinworkshop",
|
||||
},
|
||||
"top_p": {
|
||||
"display_name": "Top p",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 0.8,
|
||||
},
|
||||
"temperature": {
|
||||
"display_name": "Temperature",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 0.95,
|
||||
},
|
||||
"penalty_score": {
|
||||
"display_name": "Penalty Score",
|
||||
"field_type": "float",
|
||||
"info": "Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo",
|
||||
"value": 1.0,
|
||||
},
|
||||
"endpoint": {
|
||||
"display_name": "Endpoint",
|
||||
"info": "Endpoint of the Qianfan LLM, required if custom model used.",
|
||||
},
|
||||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
model: str = "ERNIE-Bot-turbo",
|
||||
qianfan_ak: Optional[str] = None,
|
||||
qianfan_sk: Optional[str] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
penalty_score: Optional[float] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> BaseLLM:
|
||||
try:
|
||||
output = QianfanChatEndpoint( # type: ignore
|
||||
model=model,
|
||||
qianfan_ak=SecretStr(qianfan_ak) if qianfan_ak else None,
|
||||
qianfan_sk=SecretStr(qianfan_sk) if qianfan_sk else None,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
penalty_score=penalty_score,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Could not connect to Baidu Qianfan API.") from e
|
||||
return output # type: ignore
|
||||
|
|
@ -18,9 +18,7 @@ class MetalRetrieverComponent(CustomComponent):
|
|||
"code": {"show": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None
|
||||
) -> BaseRetriever:
|
||||
def build(self, api_key: str, client_id: str, index_id: str, params: Optional[dict] = None) -> BaseRetriever:
|
||||
try:
|
||||
metal = Metal(api_key=api_key, client_id=client_id, index_id=index_id)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,19 +1,18 @@
|
|||
from typing import List, Union
|
||||
from langflow import CustomComponent
|
||||
|
||||
from metaphor_python import Metaphor # type: ignore
|
||||
from langchain.tools import Tool
|
||||
from langchain.agents import tool
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.tools import Tool
|
||||
from metaphor_python import Metaphor # type: ignore
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class MetaphorToolkit(CustomComponent):
|
||||
display_name: str = "Metaphor"
|
||||
description: str = "Metaphor Toolkit"
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/integrations/tools/metaphor_search"
|
||||
)
|
||||
beta = True
|
||||
documentation = "https://python.langchain.com/docs/integrations/tools/metaphor_search"
|
||||
beta: bool = True
|
||||
# api key should be password = True
|
||||
field_config = {
|
||||
"metaphor_api_key": {"display_name": "Metaphor API Key", "password": True},
|
||||
|
|
@ -33,9 +32,7 @@ class MetaphorToolkit(CustomComponent):
|
|||
@tool
|
||||
def search(query: str):
|
||||
"""Call search engine with a query."""
|
||||
return client.search(
|
||||
query, use_autoprompt=use_autoprompt, num_results=search_num_results
|
||||
)
|
||||
return client.search(query, use_autoprompt=use_autoprompt, num_results=search_num_results)
|
||||
|
||||
@tool
|
||||
def get_contents(ids: List[str]):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class GetRequest(CustomComponent):
|
|||
description: str = "Make a GET request to the given URL."
|
||||
output_types: list[str] = ["Document"]
|
||||
documentation: str = "https://docs.langflow.org/components/utilities#get-request"
|
||||
beta = True
|
||||
beta: bool = True
|
||||
field_config = {
|
||||
"url": {
|
||||
"display_name": "URL",
|
||||
|
|
@ -30,9 +30,7 @@ class GetRequest(CustomComponent):
|
|||
},
|
||||
}
|
||||
|
||||
def get_document(
|
||||
self, session: requests.Session, url: str, headers: Optional[dict], timeout: int
|
||||
) -> Document:
|
||||
def get_document(self, session: requests.Session, url: str, headers: Optional[dict], timeout: int) -> Document:
|
||||
try:
|
||||
response = session.get(url, headers=headers, timeout=int(timeout))
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@
|
|||
|
||||
# - **Document:** The Document containing the JSON object.
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langchain.schema import Document
|
||||
from langflow import CustomComponent
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
|
||||
|
|
@ -21,9 +21,7 @@ class JSONDocumentBuilder(CustomComponent):
|
|||
description: str = "Build a Document containing a JSON object using a key and another Document page content."
|
||||
output_types: list[str] = ["Document"]
|
||||
beta = True
|
||||
documentation: str = (
|
||||
"https://docs.langflow.org/components/utilities#json-document-builder"
|
||||
)
|
||||
documentation: str = "https://docs.langflow.org/components/utilities#json-document-builder"
|
||||
|
||||
field_config = {
|
||||
"key": {"display_name": "Key"},
|
||||
|
|
@ -38,18 +36,11 @@ class JSONDocumentBuilder(CustomComponent):
|
|||
documents = None
|
||||
if isinstance(document, list):
|
||||
documents = [
|
||||
Document(
|
||||
page_content=orjson_dumps({key: doc.page_content}, indent_2=False)
|
||||
)
|
||||
for doc in document
|
||||
Document(page_content=orjson_dumps({key: doc.page_content}, indent_2=False)) for doc in document
|
||||
]
|
||||
elif isinstance(document, Document):
|
||||
documents = Document(
|
||||
page_content=orjson_dumps({key: document.page_content}, indent_2=False)
|
||||
)
|
||||
documents = Document(page_content=orjson_dumps({key: document.page_content}, indent_2=False))
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected Document or list of Documents, got {type(document)}"
|
||||
)
|
||||
raise TypeError(f"Expected Document or list of Documents, got {type(document)}")
|
||||
self.repr_value = documents
|
||||
return documents
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class PostRequest(CustomComponent):
|
|||
description: str = "Make a POST request to the given URL."
|
||||
output_types: list[str] = ["Document"]
|
||||
documentation: str = "https://docs.langflow.org/components/utilities#post-request"
|
||||
beta = True
|
||||
beta: bool = True
|
||||
field_config = {
|
||||
"url": {"display_name": "URL", "info": "The URL to make the request to."},
|
||||
"headers": {
|
||||
|
|
@ -65,16 +65,12 @@ class PostRequest(CustomComponent):
|
|||
|
||||
if not isinstance(document, list) and isinstance(document, Document):
|
||||
documents: list[Document] = [document]
|
||||
elif isinstance(document, list) and all(
|
||||
isinstance(doc, Document) for doc in document
|
||||
):
|
||||
elif isinstance(document, list) and all(isinstance(doc, Document) for doc in document):
|
||||
documents = document
|
||||
else:
|
||||
raise ValueError("document must be a Document or a list of Documents")
|
||||
|
||||
with requests.Session() as session:
|
||||
documents = [
|
||||
self.post_document(session, doc, url, headers) for doc in documents
|
||||
]
|
||||
documents = [self.post_document(session, doc, url, headers) for doc in documents]
|
||||
self.repr_value = documents
|
||||
return documents
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class UpdateRequest(CustomComponent):
|
|||
description: str = "Make a PATCH request to the given URL."
|
||||
output_types: list[str] = ["Document"]
|
||||
documentation: str = "https://docs.langflow.org/components/utilities#update-request"
|
||||
beta = True
|
||||
beta: bool = True
|
||||
field_config = {
|
||||
"url": {"display_name": "URL", "info": "The URL to make the request to."},
|
||||
"headers": {
|
||||
|
|
@ -39,9 +39,7 @@ class UpdateRequest(CustomComponent):
|
|||
) -> Document:
|
||||
try:
|
||||
if method == "PATCH":
|
||||
response = session.patch(
|
||||
url, headers=headers, data=document.page_content
|
||||
)
|
||||
response = session.patch(url, headers=headers, data=document.page_content)
|
||||
elif method == "PUT":
|
||||
response = session.put(url, headers=headers, data=document.page_content)
|
||||
else:
|
||||
|
|
@ -78,17 +76,12 @@ class UpdateRequest(CustomComponent):
|
|||
|
||||
if not isinstance(document, list) and isinstance(document, Document):
|
||||
documents: list[Document] = [document]
|
||||
elif isinstance(document, list) and all(
|
||||
isinstance(doc, Document) for doc in document
|
||||
):
|
||||
elif isinstance(document, list) and all(isinstance(doc, Document) for doc in document):
|
||||
documents = document
|
||||
else:
|
||||
raise ValueError("document must be a Document or a list of Documents")
|
||||
|
||||
with requests.Session() as session:
|
||||
documents = [
|
||||
self.update_document(session, doc, url, headers, method)
|
||||
for doc in documents
|
||||
]
|
||||
documents = [self.update_document(session, doc, url, headers, method) for doc in documents]
|
||||
self.repr_value = documents
|
||||
return documents
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from typing import Optional, Union
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.embeddings.base import Embeddings
|
||||
import chromadb # type: ignore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class ChromaComponent(CustomComponent):
|
||||
|
|
@ -17,7 +17,7 @@ class ChromaComponent(CustomComponent):
|
|||
display_name: str = "Chroma (Custom Component)"
|
||||
description: str = "Implementation of Vector Store using Chroma"
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/chroma"
|
||||
beta = True
|
||||
beta: bool = True
|
||||
|
||||
def build_config(self):
|
||||
"""
|
||||
|
|
@ -53,9 +53,9 @@ class ChromaComponent(CustomComponent):
|
|||
self,
|
||||
collection_name: str,
|
||||
persist: bool,
|
||||
embedding: Embeddings,
|
||||
chroma_server_ssl_enabled: bool,
|
||||
persist_directory: Optional[str] = None,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
documents: Optional[Document] = None,
|
||||
chroma_server_cors_allow_origins: Optional[str] = None,
|
||||
chroma_server_host: Optional[str] = None,
|
||||
|
|
@ -86,8 +86,7 @@ class ChromaComponent(CustomComponent):
|
|||
|
||||
if chroma_server_host is not None:
|
||||
chroma_settings = chromadb.config.Settings(
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
|
||||
or None,
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
|
||||
chroma_server_host=chroma_server_host,
|
||||
chroma_server_port=chroma_server_port or None,
|
||||
chroma_server_grpc_port=chroma_server_grpc_port or None,
|
||||
|
|
@ -96,6 +95,8 @@ class ChromaComponent(CustomComponent):
|
|||
|
||||
# If documents, then we need to create a Chroma instance using .from_documents
|
||||
if documents is not None and embedding is not None:
|
||||
if len(documents) == 0:
|
||||
raise ValueError("If documents are provided, there must be at least one document.")
|
||||
return Chroma.from_documents(
|
||||
documents=documents, # type: ignore
|
||||
persist_directory=persist_directory if persist else None,
|
||||
|
|
@ -104,6 +105,4 @@ class ChromaComponent(CustomComponent):
|
|||
client_settings=chroma_settings,
|
||||
)
|
||||
|
||||
return Chroma(
|
||||
persist_directory=persist_directory, client_settings=chroma_settings
|
||||
)
|
||||
return Chroma(persist_directory=persist_directory, client_settings=chroma_settings)
|
||||
|
|
|
|||
64
src/backend/langflow/components/vectorstores/Redis.py
Normal file
64
src/backend/langflow/components/vectorstores/Redis.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Optional
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langchain.vectorstores.redis import Redis
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class RedisComponent(CustomComponent):
|
||||
"""
|
||||
A custom component for implementing a Vector Store using Redis.
|
||||
"""
|
||||
|
||||
display_name: str = "Redis"
|
||||
description: str = "Implementation of Vector Store using Redis"
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/redis"
|
||||
beta = True
|
||||
|
||||
def build_config(self):
|
||||
"""
|
||||
Builds the configuration for the component.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing the configuration options for the component.
|
||||
"""
|
||||
return {
|
||||
"index_name": {"display_name": "Index Name", "value": "your_index"},
|
||||
"code": {"show": False, "display_name": "Code"},
|
||||
"documents": {"display_name": "Documents", "is_list": True},
|
||||
"embedding": {"display_name": "Embedding"},
|
||||
"redis_server_url": {
|
||||
"display_name": "Redis Server Connection String",
|
||||
"advanced": False,
|
||||
},
|
||||
"redis_index_name": {"display_name": "Redis Index", "advanced": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
redis_server_url: str,
|
||||
redis_index_name: str,
|
||||
documents: Optional[Document] = None,
|
||||
) -> VectorStore:
|
||||
"""
|
||||
Builds the Vector Store or BaseRetriever object.
|
||||
|
||||
Args:
|
||||
- embedding (Embeddings): The embeddings to use for the Vector Store.
|
||||
- documents (Optional[Document]): The documents to use for the Vector Store.
|
||||
- redis_index_name (str): The name of the Redis index.
|
||||
- redis_server_url (str): The URL for the Redis server.
|
||||
|
||||
Returns:
|
||||
- VectorStore: The Vector Store object.
|
||||
"""
|
||||
|
||||
return Redis.from_documents(
|
||||
documents=documents, # type: ignore
|
||||
embedding=embedding,
|
||||
redis_url=redis_server_url,
|
||||
index_name=redis_index_name,
|
||||
)
|
||||
|
|
@ -1,19 +1,17 @@
|
|||
from typing import Optional, Union
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langchain.vectorstores import Vectara
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.vectorstores import Vectara
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
class VectaraComponent(CustomComponent):
|
||||
display_name: str = "Vectara"
|
||||
description: str = "Implementation of Vector Store using Vectara"
|
||||
documentation = (
|
||||
"https://python.langchain.com/docs/integrations/vectorstores/vectara"
|
||||
)
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara"
|
||||
beta = True
|
||||
# api key should be password = True
|
||||
field_config = {
|
||||
|
|
|
|||
74
src/backend/langflow/components/vectorstores/pgvector.py
Normal file
74
src/backend/langflow/components/vectorstores/pgvector.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from typing import Optional, List
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langchain.vectorstores.pgvector import PGVector
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class PostgresqlVectorComponent(CustomComponent):
|
||||
"""
|
||||
A custom component for implementing a Vector Store using PostgreSQL.
|
||||
"""
|
||||
|
||||
display_name: str = "PGVector"
|
||||
description: str = "Implementation of Vector Store using PostgreSQL"
|
||||
documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector"
|
||||
beta = True
|
||||
|
||||
def build_config(self):
|
||||
"""
|
||||
Builds the configuration for the component.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing the configuration options for the component.
|
||||
"""
|
||||
return {
|
||||
"index_name": {"display_name": "Index Name", "value": "your_index"},
|
||||
"code": {"show": True, "display_name": "Code"},
|
||||
"documents": {"display_name": "Documents", "is_list": True},
|
||||
"embedding": {"display_name": "Embedding"},
|
||||
"pg_server_url": {
|
||||
"display_name": "PostgreSQL Server Connection String",
|
||||
"advanced": False,
|
||||
},
|
||||
"collection_name": {"display_name": "Table", "advanced": False},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
pg_server_url: str,
|
||||
collection_name: str,
|
||||
documents: Optional[List[Document]] = None,
|
||||
) -> VectorStore:
|
||||
"""
|
||||
Builds the Vector Store or BaseRetriever object.
|
||||
|
||||
Args:
|
||||
- embedding (Embeddings): The embeddings to use for the Vector Store.
|
||||
- documents (Optional[Document]): The documents to use for the Vector Store.
|
||||
- collection_name (str): The name of the PG table.
|
||||
- pg_server_url (str): The URL for the PG server.
|
||||
|
||||
Returns:
|
||||
- VectorStore: The Vector Store object.
|
||||
"""
|
||||
|
||||
try:
|
||||
if documents is None:
|
||||
return PGVector.from_existing_index(
|
||||
embedding=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_string=pg_server_url,
|
||||
)
|
||||
|
||||
return PGVector.from_documents(
|
||||
embedding=embedding,
|
||||
documents=documents,
|
||||
collection_name=collection_name,
|
||||
connection_string=pg_server_url,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to build PGVector: {e}")
|
||||
|
|
@ -5,8 +5,6 @@ agents:
|
|||
documentation: "https://python.langchain.com/docs/modules/agents/toolkits/openapi"
|
||||
CSVAgent:
|
||||
documentation: "https://python.langchain.com/docs/modules/agents/toolkits/csv"
|
||||
AgentInitializer:
|
||||
documentation: "https://python.langchain.com/docs/modules/agents/agent_types/"
|
||||
VectorStoreAgent:
|
||||
documentation: ""
|
||||
VectorStoreRouterAgent:
|
||||
|
|
@ -14,14 +12,14 @@ agents:
|
|||
SQLAgent:
|
||||
documentation: ""
|
||||
chains:
|
||||
LLMChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/foundational/llm_chain"
|
||||
# LLMChain:
|
||||
# documentation: "https://python.langchain.com/docs/modules/chains/foundational/llm_chain"
|
||||
LLMMathChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_math"
|
||||
LLMCheckerChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
|
||||
ConversationChain:
|
||||
documentation: ""
|
||||
# ConversationChain:
|
||||
# documentation: ""
|
||||
SeriesCharacterChain:
|
||||
documentation: ""
|
||||
MidJourneyPromptChain:
|
||||
|
|
@ -106,6 +104,9 @@ embeddings:
|
|||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/cohere"
|
||||
VertexAIEmbeddings:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/google_vertex_ai_palm"
|
||||
AmazonBedrockEmbeddings:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/bedrock"
|
||||
|
||||
llms:
|
||||
OpenAI:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/openai"
|
||||
|
|
@ -294,4 +295,4 @@ output_parsers:
|
|||
documentation: "https://python.langchain.com/docs/modules/model_io/output_parsers/structured"
|
||||
custom_components:
|
||||
CustomComponent:
|
||||
documentation: ""
|
||||
documentation: "https://docs.langflow.org/guidelines/custom-component"
|
||||
|
|
|
|||
|
|
@ -3,12 +3,16 @@ import os
|
|||
|
||||
langflow_redis_host = os.environ.get("LANGFLOW_REDIS_HOST")
|
||||
langflow_redis_port = os.environ.get("LANGFLOW_REDIS_PORT")
|
||||
if "BROKER_URL" in os.environ and "RESULT_BACKEND" in os.environ:
|
||||
# RabbitMQ
|
||||
broker_url = os.environ.get("BROKER_URL", "amqp://localhost")
|
||||
result_backend = os.environ.get("RESULT_BACKEND", "redis://localhost:6379/0")
|
||||
elif langflow_redis_host and langflow_redis_port:
|
||||
# broker default user
|
||||
|
||||
if langflow_redis_host and langflow_redis_port:
|
||||
broker_url = f"redis://{langflow_redis_host}:{langflow_redis_port}/0"
|
||||
result_backend = f"redis://{langflow_redis_host}:{langflow_redis_port}/0"
|
||||
else:
|
||||
# RabbitMQ
|
||||
mq_user = os.environ.get("RABBITMQ_DEFAULT_USER", "langflow")
|
||||
mq_password = os.environ.get("RABBITMQ_DEFAULT_PASS", "langflow")
|
||||
broker_url = os.environ.get("BROKER_URL", f"amqp://{mq_user}:{mq_password}@localhost:5672//")
|
||||
result_backend = os.environ.get("RESULT_BACKEND", "redis://localhost:6379/0")
|
||||
# tasks should be json or pickle
|
||||
accept_content = ["json", "pickle"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,56 @@
|
|||
from .base import NestedDict
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["NestedDict"]
|
||||
from .constants import (AgentExecutor, BaseChatMemory, BaseLanguageModel,
|
||||
BaseLLM, BaseLoader, BaseMemory, BaseOutputParser,
|
||||
BasePromptTemplate, BaseRetriever, Callable, Chain,
|
||||
ChatPromptTemplate, Data, Document, Embeddings,
|
||||
NestedDict, Object, Prompt, PromptTemplate,
|
||||
TextSplitter, Tool, VectorStore)
|
||||
from .range_spec import RangeSpec
|
||||
|
||||
|
||||
def _import_template_field():
|
||||
from langflow.template.field.base import TemplateField
|
||||
|
||||
return TemplateField
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
# This is to avoid circular imports
|
||||
if name == "TemplateField":
|
||||
return _import_template_field()
|
||||
elif name == "RangeSpec":
|
||||
return RangeSpec
|
||||
# The other names should work as if they were imported from constants
|
||||
# Import the constants module langflow.field_typing.constants
|
||||
from . import constants
|
||||
|
||||
return getattr(constants, name)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NestedDict",
|
||||
"Data",
|
||||
"Tool",
|
||||
"PromptTemplate",
|
||||
"Chain",
|
||||
"BaseChatMemory",
|
||||
"BaseLLM",
|
||||
"BaseLanguageModel",
|
||||
"BaseLoader",
|
||||
"BaseMemory",
|
||||
"BaseOutputParser",
|
||||
"BaseRetriever",
|
||||
"VectorStore",
|
||||
"Embeddings",
|
||||
"TextSplitter",
|
||||
"Document",
|
||||
"AgentExecutor",
|
||||
"Object",
|
||||
"Callable",
|
||||
"BasePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"Prompt",
|
||||
"RangeSpec",
|
||||
"TemplateField",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,23 +1,44 @@
|
|||
from langchain.prompts import PromptTemplate
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.prompts import BasePromptTemplate, ChatPromptTemplate, PromptTemplate
|
||||
from langchain.schema import BaseOutputParser, BaseRetriever, Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.memory import BaseMemory
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from langchain.tools import Tool
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.schema import BaseOutputParser
|
||||
from langchain.schema.memory import BaseMemory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
|
||||
# Type alias for more complex dicts
|
||||
NestedDict = Dict[str, Union[str, Dict]]
|
||||
|
||||
|
||||
class Object:
|
||||
pass
|
||||
|
||||
|
||||
class Data:
|
||||
pass
|
||||
|
||||
|
||||
class Prompt:
|
||||
pass
|
||||
|
||||
|
||||
LANGCHAIN_BASE_TYPES = {
|
||||
"Chain": Chain,
|
||||
"AgentExecutor": AgentExecutor,
|
||||
"Tool": Tool,
|
||||
"BaseLLM": BaseLLM,
|
||||
"BaseLanguageModel": BaseLanguageModel,
|
||||
"PromptTemplate": PromptTemplate,
|
||||
"ChatPromptTemplate": ChatPromptTemplate,
|
||||
"BasePromptTemplate": BasePromptTemplate,
|
||||
"BaseLoader": BaseLoader,
|
||||
"Document": Document,
|
||||
"TextSplitter": TextSplitter,
|
||||
|
|
@ -28,38 +49,12 @@ LANGCHAIN_BASE_TYPES = {
|
|||
"BaseMemory": BaseMemory,
|
||||
"BaseChatMemory": BaseChatMemory,
|
||||
}
|
||||
|
||||
# Langchain base types plus Python base types
|
||||
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
|
||||
**LANGCHAIN_BASE_TYPES,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"NestedDict": NestedDict,
|
||||
"Data": Data,
|
||||
"Object": Object,
|
||||
"Callable": Callable,
|
||||
"Prompt": Prompt,
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_CUSTOM_COMPONENT_CODE = """from langflow import CustomComponent
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import Document
|
||||
|
||||
import requests
|
||||
|
||||
class YourComponent(CustomComponent):
|
||||
display_name: str = "Custom Component"
|
||||
description: str = "Create any custom component you want!"
|
||||
|
||||
def build_config(self):
|
||||
return { "url": { "multiline": True, "required": True } }
|
||||
|
||||
def build(self, url: str, llm: BaseLLM, prompt: PromptTemplate) -> Document:
|
||||
response = requests.get(url)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
result = chain.run(response.text[:300])
|
||||
return Document(page_content=str(result))
|
||||
"""
|
||||
21
src/backend/langflow/field_typing/range_spec.py
Normal file
21
src/backend/langflow/field_typing/range_spec.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class RangeSpec(BaseModel):
|
||||
min: float = -1.0
|
||||
max: float = 1.0
|
||||
step: float = 0.1
|
||||
|
||||
@field_validator("max")
|
||||
@classmethod
|
||||
def max_must_be_greater_than_min(cls, v, values, **kwargs):
|
||||
if "min" in values.data and v <= values.data["min"]:
|
||||
raise ValueError("max must be greater than min")
|
||||
return v
|
||||
|
||||
@field_validator("step")
|
||||
@classmethod
|
||||
def step_must_be_positive(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError("step must be positive")
|
||||
return v
|
||||
|
|
@ -1,46 +1,77 @@
|
|||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from typing import TYPE_CHECKING
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
||||
class SourceHandle(BaseModel):
|
||||
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
|
||||
dataType: str = Field(..., description="Data type for the source handle.")
|
||||
id: str = Field(..., description="Unique identifier for the source handle.")
|
||||
|
||||
|
||||
class TargetHandle(BaseModel):
|
||||
fieldName: str = Field(..., description="Field name for the target handle.")
|
||||
id: str = Field(..., description="Unique identifier for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
|
||||
type: str = Field(..., description="Type of the target handle.")
|
||||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, source: "Vertex", target: "Vertex", edge: dict):
|
||||
self.source: "Vertex" = source
|
||||
self.target: "Vertex" = target
|
||||
self.source_handle = edge.get("sourceHandle", "")
|
||||
self.target_handle = edge.get("targetHandle", "")
|
||||
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
|
||||
# target_param is documents
|
||||
self.target_param = self.target_handle.split("|")[1]
|
||||
self.source_id: str = source.id if source else ""
|
||||
self.target_id: str = target.id if target else ""
|
||||
if data := edge.get("data", {}):
|
||||
self._source_handle = data.get("sourceHandle", {})
|
||||
self._target_handle = data.get("targetHandle", {})
|
||||
self.source_handle: SourceHandle = SourceHandle(**self._source_handle)
|
||||
self.target_handle: TargetHandle = TargetHandle(**self._target_handle)
|
||||
self.target_param = self.target_handle.fieldName
|
||||
# validate handles
|
||||
self.validate_handles(source, target)
|
||||
else:
|
||||
# Logging here because this is a breaking change
|
||||
logger.error("Edge data is empty")
|
||||
self._source_handle = edge.get("sourceHandle", "")
|
||||
self._target_handle = edge.get("targetHandle", "")
|
||||
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
|
||||
# target_param is documents
|
||||
self.target_param = self._target_handle.split("|")[1]
|
||||
# Validate in __init__ to fail fast
|
||||
self.validate_edge(source, target)
|
||||
|
||||
self.validate_edge()
|
||||
def validate_handles(self, source, target) -> None:
|
||||
if self.target_handle.inputTypes is None:
|
||||
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
|
||||
else:
|
||||
self.valid_handles = (
|
||||
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
|
||||
or self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
if not self.valid_handles:
|
||||
logger.debug(self.source_handle)
|
||||
logger.debug(self.target_handle)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.source = state["source"]
|
||||
self.target = state["target"]
|
||||
self.source_id = state["source_id"]
|
||||
self.target_id = state["target_id"]
|
||||
self.target_param = state["target_param"]
|
||||
self.source_handle = state["source_handle"]
|
||||
self.target_handle = state["target_handle"]
|
||||
|
||||
def reset(self) -> None:
|
||||
self.source._build_params()
|
||||
self.target._build_params()
|
||||
|
||||
def validate_edge(self) -> None:
|
||||
def validate_edge(self, source, target) -> None:
|
||||
# Validate that the outputs of the source node are valid inputs
|
||||
# for the target node
|
||||
self.source_types = self.source.output
|
||||
self.target_reqs = self.target.required_inputs + self.target.optional_inputs
|
||||
self.source_types = source.output
|
||||
self.target_reqs = target.required_inputs + target.optional_inputs
|
||||
# Both lists contain strings and sometimes a string contains the value we are
|
||||
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
|
||||
# so we need to check if any of the strings in source_types is in target_reqs
|
||||
self.valid = any(
|
||||
output in target_req
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
)
|
||||
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
|
||||
# Get what type of input the target node is expecting
|
||||
|
||||
self.matched_type = next(
|
||||
|
|
@ -51,14 +82,11 @@ class Edge:
|
|||
if no_matched_type:
|
||||
logger.debug(self.source_types)
|
||||
logger.debug(self.target_reqs)
|
||||
raise ValueError(
|
||||
f"Edge between {self.source.vertex_type} and {self.target.vertex_type} "
|
||||
f"has no matched type"
|
||||
)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Edge(source={self.source.id}, target={self.target.id}, target_param={self.target_param}"
|
||||
f"Edge(source={self.source_id}, target={self.target_id}, target_param={self.target_param}"
|
||||
f", matched_type={self.matched_type})"
|
||||
)
|
||||
|
||||
|
|
@ -66,8 +94,4 @@ class Edge:
|
|||
return hash(self.__repr__())
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
return (
|
||||
self.__repr__() == __value.__repr__()
|
||||
if isinstance(__value, Edge)
|
||||
else False
|
||||
)
|
||||
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
|
||||
|
|
|
|||
|
|
@ -1,36 +1,45 @@
|
|||
from typing import Dict, Generator, List, Type, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.utils import process_flow
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import (
|
||||
FileToolVertex,
|
||||
LLMVertex,
|
||||
ToolkitVertex,
|
||||
)
|
||||
from langflow.graph.vertex.types import (FileToolVertex, LLMVertex,
|
||||
ToolkitVertex)
|
||||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.utils import payload
|
||||
from loguru import logger
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class Graph:
|
||||
"""A class representing a graph of nodes and edges."""
|
||||
"""A class representing a graph of vertices and edges."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[Dict[str, Union[str, Dict[str, Union[str, List[str]]]]]],
|
||||
edges: List[Dict[str, str]],
|
||||
) -> None:
|
||||
self._nodes = nodes
|
||||
self._vertices = nodes
|
||||
self._edges = edges
|
||||
self.raw_graph_data = {"nodes": nodes, "edges": edges}
|
||||
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
if vertex_id := vertex.get("id"):
|
||||
self.top_level_vertices.append(vertex_id)
|
||||
self._graph_data = process_flow(self.raw_graph_data)
|
||||
|
||||
self._vertices = self._graph_data["nodes"]
|
||||
self._edges = self._graph_data["edges"]
|
||||
self._build_graph()
|
||||
|
||||
def __getstate__(self):
|
||||
return self.raw_graph_data
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
for edge in self.edges:
|
||||
edge.reset()
|
||||
edge.validate_edge()
|
||||
self.__init__(**state)
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Dict) -> "Graph":
|
||||
|
|
@ -46,9 +55,9 @@ class Graph:
|
|||
if "data" in payload:
|
||||
payload = payload["data"]
|
||||
try:
|
||||
nodes = payload["nodes"]
|
||||
vertices = payload["nodes"]
|
||||
edges = payload["edges"]
|
||||
return cls(nodes, edges)
|
||||
return cls(vertices, edges)
|
||||
except KeyError as exc:
|
||||
raise ValueError(
|
||||
f"Invalid payload. Expected keys 'nodes' and 'edges'. Found {list(payload.keys())}"
|
||||
|
|
@ -60,65 +69,69 @@ class Graph:
|
|||
return self.__repr__() == other.__repr__()
|
||||
|
||||
def _build_graph(self) -> None:
|
||||
"""Builds the graph from the nodes and edges."""
|
||||
self.nodes = self._build_vertices()
|
||||
"""Builds the graph from the vertices and edges."""
|
||||
self.vertices = self._build_vertices()
|
||||
self.vertex_map = {vertex.id: vertex for vertex in self.vertices}
|
||||
self.edges = self._build_edges()
|
||||
for edge in self.edges:
|
||||
edge.source.add_edge(edge)
|
||||
edge.target.add_edge(edge)
|
||||
|
||||
# This is a hack to make sure that the LLM node is sent to
|
||||
# the toolkit node
|
||||
self._build_node_params()
|
||||
# remove invalid nodes
|
||||
self._validate_nodes()
|
||||
# This is a hack to make sure that the LLM vertex is sent to
|
||||
# the toolkit vertex
|
||||
self._build_vertex_params()
|
||||
# remove invalid vertices
|
||||
self._validate_vertices()
|
||||
|
||||
def _build_node_params(self) -> None:
|
||||
"""Identifies and handles the LLM node within the graph."""
|
||||
llm_node = None
|
||||
for node in self.nodes:
|
||||
node._build_params()
|
||||
if isinstance(node, LLMVertex):
|
||||
llm_node = node
|
||||
def _build_vertex_params(self) -> None:
|
||||
"""Identifies and handles the LLM vertex within the graph."""
|
||||
llm_vertex = None
|
||||
for vertex in self.vertices:
|
||||
vertex._build_params()
|
||||
if isinstance(vertex, LLMVertex):
|
||||
llm_vertex = vertex
|
||||
|
||||
if llm_node:
|
||||
for node in self.nodes:
|
||||
if isinstance(node, ToolkitVertex):
|
||||
node.params["llm"] = llm_node
|
||||
if llm_vertex:
|
||||
for vertex in self.vertices:
|
||||
if isinstance(vertex, ToolkitVertex):
|
||||
vertex.params["llm"] = llm_vertex
|
||||
|
||||
def _validate_nodes(self) -> None:
|
||||
"""Check that all nodes have edges"""
|
||||
if len(self.nodes) == 1:
|
||||
def _validate_vertices(self) -> None:
|
||||
"""Check that all vertices have edges"""
|
||||
if len(self.vertices) == 1:
|
||||
return
|
||||
for node in self.nodes:
|
||||
if not self._validate_node(node):
|
||||
raise ValueError(
|
||||
f"{node.vertex_type} is not connected to any other components"
|
||||
)
|
||||
for vertex in self.vertices:
|
||||
if not self._validate_vertex(vertex):
|
||||
raise ValueError(f"{vertex.vertex_type} is not connected to any other components")
|
||||
|
||||
def _validate_node(self, node: Vertex) -> bool:
|
||||
"""Validates a node."""
|
||||
# All nodes that do not have edges are invalid
|
||||
return len(node.edges) > 0
|
||||
def _validate_vertex(self, vertex: Vertex) -> bool:
|
||||
"""Validates a vertex."""
|
||||
# All vertices that do not have edges are invalid
|
||||
return len(self.get_vertex_edges(vertex.id)) > 0
|
||||
|
||||
def get_node(self, node_id: str) -> Union[None, Vertex]:
|
||||
"""Returns a node by id."""
|
||||
return next((node for node in self.nodes if node.id == node_id), None)
|
||||
def get_vertex(self, vertex_id: str) -> Union[None, Vertex]:
|
||||
"""Returns a vertex by id."""
|
||||
return self.vertex_map.get(vertex_id)
|
||||
|
||||
def get_nodes_with_target(self, node: Vertex) -> List[Vertex]:
|
||||
"""Returns the nodes connected to a node."""
|
||||
connected_nodes: List[Vertex] = [
|
||||
edge.source for edge in self.edges if edge.target == node
|
||||
]
|
||||
return connected_nodes
|
||||
def get_vertex_edges(self, vertex_id: str) -> List[Edge]:
|
||||
"""Returns a list of edges for a given vertex."""
|
||||
return [edge for edge in self.edges if edge.source_id == vertex_id or edge.target_id == vertex_id]
|
||||
|
||||
def build(self) -> Chain:
|
||||
def get_vertices_with_target(self, vertex_id: str) -> List[Vertex]:
|
||||
"""Returns the vertices connected to a vertex."""
|
||||
vertices: List[Vertex] = []
|
||||
for edge in self.edges:
|
||||
if edge.target_id == vertex_id:
|
||||
vertex = self.get_vertex(edge.source_id)
|
||||
if vertex is None:
|
||||
continue
|
||||
vertices.append(vertex)
|
||||
return vertices
|
||||
|
||||
async def build(self) -> Chain:
|
||||
"""Builds the graph."""
|
||||
# Get root node
|
||||
root_node = payload.get_root_node(self)
|
||||
if root_node is None:
|
||||
raise ValueError("No root node found")
|
||||
return root_node.build()
|
||||
# Get root vertex
|
||||
root_vertex = payload.get_root_vertex(self)
|
||||
if root_vertex is None:
|
||||
raise ValueError("No root vertex found")
|
||||
return await root_vertex.build()
|
||||
|
||||
def topological_sort(self) -> List[Vertex]:
|
||||
"""
|
||||
|
|
@ -131,27 +144,25 @@ class Graph:
|
|||
ValueError: If the graph contains a cycle.
|
||||
"""
|
||||
# States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||
state = {node: 0 for node in self.nodes}
|
||||
state = {vertex: 0 for vertex in self.vertices}
|
||||
sorted_vertices = []
|
||||
|
||||
def dfs(node):
|
||||
if state[node] == 1:
|
||||
def dfs(vertex):
|
||||
if state[vertex] == 1:
|
||||
# We have a cycle
|
||||
raise ValueError(
|
||||
"Graph contains a cycle, cannot perform topological sort"
|
||||
)
|
||||
if state[node] == 0:
|
||||
state[node] = 1
|
||||
for edge in node.edges:
|
||||
if edge.source == node:
|
||||
dfs(edge.target)
|
||||
state[node] = 2
|
||||
sorted_vertices.append(node)
|
||||
raise ValueError("Graph contains a cycle, cannot perform topological sort")
|
||||
if state[vertex] == 0:
|
||||
state[vertex] = 1
|
||||
for edge in vertex.edges:
|
||||
if edge.source_id == vertex.id:
|
||||
dfs(self.get_vertex(edge.target_id))
|
||||
state[vertex] = 2
|
||||
sorted_vertices.append(vertex)
|
||||
|
||||
# Visit each node
|
||||
for node in self.nodes:
|
||||
if state[node] == 0:
|
||||
dfs(node)
|
||||
# Visit each vertex
|
||||
for vertex in self.vertices:
|
||||
if state[vertex] == 0:
|
||||
dfs(vertex)
|
||||
|
||||
return list(reversed(sorted_vertices))
|
||||
|
||||
|
|
@ -161,17 +172,21 @@ class Graph:
|
|||
logger.debug("There are %s vertices in the graph", len(sorted_vertices))
|
||||
yield from sorted_vertices
|
||||
|
||||
def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]:
|
||||
"""Returns the neighbors of a node."""
|
||||
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
|
||||
"""Returns the neighbors of a vertex."""
|
||||
neighbors: Dict[Vertex, int] = {}
|
||||
for edge in self.edges:
|
||||
if edge.source == node:
|
||||
neighbor = edge.target
|
||||
if edge.source_id == vertex.id:
|
||||
neighbor = self.get_vertex(edge.target_id)
|
||||
if neighbor is None:
|
||||
continue
|
||||
if neighbor not in neighbors:
|
||||
neighbors[neighbor] = 0
|
||||
neighbors[neighbor] += 1
|
||||
elif edge.target == node:
|
||||
neighbor = edge.source
|
||||
elif edge.target_id == vertex.id:
|
||||
neighbor = self.get_vertex(edge.source_id)
|
||||
if neighbor is None:
|
||||
continue
|
||||
if neighbor not in neighbors:
|
||||
neighbors[neighbor] = 0
|
||||
neighbors[neighbor] += 1
|
||||
|
|
@ -179,59 +194,59 @@ class Graph:
|
|||
|
||||
def _build_edges(self) -> List[Edge]:
|
||||
"""Builds the edges of the graph."""
|
||||
# Edge takes two nodes as arguments, so we need to build the nodes first
|
||||
# Edge takes two vertices as arguments, so we need to build the vertices first
|
||||
# and then build the edges
|
||||
# if we can't find a node, we raise an error
|
||||
# if we can't find a vertex, we raise an error
|
||||
|
||||
edges: List[Edge] = []
|
||||
for edge in self._edges:
|
||||
source = self.get_node(edge["source"])
|
||||
target = self.get_node(edge["target"])
|
||||
source = self.get_vertex(edge["source"])
|
||||
target = self.get_vertex(edge["target"])
|
||||
if source is None:
|
||||
raise ValueError(f"Source node {edge['source']} not found")
|
||||
raise ValueError(f"Source vertex {edge['source']} not found")
|
||||
if target is None:
|
||||
raise ValueError(f"Target node {edge['target']} not found")
|
||||
raise ValueError(f"Target vertex {edge['target']} not found")
|
||||
edges.append(Edge(source, target, edge))
|
||||
return edges
|
||||
|
||||
def _get_vertex_class(self, node_type: str, node_lc_type: str) -> Type[Vertex]:
|
||||
"""Returns the node class based on the node type."""
|
||||
if node_type in FILE_TOOLS:
|
||||
def _get_vertex_class(self, vertex_type: str, vertex_lc_type: str) -> Type[Vertex]:
|
||||
"""Returns the vertex class based on the vertex type."""
|
||||
if vertex_type in FILE_TOOLS:
|
||||
return FileToolVertex
|
||||
if node_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type]
|
||||
if vertex_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_type]
|
||||
return (
|
||||
lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_lc_type]
|
||||
if node_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
|
||||
lazy_load_vertex_dict.VERTEX_TYPE_MAP[vertex_lc_type]
|
||||
if vertex_lc_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP
|
||||
else Vertex
|
||||
)
|
||||
|
||||
def _build_vertices(self) -> List[Vertex]:
|
||||
"""Builds the vertices of the graph."""
|
||||
nodes: List[Vertex] = []
|
||||
for node in self._nodes:
|
||||
node_data = node["data"]
|
||||
node_type: str = node_data["type"] # type: ignore
|
||||
node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore
|
||||
vertices: List[Vertex] = []
|
||||
for vertex in self._vertices:
|
||||
vertex_data = vertex["data"]
|
||||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_lc_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
|
||||
VertexClass = self._get_vertex_class(node_type, node_lc_type)
|
||||
nodes.append(VertexClass(node))
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_lc_type)
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
vertices.append(vertex_instance)
|
||||
|
||||
return nodes
|
||||
return vertices
|
||||
|
||||
def get_children_by_node_type(self, node: Vertex, node_type: str) -> List[Vertex]:
|
||||
"""Returns the children of a node based on the node type."""
|
||||
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
|
||||
"""Returns the children of a vertex based on the vertex type."""
|
||||
children = []
|
||||
node_types = [node.data["type"]]
|
||||
if "node" in node.data:
|
||||
node_types += node.data["node"]["base_classes"]
|
||||
if node_type in node_types:
|
||||
children.append(node)
|
||||
vertex_types = [vertex.data["type"]]
|
||||
if "node" in vertex.data:
|
||||
vertex_types += vertex.data["node"]["base_classes"]
|
||||
if vertex_type in vertex_types:
|
||||
children.append(vertex)
|
||||
return children
|
||||
|
||||
def __repr__(self):
|
||||
node_ids = [node.id for node in self.nodes]
|
||||
edges_repr = "\n".join(
|
||||
[f"{edge.source.id} --> {edge.target.id}" for edge in self.edges]
|
||||
)
|
||||
return f"Graph:\nNodes: {node_ids}\nConnections:\n{edges_repr}"
|
||||
vertex_ids = [vertex.id for vertex in self.vertices]
|
||||
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
|
||||
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
|
||||
|
|
|
|||
|
|
@ -47,10 +47,7 @@ class VertexTypesDict(LazyLoadDictBase):
|
|||
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
|
||||
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
|
||||
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
|
||||
**{
|
||||
t: types.CustomComponentVertex
|
||||
for t in custom_component_creator.to_list()
|
||||
},
|
||||
**{t: types.CustomComponentVertex for t in custom_component_creator.to_list()},
|
||||
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,247 @@
|
|||
import copy
|
||||
from collections import deque
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def find_last_node(nodes, edges):
|
||||
"""
|
||||
This function receives a flow and returns the last node.
|
||||
"""
|
||||
return next((n for n in nodes if all(e["source"] != n["id"] for e in edges)), None)
|
||||
|
||||
|
||||
def add_parent_node_id(nodes, parent_node_id):
|
||||
"""
|
||||
This function receives a list of nodes and adds a parent_node_id to each node.
|
||||
"""
|
||||
for node in nodes:
|
||||
node["parent_node_id"] = parent_node_id
|
||||
|
||||
|
||||
def ungroup_node(group_node_data, base_flow):
|
||||
template, flow = (
|
||||
group_node_data["node"]["template"],
|
||||
group_node_data["node"]["flow"],
|
||||
)
|
||||
parent_node_id = group_node_data["id"]
|
||||
g_nodes = flow["data"]["nodes"]
|
||||
add_parent_node_id(g_nodes, parent_node_id)
|
||||
g_edges = flow["data"]["edges"]
|
||||
|
||||
# Redirect edges to the correct proxy node
|
||||
updated_edges = get_updated_edges(base_flow, g_nodes, g_edges, group_node_data["id"])
|
||||
|
||||
# Update template values
|
||||
update_template(template, g_nodes)
|
||||
|
||||
nodes = [n for n in base_flow["nodes"] if n["id"] != group_node_data["id"]] + g_nodes
|
||||
edges = (
|
||||
[e for e in base_flow["edges"] if e["target"] != group_node_data["id"] and e["source"] != group_node_data["id"]]
|
||||
+ g_edges
|
||||
+ updated_edges
|
||||
)
|
||||
|
||||
base_flow["nodes"] = nodes
|
||||
base_flow["edges"] = edges
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def raw_topological_sort(nodes, edges) -> List[Dict]:
|
||||
# Redefine the above function but using the nodes and self._edges
|
||||
# which are dicts instead of Vertex and Edge objects
|
||||
# nodes have an id, edges have a source and target keys
|
||||
# return a list of node ids in topological order
|
||||
|
||||
# States: 0 = unvisited, 1 = visiting, 2 = visited
|
||||
state = {node["id"]: 0 for node in nodes}
|
||||
nodes_dict = {node["id"]: node for node in nodes}
|
||||
sorted_vertices = []
|
||||
|
||||
def dfs(node):
|
||||
if state[node] == 1:
|
||||
# We have a cycle
|
||||
raise ValueError("Graph contains a cycle, cannot perform topological sort")
|
||||
if state[node] == 0:
|
||||
state[node] = 1
|
||||
for edge in edges:
|
||||
if edge["source"] == node:
|
||||
dfs(edge["target"])
|
||||
state[node] = 2
|
||||
sorted_vertices.append(node)
|
||||
|
||||
# Visit each node
|
||||
for node in nodes:
|
||||
if state[node["id"]] == 0:
|
||||
dfs(node["id"])
|
||||
|
||||
reverse_sorted = list(reversed(sorted_vertices))
|
||||
return [nodes_dict[node_id] for node_id in reverse_sorted]
|
||||
|
||||
|
||||
def process_flow(flow_object):
|
||||
cloned_flow = copy.deepcopy(flow_object)
|
||||
processed_nodes = set() # To keep track of processed nodes
|
||||
|
||||
def process_node(node):
|
||||
node_id = node.get("id")
|
||||
|
||||
# If node already processed, skip
|
||||
if node_id in processed_nodes:
|
||||
return
|
||||
|
||||
if node.get("data") and node["data"].get("node") and node["data"]["node"].get("flow"):
|
||||
process_flow(node["data"]["node"]["flow"]["data"])
|
||||
new_nodes = ungroup_node(node["data"], cloned_flow)
|
||||
# Add new nodes to the queue for future processing
|
||||
nodes_to_process.extend(new_nodes)
|
||||
|
||||
# Mark node as processed
|
||||
processed_nodes.add(node_id)
|
||||
|
||||
sorted_nodes_list = raw_topological_sort(cloned_flow["nodes"], cloned_flow["edges"])
|
||||
nodes_to_process = deque(sorted_nodes_list)
|
||||
|
||||
while nodes_to_process:
|
||||
node = nodes_to_process.popleft()
|
||||
process_node(node)
|
||||
|
||||
return cloned_flow
|
||||
|
||||
|
||||
def update_template(template, g_nodes):
|
||||
"""
|
||||
Updates the template of a node in a graph with the given template.
|
||||
|
||||
Args:
|
||||
template (dict): The new template to update the node with.
|
||||
g_nodes (list): The list of nodes in the graph.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for _, value in template.items():
|
||||
if not value.get("proxy"):
|
||||
continue
|
||||
proxy_dict = value["proxy"]
|
||||
field, id_ = proxy_dict["field"], proxy_dict["id"]
|
||||
node_index = next((i for i, n in enumerate(g_nodes) if n["id"] == id_), -1)
|
||||
if node_index != -1:
|
||||
display_name = None
|
||||
show = g_nodes[node_index]["data"]["node"]["template"][field]["show"]
|
||||
advanced = g_nodes[node_index]["data"]["node"]["template"][field]["advanced"]
|
||||
if "display_name" in g_nodes[node_index]["data"]["node"]["template"][field]:
|
||||
display_name = g_nodes[node_index]["data"]["node"]["template"][field]["display_name"]
|
||||
else:
|
||||
display_name = g_nodes[node_index]["data"]["node"]["template"][field]["name"]
|
||||
|
||||
g_nodes[node_index]["data"]["node"]["template"][field] = value
|
||||
g_nodes[node_index]["data"]["node"]["template"][field]["show"] = show
|
||||
g_nodes[node_index]["data"]["node"]["template"][field]["advanced"] = advanced
|
||||
g_nodes[node_index]["data"]["node"]["template"][field]["display_name"] = display_name
|
||||
|
||||
|
||||
def update_target_handle(
|
||||
new_edge,
|
||||
g_nodes,
|
||||
group_node_id,
|
||||
):
|
||||
"""
|
||||
Updates the target handle of a given edge if it is a proxy node.
|
||||
|
||||
Args:
|
||||
new_edge (dict): The edge to update.
|
||||
g_nodes (list): The list of nodes in the graph.
|
||||
group_node_id (str): The ID of the group node.
|
||||
|
||||
Returns:
|
||||
dict: The updated edge.
|
||||
"""
|
||||
target_handle = new_edge["data"]["targetHandle"]
|
||||
if target_handle.get("proxy"):
|
||||
proxy_id = target_handle["proxy"]["id"]
|
||||
if node := next((n for n in g_nodes if n["id"] == proxy_id), None):
|
||||
set_new_target_handle(proxy_id, new_edge, target_handle, node)
|
||||
else:
|
||||
raise ValueError(f"Group node {group_node_id} has an invalid target proxy node {proxy_id}")
|
||||
return new_edge
|
||||
|
||||
|
||||
def set_new_target_handle(proxy_id, new_edge, target_handle, node):
|
||||
"""
|
||||
Sets a new target handle for a given edge.
|
||||
|
||||
Args:
|
||||
proxy_id (str): The ID of the proxy.
|
||||
new_edge (dict): The new edge to be created.
|
||||
target_handle (dict): The target handle of the edge.
|
||||
node (dict): The node containing the edge.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
new_edge["target"] = proxy_id
|
||||
_type = target_handle.get("type")
|
||||
if _type is None:
|
||||
raise KeyError("The 'type' key must be present in target_handle.")
|
||||
|
||||
field = target_handle["proxy"]["field"]
|
||||
new_target_handle = {
|
||||
"fieldName": field,
|
||||
"type": _type,
|
||||
"id": proxy_id,
|
||||
}
|
||||
if node["data"]["node"].get("flow"):
|
||||
new_target_handle["proxy"] = {
|
||||
"field": node["data"]["node"]["template"][field]["proxy"]["field"],
|
||||
"id": node["data"]["node"]["template"][field]["proxy"]["id"],
|
||||
}
|
||||
if input_types := target_handle.get("inputTypes"):
|
||||
new_target_handle["inputTypes"] = input_types
|
||||
new_edge["data"]["targetHandle"] = new_target_handle
|
||||
|
||||
|
||||
def update_source_handle(new_edge, g_nodes, g_edges):
|
||||
"""
|
||||
Updates the source handle of a given edge to the last node in the flow data.
|
||||
|
||||
Args:
|
||||
new_edge (dict): The edge to update.
|
||||
flow_data (dict): The flow data containing the nodes and edges.
|
||||
|
||||
Returns:
|
||||
dict: The updated edge with the new source handle.
|
||||
"""
|
||||
last_node = copy.deepcopy(find_last_node(g_nodes, g_edges))
|
||||
new_edge["source"] = last_node["id"]
|
||||
new_source_handle = new_edge["data"]["sourceHandle"]
|
||||
new_source_handle["id"] = last_node["id"]
|
||||
new_edge["data"]["sourceHandle"] = new_source_handle
|
||||
return new_edge
|
||||
|
||||
|
||||
def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
|
||||
"""
|
||||
Given a base flow, a list of graph nodes and a group node id, returns a list of updated edges.
|
||||
An updated edge is an edge that has its target or source handle updated based on the group node id.
|
||||
|
||||
Args:
|
||||
base_flow (dict): The base flow containing a list of edges.
|
||||
g_nodes (list): A list of graph nodes.
|
||||
group_node_id (str): The id of the group node.
|
||||
|
||||
Returns:
|
||||
list: A list of updated edges.
|
||||
"""
|
||||
updated_edges = []
|
||||
for edge in base_flow["edges"]:
|
||||
new_edge = copy.deepcopy(edge)
|
||||
if new_edge["target"] == group_node_id:
|
||||
new_edge = update_target_handle(new_edge, g_nodes, group_node_id)
|
||||
|
||||
if new_edge["source"] == group_node_id:
|
||||
new_edge = update_source_handle(new_edge, g_nodes, g_edges)
|
||||
|
||||
if edge["target"] == group_node_id or edge["source"] == group_node_id:
|
||||
updated_edges.append(new_edge)
|
||||
return updated_edges
|
||||
|
|
@ -1,35 +1,33 @@
|
|||
import ast
|
||||
import pickle
|
||||
import inspect
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.utils import UnbuiltObject
|
||||
from langflow.graph.vertex.utils import is_basic_type
|
||||
from langflow.interface.initialize import loading
|
||||
from langflow.interface.listing import lazy_load_dict
|
||||
from langflow.utils.constants import DIRECT_TYPES
|
||||
from loguru import logger
|
||||
from langflow.utils.util import sync_to_async
|
||||
|
||||
|
||||
import inspect
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.graph.edge.base import Edge
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
||||
|
||||
class Vertex:
|
||||
def __init__(
|
||||
self,
|
||||
data: Dict,
|
||||
graph: "Graph",
|
||||
base_type: Optional[str] = None,
|
||||
is_task: bool = False,
|
||||
params: Optional[Dict] = None,
|
||||
) -> None:
|
||||
self.graph = graph
|
||||
self.id: str = data["id"]
|
||||
self._data = data
|
||||
self.edges: List["Edge"] = []
|
||||
self.base_type: Optional[str] = base_type
|
||||
self._parse_data()
|
||||
self._built_object = UnbuiltObject()
|
||||
|
|
@ -39,45 +37,28 @@ class Vertex:
|
|||
self.is_task = is_task
|
||||
self.params = params or {}
|
||||
|
||||
def reset_params(self):
|
||||
for edge in self.edges:
|
||||
if edge.source != self:
|
||||
target_param = edge.target_param
|
||||
if target_param in ["document", "texts"]:
|
||||
# this means they got data and have already ingested it
|
||||
# so we continue after removing the param
|
||||
self.params.pop(target_param, None)
|
||||
continue
|
||||
|
||||
if target_param in self.params and not is_basic_type(
|
||||
self.params[target_param]
|
||||
):
|
||||
# edge.source.params = {}
|
||||
edge.source._build_params()
|
||||
edge.source._built_object = UnbuiltObject()
|
||||
edge.source._built = False
|
||||
|
||||
self.params[target_param] = edge.source
|
||||
@property
|
||||
def edges(self) -> List["Edge"]:
|
||||
return self.graph.get_vertex_edges(self.id)
|
||||
|
||||
def __getstate__(self):
|
||||
state_dict = self.__dict__.copy()
|
||||
try:
|
||||
# try pickling the built object
|
||||
# if it fails, then we need to delete it
|
||||
# and build it again
|
||||
pickle.dumps(state_dict["_built_object"])
|
||||
except Exception:
|
||||
self.reset_params()
|
||||
del state_dict["_built_object"]
|
||||
del state_dict["_built"]
|
||||
return state_dict
|
||||
return {
|
||||
"_data": self._data,
|
||||
"params": {},
|
||||
"base_type": self.base_type,
|
||||
"is_task": self.is_task,
|
||||
"id": self.id,
|
||||
"_built_object": UnbuiltObject(),
|
||||
"_built": False,
|
||||
"parent_node_id": self.parent_node_id,
|
||||
"parent_is_top_level": self.parent_is_top_level,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._data = state["_data"]
|
||||
self.params = state["params"]
|
||||
self.base_type = state["base_type"]
|
||||
self.is_task = state["is_task"]
|
||||
self.edges = state["edges"]
|
||||
self.id = state["id"]
|
||||
self._parse_data()
|
||||
if "_built_object" in state:
|
||||
|
|
@ -88,33 +69,26 @@ class Vertex:
|
|||
self._built = False
|
||||
self.artifacts: Dict[str, Any] = {}
|
||||
self.task_id: Optional[str] = None
|
||||
self.parent_node_id = state["parent_node_id"]
|
||||
self.parent_is_top_level = state["parent_is_top_level"]
|
||||
|
||||
def set_top_level(self, top_level_vertices: List[str]) -> None:
|
||||
self.parent_is_top_level = self.parent_node_id in top_level_vertices
|
||||
|
||||
def _parse_data(self) -> None:
|
||||
self.data = self._data["data"]
|
||||
self.output = self.data["node"]["base_classes"]
|
||||
template_dicts = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
|
||||
|
||||
self.required_inputs = [
|
||||
template_dicts[key]["type"]
|
||||
for key, value in template_dicts.items()
|
||||
if value["required"]
|
||||
template_dicts[key]["type"] for key, value in template_dicts.items() if value["required"]
|
||||
]
|
||||
self.optional_inputs = [
|
||||
template_dicts[key]["type"]
|
||||
for key, value in template_dicts.items()
|
||||
if not value["required"]
|
||||
template_dicts[key]["type"] for key, value in template_dicts.items() if not value["required"]
|
||||
]
|
||||
# Add the template_dicts[key]["input_types"] to the optional_inputs
|
||||
self.optional_inputs.extend(
|
||||
[
|
||||
input_type
|
||||
for value in template_dicts.values()
|
||||
for input_type in value.get("input_types", [])
|
||||
]
|
||||
[input_type for value in template_dicts.values() for input_type in value.get("input_types", [])]
|
||||
)
|
||||
|
||||
template_dict = self.data["node"]["template"]
|
||||
|
|
@ -153,11 +127,11 @@ class Vertex:
|
|||
# and use that as the value for the param
|
||||
# If the type is "str", then we need to get the value of the "value" key
|
||||
# and use that as the value for the param
|
||||
template_dict = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
|
||||
if self.graph is None:
|
||||
raise ValueError("Graph not found")
|
||||
|
||||
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
|
||||
params = self.params.copy() if self.params else {}
|
||||
|
||||
for edge in self.edges:
|
||||
|
|
@ -168,9 +142,9 @@ class Vertex:
|
|||
if template_dict[param_key]["list"]:
|
||||
if param_key not in params:
|
||||
params[param_key] = []
|
||||
params[param_key].append(edge.source)
|
||||
elif edge.target.id == self.id:
|
||||
params[param_key] = edge.source
|
||||
params[param_key].append(self.graph.get_vertex(edge.source_id))
|
||||
elif edge.target_id == self.id:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
|
||||
for key, value in template_dict.items():
|
||||
if key in params:
|
||||
|
|
@ -182,7 +156,7 @@ class Vertex:
|
|||
# If the type is not transformable to a python base class
|
||||
# then we need to get the edge that connects to this node
|
||||
if value.get("type") == "file":
|
||||
# Load the type in value.get('suffixes') using
|
||||
# Load the type in value.get('fileTypes') using
|
||||
# what is inside value.get('content')
|
||||
# value.get('value') is the file name
|
||||
if file_path := value.get("file_path"):
|
||||
|
|
@ -190,27 +164,33 @@ class Vertex:
|
|||
else:
|
||||
raise ValueError(f"File path not found for {self.vertex_type}")
|
||||
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
|
||||
val = value.get("value")
|
||||
if value.get("type") == "code":
|
||||
try:
|
||||
params[key] = ast.literal_eval(value.get("value"))
|
||||
params[key] = ast.literal_eval(val) if val else None
|
||||
except Exception as exc:
|
||||
logger.debug(f"Error parsing code: {exc}")
|
||||
params[key] = value.get("value")
|
||||
params[key] = val
|
||||
elif value.get("type") in ["dict", "NestedDict"]:
|
||||
# When dict comes from the frontend it comes as a
|
||||
# list of dicts, so we need to convert it to a dict
|
||||
# before passing it to the build method
|
||||
_value = value.get("value")
|
||||
if isinstance(_value, list):
|
||||
params[key] = {
|
||||
k: v
|
||||
for item in value.get("value", [])
|
||||
for k, v in item.items()
|
||||
}
|
||||
elif isinstance(_value, dict):
|
||||
params[key] = _value
|
||||
else:
|
||||
params[key] = value.get("value")
|
||||
if isinstance(val, list):
|
||||
params[key] = {k: v for item in value.get("value", []) for k, v in item.items()}
|
||||
elif isinstance(val, dict):
|
||||
params[key] = val
|
||||
elif value.get("type") == "int" and val is not None:
|
||||
try:
|
||||
params[key] = int(val)
|
||||
except ValueError:
|
||||
params[key] = val
|
||||
elif value.get("type") == "float" and val is not None:
|
||||
try:
|
||||
params[key] = float(val)
|
||||
except ValueError:
|
||||
params[key] = val
|
||||
elif val is not None and val != "":
|
||||
params[key] = val
|
||||
|
||||
if not value.get("required") and params.get(key) is None:
|
||||
if value.get("default"):
|
||||
|
|
@ -221,18 +201,18 @@ class Vertex:
|
|||
self._raw_params = params
|
||||
self.params = params
|
||||
|
||||
def _build(self, user_id=None):
|
||||
async def _build(self, user_id=None):
|
||||
"""
|
||||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.vertex_type}")
|
||||
self._build_each_node_in_params_dict(user_id)
|
||||
self._get_and_instantiate_class(user_id)
|
||||
await self._build_each_node_in_params_dict(user_id)
|
||||
await self._get_and_instantiate_class(user_id)
|
||||
self._validate_built_object()
|
||||
|
||||
self._built = True
|
||||
|
||||
def _build_each_node_in_params_dict(self, user_id=None):
|
||||
async def _build_each_node_in_params_dict(self, user_id=None):
|
||||
"""
|
||||
Iterates over each node in the params dictionary and builds it.
|
||||
"""
|
||||
|
|
@ -241,9 +221,9 @@ class Vertex:
|
|||
if value == self:
|
||||
del self.params[key]
|
||||
continue
|
||||
self._build_node_and_update_params(key, value, user_id)
|
||||
await self._build_node_and_update_params(key, value, user_id)
|
||||
elif isinstance(value, list) and self._is_list_of_nodes(value):
|
||||
self._build_list_of_nodes_and_update_params(key, value, user_id)
|
||||
await self._build_list_of_nodes_and_update_params(key, value, user_id)
|
||||
|
||||
def _is_node(self, value):
|
||||
"""
|
||||
|
|
@ -257,14 +237,17 @@ class Vertex:
|
|||
"""
|
||||
return all(self._is_node(node) for node in value)
|
||||
|
||||
def get_result(self, user_id=None, timeout=None) -> Any:
|
||||
async def get_result(self, user_id=None, timeout=None) -> Any:
|
||||
# Check if the Vertex was built already
|
||||
if self._built:
|
||||
return self._built_object
|
||||
|
||||
if self.is_task and self.task_id is not None:
|
||||
task = self.get_task()
|
||||
|
||||
result = task.get(timeout=timeout)
|
||||
if isinstance(result, Coroutine):
|
||||
result = await result
|
||||
if result is not None: # If result is ready
|
||||
self._update_built_object_and_artifacts(result)
|
||||
return self._built_object
|
||||
|
|
@ -273,29 +256,27 @@ class Vertex:
|
|||
pass
|
||||
|
||||
# If there's no task_id, build the vertex locally
|
||||
self.build(user_id)
|
||||
await self.build(user_id=user_id)
|
||||
return self._built_object
|
||||
|
||||
def _build_node_and_update_params(self, key, node, user_id=None):
|
||||
async def _build_node_and_update_params(self, key, node, user_id=None):
|
||||
"""
|
||||
Builds a given node and updates the params dictionary accordingly.
|
||||
"""
|
||||
|
||||
result = node.get_result(user_id)
|
||||
result = await node.get_result(user_id)
|
||||
self._handle_func(key, result)
|
||||
if isinstance(result, list):
|
||||
self._extend_params_list_with_result(key, result)
|
||||
self.params[key] = result
|
||||
|
||||
def _build_list_of_nodes_and_update_params(
|
||||
self, key, nodes: List["Vertex"], user_id=None
|
||||
):
|
||||
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
|
||||
"""
|
||||
Iterates over a list of nodes, builds each and updates the params dictionary.
|
||||
"""
|
||||
self.params[key] = []
|
||||
for node in nodes:
|
||||
built = node.get_result(user_id)
|
||||
built = await node.get_result(user_id)
|
||||
if isinstance(built, list):
|
||||
if key not in self.params:
|
||||
self.params[key] = []
|
||||
|
|
@ -325,14 +306,14 @@ class Vertex:
|
|||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
|
||||
def _get_and_instantiate_class(self, user_id=None):
|
||||
async def _get_and_instantiate_class(self, user_id=None):
|
||||
"""
|
||||
Gets the class from a dictionary and instantiates it with the params.
|
||||
"""
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.vertex_type} not found")
|
||||
try:
|
||||
result = loading.instantiate_class(
|
||||
result = await loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
params=self.params,
|
||||
|
|
@ -341,9 +322,7 @@ class Vertex:
|
|||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(
|
||||
f"Error building node {self.vertex_type}: {str(exc)}"
|
||||
) from exc
|
||||
raise ValueError(f"Error building node {self.vertex_type}(ID:{self.id}): {str(exc)}") from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result):
|
||||
"""
|
||||
|
|
@ -365,11 +344,11 @@ class Vertex:
|
|||
if self.base_type == "custom_components":
|
||||
message += " Make sure your build method returns a component."
|
||||
|
||||
raise ValueError(message)
|
||||
logger.warning(message)
|
||||
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
self._build(user_id, *args, **kwargs)
|
||||
await self._build(user_id, *args, **kwargs)
|
||||
|
||||
return self._built_object
|
||||
|
||||
|
|
@ -391,8 +370,4 @@ class Vertex:
|
|||
|
||||
def _built_object_repr(self):
|
||||
# Add a message with an emoji, stars for sucess,
|
||||
return (
|
||||
"Built sucessfully ✨"
|
||||
if self._built_object is not None
|
||||
else "Failed to build 😵💫"
|
||||
)
|
||||
return "Built sucessfully ✨" if self._built_object is not None else "Failed to build 😵💫"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
import ast
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langflow.graph.utils import UnbuiltObject, flatten_list
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.utils import flatten_list
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
|
||||
class AgentVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="agents", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="agents", params=params)
|
||||
|
||||
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
|
||||
self.chains: List[ChainVertex] = []
|
||||
|
|
@ -26,49 +26,54 @@ class AgentVertex(Vertex):
|
|||
|
||||
def _set_tools_and_chains(self) -> None:
|
||||
for edge in self.edges:
|
||||
if not hasattr(edge, "source"):
|
||||
if not hasattr(edge, "source_id"):
|
||||
continue
|
||||
source_node = edge.source
|
||||
source_node = self.graph.get_vertex(edge.source_id)
|
||||
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
|
||||
self.tools.append(source_node)
|
||||
elif isinstance(source_node, ChainVertex):
|
||||
self.chains.append(source_node)
|
||||
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
self._set_tools_and_chains()
|
||||
# First, build the tools
|
||||
for tool_node in self.tools:
|
||||
tool_node.build(user_id=user_id)
|
||||
await tool_node.build(user_id=user_id)
|
||||
|
||||
# Next, build the chains and the rest
|
||||
for chain_node in self.chains:
|
||||
chain_node.build(tools=self.tools, user_id=user_id)
|
||||
await chain_node.build(tools=self.tools, user_id=user_id)
|
||||
|
||||
self._build(user_id=user_id)
|
||||
await self._build(user_id=user_id)
|
||||
|
||||
return self._built_object
|
||||
|
||||
|
||||
class ToolVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="tools", params=params)
|
||||
def __init__(
|
||||
self,
|
||||
data: Dict,
|
||||
graph,
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__(data, graph=graph, base_type="tools", params=params)
|
||||
|
||||
|
||||
class LLMVertex(Vertex):
|
||||
built_node_type = None
|
||||
class_built_object = None
|
||||
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="llms", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="llms", params=params)
|
||||
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
# LLM is different because some models might take up too much memory
|
||||
# or time to load. So we only load them when we need them.ß
|
||||
if self.vertex_type == self.built_node_type:
|
||||
return self.class_built_object
|
||||
if not self._built or force:
|
||||
self._build(user_id=user_id)
|
||||
await self._build(user_id=user_id)
|
||||
self.built_node_type = self.vertex_type
|
||||
self.class_built_object = self._built_object
|
||||
# Avoid deepcopying the LLM
|
||||
|
|
@ -77,41 +82,39 @@ class LLMVertex(Vertex):
|
|||
|
||||
|
||||
class ToolkitVertex(Vertex):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="toolkits", params=params)
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, base_type="toolkits", params=params)
|
||||
|
||||
|
||||
class FileToolVertex(ToolVertex):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, params=params)
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, params=params)
|
||||
|
||||
|
||||
class WrapperVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="wrappers")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="wrappers")
|
||||
|
||||
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
|
||||
if not self._built or force:
|
||||
if "headers" in self.params:
|
||||
self.params["headers"] = ast.literal_eval(self.params["headers"])
|
||||
self._build(user_id=user_id)
|
||||
await self._build(user_id=user_id)
|
||||
return self._built_object
|
||||
|
||||
|
||||
class DocumentLoaderVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="documentloaders", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="documentloaders", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
avg_length = sum(
|
||||
len(doc.page_content)
|
||||
for doc in self._built_object
|
||||
if hasattr(doc, "page_content")
|
||||
) / len(self._built_object)
|
||||
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
|
||||
self._built_object
|
||||
)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
Documents: {self._built_object[:3]}..."""
|
||||
|
|
@ -119,28 +122,19 @@ class DocumentLoaderVertex(Vertex):
|
|||
|
||||
|
||||
class EmbeddingVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="embeddings", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="embeddings", params=params)
|
||||
|
||||
|
||||
class VectorStoreVertex(Vertex):
|
||||
def __init__(self, data: Dict, params=None):
|
||||
super().__init__(data, base_type="vectorstores")
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, base_type="vectorstores")
|
||||
|
||||
self.params = params or {}
|
||||
|
||||
# VectorStores may contain databse connections
|
||||
# so we need to define the __reduce__ method and the __setstate__ method
|
||||
# to avoid pickling errors
|
||||
def clean_edges_for_pickling(self):
|
||||
# for each edge that has self as source
|
||||
# we need to clear the _built_object of the target
|
||||
# so that we don't try to pickle a database connection
|
||||
for edge in self.edges:
|
||||
if edge.source == self:
|
||||
edge.target._built_object = None
|
||||
edge.target._built = False
|
||||
edge.target.params[edge.target_param] = self
|
||||
|
||||
def remove_docs_and_texts_from_params(self):
|
||||
# remove documents and texts from params
|
||||
|
|
@ -148,17 +142,16 @@ class VectorStoreVertex(Vertex):
|
|||
self.params.pop("documents", None)
|
||||
self.params.pop("texts", None)
|
||||
|
||||
def __getstate__(self):
|
||||
# We want to save the params attribute
|
||||
# and if "documents" or "texts" are in the params
|
||||
# we want to remove them because they have already
|
||||
# been processed.
|
||||
params = self.params.copy()
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
self.clean_edges_for_pickling()
|
||||
# def __getstate__(self):
|
||||
# # We want to save the params attribute
|
||||
# # and if "documents" or "texts" are in the params
|
||||
# # we want to remove them because they have already
|
||||
# # been processed.
|
||||
# params = self.params.copy()
|
||||
# params.pop("documents", None)
|
||||
# params.pop("texts", None)
|
||||
|
||||
return super().__getstate__()
|
||||
# return super().__getstate__()
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
|
|
@ -166,27 +159,25 @@ class VectorStoreVertex(Vertex):
|
|||
|
||||
|
||||
class MemoryVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="memory")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="memory")
|
||||
|
||||
|
||||
class RetrieverVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="retrievers")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="retrievers")
|
||||
|
||||
|
||||
class TextSplitterVertex(Vertex):
|
||||
def __init__(self, data: Dict, params: Optional[Dict] = None):
|
||||
super().__init__(data, base_type="textsplitters", params=params)
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="textsplitters", params=params)
|
||||
|
||||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
|
||||
self._built_object
|
||||
)
|
||||
if self._built_object and not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
\nDocuments: {self._built_object[:3]}..."""
|
||||
|
|
@ -194,10 +185,10 @@ class TextSplitterVertex(Vertex):
|
|||
|
||||
|
||||
class ChainVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="chains")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="chains")
|
||||
|
||||
def build(
|
||||
async def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
user_id=None,
|
||||
|
|
@ -205,6 +196,8 @@ class ChainVertex(Vertex):
|
|||
**kwargs,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
# Temporarily remove the code from the params
|
||||
self.params.pop("code", None)
|
||||
# Check if the chain requires a PromptVertex
|
||||
|
||||
# Temporarily remove the code from the params
|
||||
|
|
@ -214,18 +207,18 @@ class ChainVertex(Vertex):
|
|||
if isinstance(value, PromptVertex):
|
||||
# Build the PromptVertex, passing the tools if available
|
||||
tools = kwargs.get("tools", None)
|
||||
self.params[key] = value.build(tools=tools, force=force)
|
||||
self.params[key] = await value.build(tools=tools, force=force)
|
||||
|
||||
self._build(user_id=user_id)
|
||||
await self._build(user_id=user_id)
|
||||
|
||||
return self._built_object
|
||||
|
||||
|
||||
class PromptVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="prompts")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="prompts")
|
||||
|
||||
def build(
|
||||
async def build(
|
||||
self,
|
||||
force: bool = False,
|
||||
user_id=None,
|
||||
|
|
@ -234,27 +227,18 @@ class PromptVertex(Vertex):
|
|||
**kwargs,
|
||||
) -> Any:
|
||||
if not self._built or force:
|
||||
if (
|
||||
"input_variables" not in self.params
|
||||
or self.params["input_variables"] is None
|
||||
):
|
||||
if "input_variables" not in self.params or self.params["input_variables"] is None:
|
||||
self.params["input_variables"] = []
|
||||
# Check if it is a ZeroShotPrompt and needs a tool
|
||||
if "ShotPrompt" in self.vertex_type:
|
||||
tools = (
|
||||
[tool_node.build(user_id=user_id) for tool_node in tools]
|
||||
if tools is not None
|
||||
else []
|
||||
)
|
||||
tools = [await tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
|
||||
# flatten the list of tools if it is a list of lists
|
||||
# first check if it is a list
|
||||
if tools and isinstance(tools, list) and isinstance(tools[0], list):
|
||||
tools = flatten_list(tools)
|
||||
self.params["tools"] = tools
|
||||
prompt_params = [
|
||||
key
|
||||
for key, value in self.params.items()
|
||||
if isinstance(value, str) and key != "format_instructions"
|
||||
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
|
||||
]
|
||||
else:
|
||||
prompt_params = ["template"]
|
||||
|
|
@ -264,21 +248,15 @@ class PromptVertex(Vertex):
|
|||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(
|
||||
set(self.params["input_variables"])
|
||||
)
|
||||
self.params["input_variables"] = list(set(self.params["input_variables"]))
|
||||
elif isinstance(self.params, dict):
|
||||
self.params.pop("input_variables", None)
|
||||
|
||||
self._build(user_id=user_id)
|
||||
await self._build(user_id=user_id)
|
||||
return self._built_object
|
||||
|
||||
def _built_object_repr(self):
|
||||
if (
|
||||
not self.artifacts
|
||||
or self._built_object is None
|
||||
or not hasattr(self._built_object, "format")
|
||||
):
|
||||
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
|
||||
return super()._built_object_repr()
|
||||
# We'll build the prompt with the artifacts
|
||||
# to show the user what the prompt looks like
|
||||
|
|
@ -288,33 +266,31 @@ class PromptVertex(Vertex):
|
|||
# so the prompt format doesn't break
|
||||
artifacts.pop("handle_keys", None)
|
||||
try:
|
||||
if not hasattr(self._built_object, "template") and hasattr(
|
||||
self._built_object, "prompt"
|
||||
if (
|
||||
not hasattr(self._built_object, "template")
|
||||
and hasattr(self._built_object, "prompt")
|
||||
and not isinstance(self._built_object, UnbuiltObject)
|
||||
):
|
||||
template = self._built_object.prompt.template
|
||||
else:
|
||||
elif not isinstance(self._built_object, UnbuiltObject) and hasattr(self._built_object, "template"):
|
||||
template = self._built_object.template
|
||||
for key, value in artifacts.items():
|
||||
if value:
|
||||
replace_key = "{" + key + "}"
|
||||
template = template.replace(replace_key, value)
|
||||
return (
|
||||
template
|
||||
if isinstance(template, str)
|
||||
else f"{self.vertex_type}({template})"
|
||||
)
|
||||
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
|
||||
except KeyError:
|
||||
return str(self._built_object)
|
||||
|
||||
|
||||
class OutputParserVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="output_parsers")
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="output_parsers")
|
||||
|
||||
|
||||
class CustomComponentVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="custom_components", is_task=True)
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="custom_components", is_task=False)
|
||||
|
||||
def _built_object_repr(self):
|
||||
if self.task_id and self.is_task:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Dict, List, Optional
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
from langchain.agents import types
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.agents import AgentFrontendNode
|
||||
from loguru import logger
|
||||
|
|
@ -15,7 +15,7 @@ from langflow.utils.util import build_template_from_class, build_template_from_m
|
|||
class AgentCreator(LangChainTypeCreator):
|
||||
type_name: str = "agents"
|
||||
|
||||
from_method_nodes = {"ZeroShotAgent": "from_llm_and_tools"}
|
||||
from_method_nodes: ClassVar[Dict] = {"ZeroShotAgent": "from_llm_and_tools"}
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> type[AgentFrontendNode]:
|
||||
|
|
@ -42,9 +42,7 @@ class AgentCreator(LangChainTypeCreator):
|
|||
add_function=True,
|
||||
method_name=self.from_method_nodes[name],
|
||||
)
|
||||
return build_template_from_class(
|
||||
name, self.type_to_loader_dict, add_function=True
|
||||
)
|
||||
return build_template_from_class(name, self.type_to_loader_dict, add_function=True)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Agent not found") from exc
|
||||
except AttributeError as exc:
|
||||
|
|
@ -56,15 +54,8 @@ class AgentCreator(LangChainTypeCreator):
|
|||
names = []
|
||||
settings_service = get_settings_service()
|
||||
for _, agent in self.type_to_loader_dict.items():
|
||||
agent_name = (
|
||||
agent.function_name()
|
||||
if hasattr(agent, "function_name")
|
||||
else agent.__name__
|
||||
)
|
||||
if (
|
||||
agent_name in settings_service.settings.AGENTS
|
||||
or settings_service.settings.DEV
|
||||
):
|
||||
agent_name = agent.function_name() if hasattr(agent, "function_name") else agent.__name__
|
||||
if agent_name in settings_service.settings.AGENTS or settings_service.settings.DEV:
|
||||
names.append(agent_name)
|
||||
return names
|
||||
|
||||
|
|
|
|||
|
|
@ -1,38 +1,31 @@
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.agents import (
|
||||
AgentExecutor,
|
||||
Tool,
|
||||
ZeroShotAgent,
|
||||
initialize_agent,
|
||||
AgentType,
|
||||
)
|
||||
from langchain.agents.agent_toolkits import (
|
||||
SQLDatabaseToolkit,
|
||||
VectorStoreInfo,
|
||||
VectorStoreRouterToolkit,
|
||||
VectorStoreToolkit,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
|
||||
from langchain.agents import (AgentExecutor, AgentType, Tool, ZeroShotAgent,
|
||||
initialize_agent)
|
||||
from langchain.agents.agent_toolkits import (SQLDatabaseToolkit,
|
||||
VectorStoreInfo,
|
||||
VectorStoreRouterToolkit,
|
||||
VectorStoreToolkit)
|
||||
from langchain.agents.agent_toolkits.json.prompt import (JSON_PREFIX,
|
||||
JSON_SUFFIX)
|
||||
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import (
|
||||
SUFFIX_WITH_DF as PANDAS_SUFFIX,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
|
||||
from langchain.agents.agent_toolkits.vectorstore.prompt import (
|
||||
PREFIX as VECTORSTORE_PREFIX,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.vectorstore.prompt import (
|
||||
ROUTER_PREFIX as VECTORSTORE_ROUTER_PREFIX,
|
||||
)
|
||||
from langchain.agents.agent_toolkits.vectorstore.prompt import \
|
||||
PREFIX as VECTORSTORE_PREFIX
|
||||
from langchain.agents.agent_toolkits.vectorstore.prompt import \
|
||||
ROUTER_PREFIX as VECTORSTORE_ROUTER_PREFIX
|
||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain_experimental.agents.agent_toolkits.pandas.prompt import \
|
||||
PREFIX as PANDAS_PREFIX
|
||||
from langchain_experimental.agents.agent_toolkits.pandas.prompt import \
|
||||
SUFFIX_WITH_DF as PANDAS_SUFFIX
|
||||
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
from langflow.interface.base import CustomAgentExecutor
|
||||
|
||||
|
||||
|
|
@ -53,7 +46,7 @@ class JsonAgent(CustomAgentExecutor):
|
|||
@classmethod
|
||||
def from_toolkit_and_llm(cls, toolkit: JsonToolkit, llm: BaseLanguageModel):
|
||||
tools = toolkit if isinstance(toolkit, list) else toolkit.get_tools()
|
||||
tool_names = {tool.name for tool in tools}
|
||||
tool_names = list({tool.name for tool in tools})
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=JSON_PREFIX,
|
||||
|
|
@ -66,7 +59,8 @@ class JsonAgent(CustomAgentExecutor):
|
|||
prompt=prompt,
|
||||
)
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names # type: ignore
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names, # type: ignore
|
||||
)
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
|
|
@ -90,11 +84,7 @@ class CSVAgent(CustomAgentExecutor):
|
|||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(
|
||||
cls,
|
||||
path: str,
|
||||
llm: BaseLanguageModel,
|
||||
pandas_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any
|
||||
cls, path: str, llm: BaseLanguageModel, pandas_kwargs: Optional[dict] = None, **kwargs: Any
|
||||
):
|
||||
import pandas as pd # type: ignore
|
||||
|
||||
|
|
@ -106,16 +96,18 @@ class CSVAgent(CustomAgentExecutor):
|
|||
tools,
|
||||
prefix=PANDAS_PREFIX,
|
||||
suffix=PANDAS_SUFFIX,
|
||||
input_variables=["df", "input", "agent_scratchpad"],
|
||||
input_variables=["df_head", "input", "agent_scratchpad"],
|
||||
)
|
||||
partial_prompt = prompt.partial(df=str(df.head()))
|
||||
partial_prompt = prompt.partial(df_head=str(df.head()))
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=partial_prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
tool_names = list({tool.name for tool in tools})
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
|
@ -139,9 +131,7 @@ class VectorStoreAgent(CustomAgentExecutor):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(
|
||||
cls, llm: BaseLanguageModel, vectorstoreinfo: VectorStoreInfo, **kwargs: Any
|
||||
):
|
||||
def from_toolkit_and_llm(cls, llm: BaseLanguageModel, vectorstoreinfo: VectorStoreInfo, **kwargs: Any):
|
||||
"""Construct a vectorstore agent from an LLM and tools."""
|
||||
|
||||
toolkit = VectorStoreToolkit(vectorstore_info=vectorstoreinfo, llm=llm)
|
||||
|
|
@ -152,13 +142,13 @@ class VectorStoreAgent(CustomAgentExecutor):
|
|||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
tool_names = list({tool.name for tool in tools})
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
|
@ -179,9 +169,7 @@ class SQLAgent(CustomAgentExecutor):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(
|
||||
cls, llm: BaseLanguageModel, database_uri: str, **kwargs: Any
|
||||
):
|
||||
def from_toolkit_and_llm(cls, llm: BaseLanguageModel, database_uri: str, **kwargs: Any):
|
||||
"""Construct an SQL agent from an LLM and tools."""
|
||||
db = SQLDatabase.from_uri(database_uri)
|
||||
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
||||
|
|
@ -190,18 +178,14 @@ class SQLAgent(CustomAgentExecutor):
|
|||
# related to `OPENAI_API_KEY`
|
||||
# return create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
)
|
||||
from langchain.tools.sql_database.tool import (InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool)
|
||||
|
||||
llmchain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
||||
),
|
||||
prompt=PromptTemplate(template=QUERY_CHECKER, input_variables=["query", "dialect"]),
|
||||
)
|
||||
|
||||
tools = [
|
||||
|
|
@ -222,9 +206,11 @@ class SQLAgent(CustomAgentExecutor):
|
|||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools} # type: ignore
|
||||
tool_names = list({tool.name for tool in tools}) # type: ignore
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
|
|
@ -255,10 +241,7 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
|
|||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
vectorstoreroutertoolkit: VectorStoreRouterToolkit,
|
||||
**kwargs: Any
|
||||
cls, llm: BaseLanguageModel, vectorstoreroutertoolkit: VectorStoreRouterToolkit, **kwargs: Any
|
||||
):
|
||||
"""Construct a vector store router agent from an LLM and tools."""
|
||||
|
||||
|
|
@ -272,13 +255,13 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
|
|||
llm=llm,
|
||||
prompt=prompt,
|
||||
)
|
||||
tool_names = {tool.name for tool in tools}
|
||||
tool_names = list({tool.name for tool in tools})
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
**kwargs, # type: ignore
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.agents import AgentExecutor
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
|
|
@ -30,13 +30,8 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
settings_service = get_settings_service()
|
||||
if self.name_docs_dict is None:
|
||||
try:
|
||||
type_settings = getattr(
|
||||
settings_service.settings, self.type_name.upper()
|
||||
)
|
||||
self.name_docs_dict = {
|
||||
name: value_dict["documentation"]
|
||||
for name, value_dict in type_settings.items()
|
||||
}
|
||||
type_settings = getattr(settings_service.settings, self.type_name.upper())
|
||||
self.name_docs_dict = {name: value_dict["documentation"] for name, value_dict in type_settings.items()}
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Error getting settings for {self.type_name}: {exc}")
|
||||
|
||||
|
|
@ -88,7 +83,6 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
show=value.get("show", True),
|
||||
multiline=value.get("multiline", False),
|
||||
value=value.get("value", None),
|
||||
suffixes=value.get("suffixes", []),
|
||||
file_types=value.get("fileTypes", []),
|
||||
file_path=value.get("file_path", None),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Type
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.chains import ChainFrontendNode
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
from langchain import chains
|
||||
from langchain_experimental.sql import SQLDatabaseChain # type: ignore
|
||||
from langchain_experimental.sql import SQLDatabaseChain
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ class ChainCreator(LangChainTypeCreator):
|
|||
return ChainFrontendNode
|
||||
|
||||
#! We need to find a better solution for this
|
||||
from_method_nodes = {
|
||||
from_method_nodes: ClassVar[Dict] = {
|
||||
"ConversationalRetrievalChain": "from_llm",
|
||||
"LLMCheckerChain": "from_llm",
|
||||
"SQLDatabaseChain": "from_llm",
|
||||
|
|
@ -33,8 +33,7 @@ class ChainCreator(LangChainTypeCreator):
|
|||
if self.type_dict is None:
|
||||
settings_service = get_settings_service()
|
||||
self.type_dict: dict[str, Any] = {
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}")
|
||||
for chain_name in chains.__all__
|
||||
chain_name: import_class(f"langchain.chains.{chain_name}") for chain_name in chains.__all__
|
||||
}
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
|
||||
|
|
@ -45,8 +44,7 @@ class ChainCreator(LangChainTypeCreator):
|
|||
self.type_dict = {
|
||||
name: chain
|
||||
for name, chain in self.type_dict.items()
|
||||
if name in settings_service.settings.CHAINS
|
||||
or settings_service.settings.DEV
|
||||
if name in settings_service.settings.CHAINS or settings_service.settings.DEV
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
@ -61,9 +59,7 @@ class ChainCreator(LangChainTypeCreator):
|
|||
method_name=self.from_method_nodes[name],
|
||||
add_function=True,
|
||||
)
|
||||
return build_template_from_class(
|
||||
name, self.type_to_loader_dict, add_function=True
|
||||
)
|
||||
return build_template_from_class(name, self.type_to_loader_dict, add_function=True)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Chain {name} not found: {exc}") from exc
|
||||
except AttributeError as exc:
|
||||
|
|
@ -73,11 +69,7 @@ class ChainCreator(LangChainTypeCreator):
|
|||
def to_list(self) -> List[str]:
|
||||
names = []
|
||||
for _, chain in self.type_to_loader_dict.items():
|
||||
chain_name = (
|
||||
chain.function_name()
|
||||
if hasattr(chain, "function_name")
|
||||
else chain.__name__
|
||||
)
|
||||
chain_name = chain.function_name() if hasattr(chain, "function_name") else chain.__name__
|
||||
names.append(chain_name)
|
||||
return names
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langchain.chains import ConversationChain
|
|||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.schema import BaseMemory
|
||||
from langflow.interface.base import CustomChain
|
||||
from pydantic import Field, root_validator
|
||||
from pydantic.v1 import Field, root_validator
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
|
@ -41,9 +41,7 @@ class BaseCustomConversationChain(ConversationChain):
|
|||
values["template"] = values["template"].format(**format_dict)
|
||||
|
||||
values["template"] = values["template"]
|
||||
values["input_variables"] = extract_input_variables_from_prompt(
|
||||
values["template"]
|
||||
)
|
||||
values["input_variables"] = extract_input_variables_from_prompt(values["template"])
|
||||
values["prompt"].template = values["template"]
|
||||
values["prompt"].input_variables = values["input_variables"]
|
||||
return values
|
||||
|
|
@ -54,9 +52,7 @@ class SeriesCharacterChain(BaseCustomConversationChain):
|
|||
|
||||
character: str
|
||||
series: str
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act like {character} from {series}.
|
||||
template: Optional[str] = """I want you to act like {character} from {series}.
|
||||
I want you to respond and answer like {character}. do not write any explanations. only answer like {character}.
|
||||
You must know all of the knowledge of {character}.
|
||||
Current conversation:
|
||||
|
|
@ -71,9 +67,7 @@ Human: {input}
|
|||
class MidJourneyPromptChain(BaseCustomConversationChain):
|
||||
"""MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."""
|
||||
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act as a prompt generator for Midjourney's artificial intelligence program.
|
||||
template: Optional[str] = """I want you to act as a prompt generator for Midjourney's artificial intelligence program.
|
||||
Your job is to provide detailed and creative descriptions that will inspire unique and interesting images from the AI.
|
||||
Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible.
|
||||
For example, you could describe a scene from a futuristic city, or a surreal landscape filled with strange creatures.
|
||||
|
|
@ -87,9 +81,7 @@ class MidJourneyPromptChain(BaseCustomConversationChain):
|
|||
|
||||
|
||||
class TimeTravelGuideChain(BaseCustomConversationChain):
|
||||
template: Optional[
|
||||
str
|
||||
] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
|
||||
template: Optional[str] = """I want you to act as my time travel guide. You are helpful and creative. I will provide you with the historical period or future time I want to visit and you will suggest the best events, sights, or people to experience. Provide the suggestions and any necessary information.
|
||||
Current conversation:
|
||||
{history}
|
||||
Human: {input}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
import ast
|
||||
import inspect
|
||||
import operator
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Type, Union
|
||||
|
||||
from typing import Dict, Any, List, Type, Union
|
||||
from cachetools import TTLCache, cachedmethod, keys
|
||||
from fastapi import HTTPException
|
||||
|
||||
from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails
|
||||
|
||||
|
||||
|
|
@ -11,6 +14,19 @@ class CodeSyntaxError(HTTPException):
|
|||
pass
|
||||
|
||||
|
||||
def get_data_type():
|
||||
from langflow.field_typing import Data
|
||||
|
||||
return Data
|
||||
|
||||
|
||||
def imports_key(*args, **kwargs):
|
||||
imports = kwargs.pop("imports")
|
||||
key = keys.methodkey(*args, **kwargs)
|
||||
key += tuple(imports)
|
||||
return key
|
||||
|
||||
|
||||
class CodeParser:
|
||||
"""
|
||||
A parser for Python source code, extracting code details.
|
||||
|
|
@ -20,6 +36,7 @@ class CodeParser:
|
|||
"""
|
||||
Initializes the parser with the provided code.
|
||||
"""
|
||||
self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60)
|
||||
if isinstance(code, type):
|
||||
if not inspect.isclass(code):
|
||||
raise ValueError("The provided code must be a class.")
|
||||
|
|
@ -65,14 +82,20 @@ class CodeParser:
|
|||
|
||||
def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
|
||||
"""
|
||||
Extracts "imports" from the code.
|
||||
Extracts "imports" from the code, including aliases.
|
||||
"""
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
self.data["imports"].append(alias.name)
|
||||
if alias.asname:
|
||||
self.data["imports"].append(f"{alias.name} as {alias.asname}")
|
||||
else:
|
||||
self.data["imports"].append(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
self.data["imports"].append((node.module, alias.name))
|
||||
if alias.asname:
|
||||
self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}"))
|
||||
else:
|
||||
self.data["imports"].append((node.module, alias.name))
|
||||
|
||||
def parse_functions(self, node: ast.FunctionDef) -> None:
|
||||
"""
|
||||
|
|
@ -89,22 +112,54 @@ class CodeParser:
|
|||
arg_dict["type"] = ast.unparse(arg.annotation)
|
||||
return arg_dict
|
||||
|
||||
@cachedmethod(operator.attrgetter("cache"))
|
||||
def construct_eval_env(self, return_type_str: str, imports) -> dict:
|
||||
"""
|
||||
Constructs an evaluation environment with the necessary imports for the return type,
|
||||
taking into account module aliases.
|
||||
"""
|
||||
eval_env: dict = {}
|
||||
for import_entry in imports:
|
||||
if isinstance(import_entry, tuple): # from module import name
|
||||
module, name = import_entry
|
||||
if name in return_type_str:
|
||||
exec(f"import {module}", eval_env)
|
||||
exec(f"from {module} import {name}", eval_env)
|
||||
else: # import module
|
||||
module = import_entry
|
||||
alias = None
|
||||
if " as " in module:
|
||||
module, alias = module.split(" as ")
|
||||
if module in return_type_str or (alias and alias in return_type_str):
|
||||
exec(f"import {module} as {alias if alias else module}", eval_env)
|
||||
return eval_env
|
||||
|
||||
@cachedmethod(cache=operator.attrgetter("cache"))
|
||||
def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts details from a single function or method node.
|
||||
"""
|
||||
return_type = None
|
||||
if node.returns:
|
||||
return_type_str = ast.unparse(node.returns)
|
||||
eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"]))
|
||||
|
||||
try:
|
||||
return_type = eval(return_type_str, eval_env)
|
||||
except NameError:
|
||||
# Handle cases where the type is not found in the constructed environment
|
||||
pass
|
||||
|
||||
func = CallableCodeDetails(
|
||||
name=node.name,
|
||||
doc=ast.get_docstring(node),
|
||||
args=[],
|
||||
body=[],
|
||||
return_type=ast.unparse(node.returns) if node.returns else None,
|
||||
args=self.parse_function_args(node),
|
||||
body=self.parse_function_body(node),
|
||||
return_type=return_type or get_data_type(),
|
||||
has_return=self.parse_return_statement(node),
|
||||
)
|
||||
|
||||
func.args = self.parse_function_args(node)
|
||||
func.body = self.parse_function_body(node)
|
||||
|
||||
return func.dict()
|
||||
return func.model_dump()
|
||||
|
||||
def parse_function_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -115,7 +170,9 @@ class CodeParser:
|
|||
args += self.parse_positional_args(node)
|
||||
args += self.parse_varargs(node)
|
||||
args += self.parse_keyword_args(node)
|
||||
args += self.parse_kwargs(node)
|
||||
# Commented out because we don't want kwargs
|
||||
# showing up as fields in the frontend
|
||||
# args += self.parse_kwargs(node)
|
||||
|
||||
return args
|
||||
|
||||
|
|
@ -127,22 +184,14 @@ class CodeParser:
|
|||
num_defaults = len(node.args.defaults)
|
||||
num_missing_defaults = num_args - num_defaults
|
||||
missing_defaults = [None] * num_missing_defaults
|
||||
default_values = [
|
||||
ast.unparse(default).strip("'") if default else None
|
||||
for default in node.args.defaults
|
||||
]
|
||||
default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults]
|
||||
# Now check all default values to see if there
|
||||
# are any "None" values in the middle
|
||||
default_values = [
|
||||
None if value == "None" else value for value in default_values
|
||||
]
|
||||
default_values = [None if value == "None" else value for value in default_values]
|
||||
|
||||
defaults = missing_defaults + default_values
|
||||
|
||||
args = [
|
||||
self.parse_arg(arg, default)
|
||||
for arg, default in zip(node.args.args, defaults)
|
||||
]
|
||||
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)]
|
||||
return args
|
||||
|
||||
def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
|
|
@ -160,17 +209,11 @@ class CodeParser:
|
|||
"""
|
||||
Parses the keyword-only arguments of a function or method node.
|
||||
"""
|
||||
kw_defaults = [None] * (
|
||||
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
|
||||
) + [
|
||||
ast.unparse(default) if default else None
|
||||
for default in node.args.kw_defaults
|
||||
kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [
|
||||
ast.unparse(default) if default else None for default in node.args.kw_defaults
|
||||
]
|
||||
|
||||
args = [
|
||||
self.parse_arg(arg, default)
|
||||
for arg, default in zip(node.args.kwonlyargs, kw_defaults)
|
||||
]
|
||||
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)]
|
||||
return args
|
||||
|
||||
def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
|
|
@ -190,6 +233,13 @@ class CodeParser:
|
|||
"""
|
||||
return [ast.unparse(line) for line in node.body]
|
||||
|
||||
def parse_return_statement(self, node: ast.FunctionDef) -> bool:
|
||||
"""
|
||||
Parses the return statement of a function or method node.
|
||||
"""
|
||||
|
||||
return any(isinstance(n, ast.Return) for n in node.body)
|
||||
|
||||
def parse_assign(self, stmt):
|
||||
"""
|
||||
Parses an Assign statement and returns a dictionary
|
||||
|
|
@ -240,23 +290,21 @@ class CodeParser:
|
|||
elif isinstance(stmt, ast.AnnAssign):
|
||||
if attr := self.parse_ann_assign(stmt):
|
||||
class_details.attributes.append(attr)
|
||||
elif isinstance(stmt, ast.FunctionDef):
|
||||
elif isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
method, is_init = self.parse_function_def(stmt)
|
||||
if is_init:
|
||||
class_details.init = method
|
||||
else:
|
||||
class_details.methods.append(method)
|
||||
|
||||
self.data["classes"].append(class_details.dict())
|
||||
self.data["classes"].append(class_details.model_dump())
|
||||
|
||||
def parse_global_vars(self, node: ast.Assign) -> None:
|
||||
"""
|
||||
Extracts global variables from the code.
|
||||
"""
|
||||
global_var = {
|
||||
"targets": [
|
||||
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
|
||||
],
|
||||
"targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets],
|
||||
"value": ast.unparse(node.value),
|
||||
}
|
||||
self.data["global_vars"].append(global_var)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import ast
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
import operator
|
||||
import warnings
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
|
||||
from langflow.utils import validate
|
||||
from langflow.interface.custom.code_parser import CodeParser
|
||||
from langflow.utils import validate
|
||||
|
||||
|
||||
class ComponentCodeNullError(HTTPException):
|
||||
|
|
@ -15,19 +18,29 @@ class ComponentFunctionEntrypointNameNullError(HTTPException):
|
|||
pass
|
||||
|
||||
|
||||
class Component(BaseModel):
|
||||
ERROR_CODE_NULL = "Python code must be provided."
|
||||
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = (
|
||||
"The name of the entrypoint function must be provided."
|
||||
)
|
||||
class Component:
|
||||
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
|
||||
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
|
||||
|
||||
code: Optional[str]
|
||||
function_entrypoint_name = "build"
|
||||
code: Optional[str] = None
|
||||
_function_entrypoint_name: str = "build"
|
||||
field_config: dict = {}
|
||||
_user_id: Optional[str]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
for key, value in data.items():
|
||||
if key == "user_id":
|
||||
setattr(self, "_user_id", value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "_user_id" and hasattr(self, "_user_id"):
|
||||
warnings.warn("user_id is immutable and cannot be changed.")
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@cachedmethod(cache=operator.attrgetter("cache"))
|
||||
def get_code_tree(self, code: str):
|
||||
parser = CodeParser(code)
|
||||
return parser.parse_code()
|
||||
|
|
@ -39,7 +52,7 @@ class Component(BaseModel):
|
|||
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
|
||||
)
|
||||
|
||||
if not self.function_entrypoint_name:
|
||||
if not self._function_entrypoint_name:
|
||||
raise ComponentFunctionEntrypointNameNullError(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
|
@ -48,7 +61,7 @@ class Component(BaseModel):
|
|||
},
|
||||
)
|
||||
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
return validate.create_function(self.code, self._function_entrypoint_name)
|
||||
|
||||
def build_template_config(self, attributes) -> dict:
|
||||
template_config = {}
|
||||
|
|
|
|||
|
|
@ -1,34 +1,42 @@
|
|||
from typing import Any, Callable, List, Optional, Union
|
||||
import operator
|
||||
from typing import Any, Callable, ClassVar, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
|
||||
from langflow.interface.custom.component import Component
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.services.getters import get_db_service
|
||||
from langflow.interface.custom.utils import extract_inner_type
|
||||
|
||||
from langflow.interface.custom.utils import (
|
||||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias)
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import get_credential_service, get_db_service
|
||||
from langflow.utils import validate
|
||||
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from pydantic import Extra
|
||||
import yaml
|
||||
|
||||
|
||||
class CustomComponent(Component, extra=Extra.allow):
|
||||
code: Optional[str]
|
||||
class CustomComponent(Component):
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
field_config: dict = {}
|
||||
code_class_base_inheritance = "CustomComponent"
|
||||
function_entrypoint_name = "build"
|
||||
code_class_base_inheritance: ClassVar[str] = "CustomComponent"
|
||||
function_entrypoint_name: ClassVar[str] = "build"
|
||||
function: Optional[Callable] = None
|
||||
return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
|
||||
repr_value: Optional[Any] = ""
|
||||
user_id: Optional[Union[UUID, str]] = None
|
||||
status: Optional[Any] = None
|
||||
_tree: Optional[dict] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
super().__init__(**data)
|
||||
|
||||
def custom_repr(self):
|
||||
if self.repr_value == "":
|
||||
self.repr_value = self.status
|
||||
if isinstance(self.repr_value, dict):
|
||||
return yaml.dump(self.repr_value)
|
||||
if isinstance(self.repr_value, str):
|
||||
|
|
@ -53,47 +61,28 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
reader = DirectoryReader("", False)
|
||||
|
||||
for type_hint in TYPE_HINT_LIST:
|
||||
if reader._is_type_hint_used_in_args(
|
||||
if reader._is_type_hint_used_in_args(type_hint, code) and not reader._is_type_hint_imported(
|
||||
type_hint, code
|
||||
) and not reader._is_type_hint_imported(type_hint, code):
|
||||
):
|
||||
error_detail = {
|
||||
"error": "Type hint Error",
|
||||
"traceback": f"Type hint '{type_hint}' is used but not imported in the code.",
|
||||
}
|
||||
raise HTTPException(status_code=400, detail=error_detail)
|
||||
return True
|
||||
|
||||
def is_check_valid(self) -> bool:
|
||||
def validate(self) -> bool:
|
||||
return self._class_template_validation(self.code) if self.code else False
|
||||
|
||||
def get_code_tree(self, code: str):
|
||||
return super().get_code_tree(code)
|
||||
@property
|
||||
def tree(self):
|
||||
return self.get_code_tree(self.code)
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_args(self) -> str:
|
||||
if not self.code:
|
||||
return ""
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
if not build_methods:
|
||||
return ""
|
||||
|
||||
build_method = build_methods[0]
|
||||
def get_function_entrypoint_args(self) -> list:
|
||||
build_method = self.get_build_method()
|
||||
if not build_method:
|
||||
return []
|
||||
|
||||
args = build_method["args"]
|
||||
for arg in args:
|
||||
|
|
@ -103,65 +92,69 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
detail={
|
||||
"error": "Type hint Error",
|
||||
"traceback": (
|
||||
"Prompt type is not supported in the build method."
|
||||
" Try using PromptTemplate instead."
|
||||
"Prompt type is not supported in the build method." " Try using PromptTemplate instead."
|
||||
),
|
||||
},
|
||||
)
|
||||
elif not arg.get("type") and arg.get("name") != "self":
|
||||
# Set the type to Data
|
||||
arg["type"] = "Data"
|
||||
return args
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_return_type(self) -> List[str]:
|
||||
@cachedmethod(operator.attrgetter("cache"))
|
||||
def get_build_method(self):
|
||||
if not self.code:
|
||||
return []
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]]
|
||||
if not component_classes:
|
||||
return []
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
if not build_methods:
|
||||
return []
|
||||
|
||||
build_method = build_methods[0]
|
||||
return_type = build_method["return_type"]
|
||||
if not return_type:
|
||||
return build_methods[0]
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_return_type(self) -> List[Any]:
|
||||
build_method = self.get_build_method()
|
||||
if not build_method:
|
||||
return []
|
||||
elif not build_method["has_return"]:
|
||||
return []
|
||||
|
||||
return_type = build_method["return_type"]
|
||||
|
||||
# If list or List is in the return type, then we remove it and return the inner type
|
||||
if return_type.startswith("list") or return_type.startswith("List"):
|
||||
return_type = extract_inner_type(return_type)
|
||||
if hasattr(return_type, "__origin__") and return_type.__origin__ in [list, List]:
|
||||
return_type = extract_inner_type_from_generic_alias(return_type)
|
||||
|
||||
# If the return type is not a Union, then we just return it as a list
|
||||
if "Union" not in return_type:
|
||||
return [return_type] if return_type in self.return_type_valid_list else []
|
||||
if not hasattr(return_type, "__origin__") or return_type.__origin__ != Union:
|
||||
if isinstance(return_type, list):
|
||||
return return_type
|
||||
return [return_type]
|
||||
|
||||
# If the return type is a Union, then we need to parse it
|
||||
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
|
||||
return_type = return_type.split(",")
|
||||
return_type = [item.strip() for item in return_type]
|
||||
return [item for item in return_type if item in self.return_type_valid_list]
|
||||
# If the return type is a Union, then we need to parse itx
|
||||
return_type = extract_union_types_from_generic_alias(return_type)
|
||||
return return_type
|
||||
|
||||
@property
|
||||
def get_main_class_name(self):
|
||||
tree = self.get_code_tree(self.code)
|
||||
if not self.code:
|
||||
return ""
|
||||
|
||||
base_name = self.code_class_base_inheritance
|
||||
method_name = self.function_entrypoint_name
|
||||
|
||||
classes = []
|
||||
for item in tree.get("classes"):
|
||||
for item in self.tree.get("classes", []):
|
||||
if base_name in item["bases"]:
|
||||
method_names = [method["name"] for method in item["methods"]]
|
||||
if method_name in method_names:
|
||||
|
|
@ -171,12 +164,16 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
return next(iter(classes), "")
|
||||
|
||||
@property
|
||||
def template_config(self):
|
||||
return self.build_template_config()
|
||||
|
||||
def build_template_config(self):
|
||||
tree = self.get_code_tree(self.code)
|
||||
if not self.code:
|
||||
return {}
|
||||
|
||||
attributes = [
|
||||
main_class["attributes"]
|
||||
for main_class in tree.get("classes")
|
||||
for main_class in self.tree.get("classes", [])
|
||||
if main_class["name"] == self.get_main_class_name
|
||||
]
|
||||
# Get just the first item
|
||||
|
|
@ -184,13 +181,44 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
|
||||
return super().build_template_config(attributes)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
def get_credential(name: str):
|
||||
if hasattr(self, "_user_id") and not self._user_id:
|
||||
raise ValueError(f"User id is not set for {self.__class__.__name__}")
|
||||
credential_service = get_credential_service() # Get service instance
|
||||
# Retrieve and decrypt the credential by name for the current user
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
|
||||
|
||||
return get_credential
|
||||
|
||||
def list_key_names(self):
|
||||
if hasattr(self, "_user_id") and not self._user_id:
|
||||
raise ValueError(f"User id is not set for {self.__class__.__name__}")
|
||||
credential_service = get_credential_service()
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
return credential_service.list_credentials(user_id=self._user_id, session=session)
|
||||
|
||||
def index(self, value: int = 0):
|
||||
"""Returns a function that returns the value at the given index in the iterable."""
|
||||
|
||||
def get_index(iterable: List[Any]):
|
||||
if iterable:
|
||||
return iterable[value]
|
||||
return iterable
|
||||
|
||||
return get_index
|
||||
|
||||
@property
|
||||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
|
||||
from langflow.processing.process import build_sorted_vertices
|
||||
from langflow.processing.process import process_tweaks
|
||||
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
|
||||
from langflow.processing.process import (build_sorted_vertices,
|
||||
process_tweaks)
|
||||
|
||||
db_service = get_db_service()
|
||||
with session_getter(db_service) as session:
|
||||
|
|
@ -199,10 +227,10 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
raise ValueError(f"Flow {flow_id} not found")
|
||||
if tweaks:
|
||||
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
|
||||
return build_sorted_vertices(graph_data)
|
||||
return await build_sorted_vertices(graph_data, self.user_id)
|
||||
|
||||
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
|
||||
if not self.user_id:
|
||||
if not self._user_id:
|
||||
raise ValueError("Session is invalid")
|
||||
try:
|
||||
get_session = get_session or session_getter
|
||||
|
|
@ -213,7 +241,7 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
except Exception as e:
|
||||
raise ValueError("Session is invalid") from e
|
||||
|
||||
def get_flow(
|
||||
async def get_flow(
|
||||
self,
|
||||
*,
|
||||
flow_name: Optional[str] = None,
|
||||
|
|
@ -227,17 +255,13 @@ class CustomComponent(Component, extra=Extra.allow):
|
|||
if flow_id:
|
||||
flow = session.query(Flow).get(flow_id)
|
||||
elif flow_name:
|
||||
flow = (
|
||||
session.query(Flow)
|
||||
.filter(Flow.name == flow_name)
|
||||
.filter(Flow.user_id == self.user_id)
|
||||
).first()
|
||||
flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first()
|
||||
else:
|
||||
raise ValueError("Either flow_name or flow_id must be provided")
|
||||
|
||||
if not flow:
|
||||
raise ValueError(f"Flow {flow_name or flow_id} not found")
|
||||
return self.load_flow(flow.id, tweaks)
|
||||
return await self.load_flow(flow.id, tweaks)
|
||||
|
||||
def build(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -76,9 +76,7 @@ class DirectoryReader:
|
|||
for menu in data["menu"]
|
||||
]
|
||||
filtered = [menu for menu in items if menu["components"]]
|
||||
logger.debug(
|
||||
f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}'
|
||||
)
|
||||
logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}')
|
||||
return {"menu": filtered}
|
||||
|
||||
def validate_code(self, file_content):
|
||||
|
|
@ -111,9 +109,7 @@ class DirectoryReader:
|
|||
Walk through the directory path and return a list of all .py files.
|
||||
"""
|
||||
if not (safe_path := self.get_safe_path()):
|
||||
raise CustomComponentPathValueError(
|
||||
f"The path needs to start with '{self.base_path}'."
|
||||
)
|
||||
raise CustomComponentPathValueError(f"The path needs to start with '{self.base_path}'.")
|
||||
|
||||
file_list = []
|
||||
for root, _, files in os.walk(safe_path):
|
||||
|
|
@ -158,9 +154,7 @@ class DirectoryReader:
|
|||
for node in ast.walk(module):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
for arg in node.args.args:
|
||||
if self._is_type_hint_in_arg_annotation(
|
||||
arg.annotation, type_hint_name
|
||||
):
|
||||
if self._is_type_hint_in_arg_annotation(arg.annotation, type_hint_name):
|
||||
return True
|
||||
except SyntaxError:
|
||||
# Returns False if the code is not valid Python
|
||||
|
|
@ -178,16 +172,14 @@ class DirectoryReader:
|
|||
and annotation.value.id == type_hint_name
|
||||
)
|
||||
|
||||
def is_type_hint_used_but_not_imported(
|
||||
self, type_hint_name: str, code: str
|
||||
) -> bool:
|
||||
def is_type_hint_used_but_not_imported(self, type_hint_name: str, code: str) -> bool:
|
||||
"""
|
||||
Check if a type hint is used but not imported in the given code.
|
||||
"""
|
||||
try:
|
||||
return self._is_type_hint_used_in_args(
|
||||
return self._is_type_hint_used_in_args(type_hint_name, code) and not self._is_type_hint_imported(
|
||||
type_hint_name, code
|
||||
) and not self._is_type_hint_imported(type_hint_name, code)
|
||||
)
|
||||
except SyntaxError:
|
||||
# Returns True if there's something wrong with the code
|
||||
# TODO : Find a better way to handle this
|
||||
|
|
@ -208,9 +200,9 @@ class DirectoryReader:
|
|||
return False, "Syntax error"
|
||||
elif not self.validate_build(file_content):
|
||||
return False, "Missing build function"
|
||||
elif self._is_type_hint_used_in_args(
|
||||
elif self._is_type_hint_used_in_args("Optional", file_content) and not self._is_type_hint_imported(
|
||||
"Optional", file_content
|
||||
) and not self._is_type_hint_imported("Optional", file_content):
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"Type hint 'Optional' is used but not imported in the code.",
|
||||
|
|
@ -226,9 +218,7 @@ class DirectoryReader:
|
|||
from the .py files in the directory.
|
||||
"""
|
||||
response = {"menu": []}
|
||||
logger.debug(
|
||||
"-------------------- Building component menu list --------------------"
|
||||
)
|
||||
logger.debug("-------------------- Building component menu list --------------------")
|
||||
|
||||
for file_path in file_paths:
|
||||
menu_name = os.path.basename(os.path.dirname(file_path))
|
||||
|
|
@ -248,9 +238,7 @@ class DirectoryReader:
|
|||
|
||||
# first check if it's already CamelCase
|
||||
if "_" in component_name:
|
||||
component_name_camelcase = " ".join(
|
||||
word.title() for word in component_name.split("_")
|
||||
)
|
||||
component_name_camelcase = " ".join(word.title() for word in component_name.split("_"))
|
||||
else:
|
||||
component_name_camelcase = component_name
|
||||
|
||||
|
|
@ -266,7 +254,5 @@ class DirectoryReader:
|
|||
logger.debug(f"Component info: {component_info}")
|
||||
if menu_result not in response["menu"]:
|
||||
response["menu"].append(menu_result)
|
||||
logger.debug(
|
||||
"-------------------- Component menu list built --------------------"
|
||||
)
|
||||
logger.debug("-------------------- Component menu list built --------------------")
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ClassCodeDetails(BaseModel):
|
||||
"""
|
||||
A dataclass for storing details about a class.
|
||||
"""
|
||||
|
||||
name: str
|
||||
doc: Optional[str]
|
||||
doc: Optional[str] = None
|
||||
bases: list
|
||||
attributes: list
|
||||
methods: list
|
||||
|
|
@ -23,7 +22,8 @@ class CallableCodeDetails(BaseModel):
|
|||
"""
|
||||
|
||||
name: str
|
||||
doc: Optional[str]
|
||||
doc: Optional[str] = None
|
||||
args: list
|
||||
body: list
|
||||
return_type: Optional[str]
|
||||
return_type: Optional[Any] = None
|
||||
has_return: bool = False
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import re
|
||||
from types import GenericAlias
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_inner_type(return_type: str) -> str:
|
||||
|
|
@ -8,3 +10,31 @@ def extract_inner_type(return_type: str) -> str:
|
|||
if match := re.match(r"list\[(.*)\]", return_type, re.IGNORECASE):
|
||||
return match[1]
|
||||
return return_type
|
||||
|
||||
|
||||
def extract_inner_type_from_generic_alias(return_type: GenericAlias) -> Any:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a list.
|
||||
"""
|
||||
if return_type.__origin__ == list:
|
||||
return list(return_type.__args__)
|
||||
|
||||
return return_type
|
||||
|
||||
|
||||
def extract_union_types_from_generic_alias(return_type: GenericAlias) -> list:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a Union.
|
||||
"""
|
||||
return list(return_type.__args__)
|
||||
|
||||
|
||||
def extract_union_types(return_type: str) -> list[str]:
|
||||
"""
|
||||
Extracts the inner type from a type hint that is a list.
|
||||
"""
|
||||
# If the return type is a Union, then we need to parse it
|
||||
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
|
||||
return_types = return_type.split(",")
|
||||
return_types = [item.strip() for item in return_types]
|
||||
return return_types
|
||||
|
|
|
|||
|
|
@ -1,28 +1,21 @@
|
|||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from langchain import (
|
||||
document_loaders,
|
||||
embeddings,
|
||||
llms,
|
||||
memory,
|
||||
requests,
|
||||
text_splitter,
|
||||
)
|
||||
from langchain import document_loaders, embeddings, llms, memory, requests, text_splitter
|
||||
from langchain.agents import agent_toolkits
|
||||
from langchain.chat_models import (
|
||||
AzureChatOpenAI,
|
||||
ChatOpenAI,
|
||||
ChatVertexAI,
|
||||
ChatAnthropic,
|
||||
)
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI, ChatVertexAI
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.chains.custom import CUSTOM_CHAINS
|
||||
from langflow.interface.importing.utils import import_class
|
||||
|
||||
# LLMs
|
||||
llm_type_to_cls_dict = llms.type_to_cls_dict
|
||||
llm_type_to_cls_dict = {}
|
||||
|
||||
for k, v in llms.get_type_to_cls_dict().items():
|
||||
try:
|
||||
llm_type_to_cls_dict[k] = v()
|
||||
except Exception:
|
||||
pass
|
||||
llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
|
||||
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
|
||||
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
|
||||
|
|
@ -46,34 +39,26 @@ toolkit_type_to_cls_dict: dict[str, Any] = {
|
|||
|
||||
# Memories
|
||||
memory_type_to_cls_dict: dict[str, Any] = {
|
||||
memory_name: import_class(f"langchain.memory.{memory_name}")
|
||||
for memory_name in memory.__all__
|
||||
memory_name: import_class(f"langchain.memory.{memory_name}") for memory_name in memory.__all__
|
||||
}
|
||||
|
||||
# Wrappers
|
||||
wrapper_type_to_cls_dict: dict[str, Any] = {
|
||||
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
|
||||
}
|
||||
wrapper_type_to_cls_dict: dict[str, Any] = {wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]}
|
||||
|
||||
# Embeddings
|
||||
embedding_type_to_cls_dict: dict[str, Any] = {
|
||||
embedding_name: import_class(f"langchain.embeddings.{embedding_name}")
|
||||
for embedding_name in embeddings.__all__
|
||||
embedding_name: import_class(f"langchain.embeddings.{embedding_name}") for embedding_name in embeddings.__all__
|
||||
}
|
||||
|
||||
|
||||
# Document Loaders
|
||||
documentloaders_type_to_cls_dict: dict[str, Any] = {
|
||||
documentloader_name: import_class(
|
||||
f"langchain.document_loaders.{documentloader_name}"
|
||||
)
|
||||
documentloader_name: import_class(f"langchain.document_loaders.{documentloader_name}")
|
||||
for documentloader_name in document_loaders.__all__
|
||||
}
|
||||
|
||||
# Text Splitters
|
||||
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
|
||||
inspect.getmembers(text_splitter, inspect.isclass)
|
||||
)
|
||||
textsplitter_type_to_cls_dict: dict[str, Any] = dict(inspect.getmembers(text_splitter, inspect.isclass))
|
||||
|
||||
# merge CUSTOM_AGENTS and CUSTOM_CHAINS
|
||||
CUSTOM_NODES = {**CUSTOM_AGENTS, **CUSTOM_CHAINS} # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.template.frontend_node.documentloaders import DocumentLoaderFrontNode
|
||||
from langflow.interface.custom_lists import documentloaders_type_to_cls_dict
|
||||
|
||||
|
|
@ -35,8 +35,7 @@ class DocumentLoaderCreator(LangChainTypeCreator):
|
|||
return [
|
||||
documentloader.__name__
|
||||
for documentloader in self.type_to_loader_dict.values()
|
||||
if documentloader.__name__ in settings_service.settings.DOCUMENTLOADERS
|
||||
or settings_service.settings.DEV
|
||||
if documentloader.__name__ in settings_service.settings.DOCUMENTLOADERS or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Type
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import embedding_type_to_cls_dict
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.embeddings import EmbeddingFrontendNode
|
||||
|
|
@ -37,8 +37,7 @@ class EmbeddingCreator(LangChainTypeCreator):
|
|||
return [
|
||||
embedding.__name__
|
||||
for embedding in self.type_to_loader_dict.values()
|
||||
if embedding.__name__ in settings_service.settings.EMBEDDINGS
|
||||
or settings_service.settings.DEV
|
||||
if embedding.__name__ in settings_service.settings.EMBEDDINGS or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,15 +3,15 @@
|
|||
import importlib
|
||||
from typing import Any, Type
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.agents import Agent
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.tools import BaseTool
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.utils import validate
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.utils import validate
|
||||
|
||||
|
||||
def import_module(module_path: str) -> Any:
|
||||
|
|
@ -104,10 +104,7 @@ def import_prompt(prompt: str) -> Type[PromptTemplate]:
|
|||
|
||||
def import_wrapper(wrapper: str) -> Any:
|
||||
"""Import wrapper from wrapper name"""
|
||||
if (
|
||||
isinstance(wrapper_creator.type_dict, dict)
|
||||
and wrapper in wrapper_creator.type_dict
|
||||
):
|
||||
if isinstance(wrapper_creator.type_dict, dict) and wrapper in wrapper_creator.type_dict:
|
||||
return wrapper_creator.type_dict.get(wrapper)
|
||||
|
||||
|
||||
|
|
@ -183,6 +180,7 @@ def get_function(code):
|
|||
return validate.create_function(code, function_name)
|
||||
|
||||
|
||||
def get_function_custom(code):
|
||||
def eval_custom_component_code(code: str) -> Type[CustomComponent]:
|
||||
"""Evaluate custom component code"""
|
||||
class_name = validate.extract_class_name(code)
|
||||
return validate.create_class(code, class_name)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ def initialize_vertexai(class_object, params):
|
|||
if credentials_path := params.get("credentials"):
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
|
||||
credentials_object = service_account.Credentials.from_service_account_file(
|
||||
filename=credentials_path
|
||||
)
|
||||
credentials_object = service_account.Credentials.from_service_account_file(filename=credentials_path)
|
||||
params["credentials"] = credentials_object
|
||||
return class_object(**params)
|
||||
|
|
|
|||
|
|
@ -1,40 +1,30 @@
|
|||
import inspect
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Type
|
||||
|
||||
import orjson
|
||||
from typing import Any, Callable, Dict, Sequence, Type, TYPE_CHECKING
|
||||
from langchain.schema import Document
|
||||
from langchain.agents import agent as agent_module
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.agents.tools import BaseTool
|
||||
from langflow.interface.initialize.llm import initialize_vertexai
|
||||
from langflow.interface.initialize.utils import (
|
||||
handle_format_kwargs,
|
||||
handle_node_type,
|
||||
handle_partial_variables,
|
||||
)
|
||||
|
||||
from langflow.interface.initialize.vector_store import vecstore_initializer
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.importing.utils import (
|
||||
get_function,
|
||||
get_function_custom,
|
||||
import_by_type,
|
||||
)
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.importing.utils import eval_custom_component_code, get_function, import_by_type
|
||||
from langflow.interface.initialize.llm import initialize_vertexai
|
||||
from langflow.interface.initialize.utils import handle_format_kwargs, handle_node_type, handle_partial_variables
|
||||
from langflow.interface.initialize.vector_store import vecstore_initializer
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.utils import validate
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow import CustomComponent
|
||||
|
|
@ -44,15 +34,10 @@ def build_vertex_in_params(params: Dict) -> Dict:
|
|||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
# If any of the values in params is a Vertex, we will build it
|
||||
return {
|
||||
key: value.build() if isinstance(value, Vertex) else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()}
|
||||
|
||||
|
||||
def instantiate_class(
|
||||
node_type: str, base_type: str, params: Dict, user_id=None
|
||||
) -> Any:
|
||||
async def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
|
|
@ -64,9 +49,7 @@ def instantiate_class(
|
|||
return custom_node(**params)
|
||||
logger.debug(f"Instantiating {node_type} of type {base_type}")
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return instantiate_based_on_type(
|
||||
class_object, base_type, node_type, params, user_id=user_id
|
||||
)
|
||||
return await instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id)
|
||||
|
||||
|
||||
def convert_params_to_sets(params):
|
||||
|
|
@ -93,7 +76,7 @@ def convert_kwargs(params):
|
|||
return params
|
||||
|
||||
|
||||
def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
|
||||
async def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(node_type, class_object, params)
|
||||
elif base_type == "prompts":
|
||||
|
|
@ -127,20 +110,31 @@ def instantiate_based_on_type(class_object, base_type, node_type, params, user_i
|
|||
elif base_type == "memory":
|
||||
return instantiate_memory(node_type, class_object, params)
|
||||
elif base_type == "custom_components":
|
||||
return instantiate_custom_component(node_type, class_object, params, user_id)
|
||||
return await instantiate_custom_component(node_type, class_object, params, user_id)
|
||||
elif base_type == "wrappers":
|
||||
return instantiate_wrapper(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_custom_component(node_type, class_object, params, user_id):
|
||||
# we need to make a copy of the params because we will be
|
||||
# modifying it
|
||||
async def instantiate_custom_component(node_type, class_object, params, user_id):
|
||||
params_copy = params.copy()
|
||||
class_object: "CustomComponent" = get_function_custom(params_copy.pop("code"))
|
||||
class_object: "CustomComponent" = eval_custom_component_code(params_copy.pop("code"))
|
||||
custom_component = class_object(user_id=user_id)
|
||||
built_object = custom_component.build(**params_copy)
|
||||
|
||||
if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"):
|
||||
params_copy["retriever"] = params_copy["retriever"].as_retriever()
|
||||
|
||||
# Determine if the build method is asynchronous
|
||||
is_async = inspect.iscoroutinefunction(custom_component.build)
|
||||
|
||||
if is_async:
|
||||
# Await the build method directly if it's async
|
||||
built_object = await custom_component.build(**params_copy)
|
||||
else:
|
||||
# Call the build method directly if it's sync
|
||||
built_object = custom_component.build(**params_copy)
|
||||
|
||||
return built_object, {"repr": custom_component.custom_repr()}
|
||||
|
||||
|
||||
|
|
@ -194,9 +188,7 @@ def instantiate_memory(node_type, class_object, params):
|
|||
# I want to catch a specific attribute error that happens
|
||||
# when the object does not have a cursor attribute
|
||||
except Exception as exc:
|
||||
if "object has no attribute 'cursor'" in str(
|
||||
exc
|
||||
) or 'object has no field "conn"' in str(exc):
|
||||
if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(exc):
|
||||
raise AttributeError(
|
||||
(
|
||||
"Failed to build connection to database."
|
||||
|
|
@ -218,6 +210,8 @@ def instantiate_retriever(node_type, class_object, params):
|
|||
|
||||
|
||||
def instantiate_chains(node_type, class_object: Type[Chain], params: Dict):
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
if node_type in chain_creator.from_method_nodes:
|
||||
|
|
@ -230,14 +224,14 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict):
|
|||
|
||||
|
||||
def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: Dict):
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
|
||||
if node_type in agent_creator.from_method_nodes:
|
||||
method = agent_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
agent = class_method(**params)
|
||||
tools = params.get("tools", [])
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, handle_parsing_errors=True
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, handle_parsing_errors=True)
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
||||
|
|
@ -287,14 +281,13 @@ def instantiate_embedding(node_type, class_object, params: Dict):
|
|||
if "VertexAI" in node_type:
|
||||
return initialize_vertexai(class_object=class_object, params=params)
|
||||
|
||||
if "OpenAIEmbedding" in node_type:
|
||||
params["disallowed_special"] = ()
|
||||
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.__fields__
|
||||
}
|
||||
params = {key: value for key, value in params.items() if key in class_object.model_fields}
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
|
|
@ -306,9 +299,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
|
|||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
if "documents" in params:
|
||||
params["documents"] = [
|
||||
doc for doc in params["documents"] if isinstance(doc, Document)
|
||||
]
|
||||
params["documents"] = [doc for doc in params["documents"] if isinstance(doc, Document)]
|
||||
if initializer := vecstore_initializer.get(class_object.__name__):
|
||||
vecstore = initializer(class_object, params)
|
||||
else:
|
||||
|
|
@ -323,9 +314,7 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
|
|||
return vecstore
|
||||
|
||||
|
||||
def instantiate_documentloader(
|
||||
node_type: str, class_object: Type[BaseLoader], params: Dict
|
||||
):
|
||||
def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], params: Dict):
|
||||
if "file_filter" in params:
|
||||
# file_filter will be a string but we need a function
|
||||
# that will be used to filter the files using file_filter
|
||||
|
|
@ -334,17 +323,13 @@ def instantiate_documentloader(
|
|||
# in x and if it is, we will return True
|
||||
file_filter = params.pop("file_filter")
|
||||
extensions = file_filter.split(",")
|
||||
params["file_filter"] = lambda x: any(
|
||||
extension.strip() in x for extension in extensions
|
||||
)
|
||||
params["file_filter"] = lambda x: any(extension.strip() in x for extension in extensions)
|
||||
metadata = params.pop("metadata", None)
|
||||
if metadata and isinstance(metadata, str):
|
||||
try:
|
||||
metadata = orjson.loads(metadata)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(
|
||||
"The metadata you provided is not a valid JSON string."
|
||||
) from exc
|
||||
raise ValueError("The metadata you provided is not a valid JSON string.") from exc
|
||||
|
||||
if node_type == "WebBaseLoader":
|
||||
if web_path := params.pop("web_path", None):
|
||||
|
|
@ -377,16 +362,12 @@ def instantiate_textsplitter(
|
|||
"Try changing the chunk_size of the Text Splitter."
|
||||
) from exc
|
||||
|
||||
if (
|
||||
"separator_type" in params and params["separator_type"] == "Text"
|
||||
) or "separator_type" not in params:
|
||||
if ("separator_type" in params and params["separator_type"] == "Text") or "separator_type" not in params:
|
||||
params.pop("separator_type", None)
|
||||
# separators might come in as an escaped string like \\n
|
||||
# so we need to convert it to a string
|
||||
if "separators" in params:
|
||||
params["separators"] = (
|
||||
params["separators"].encode().decode("unicode-escape")
|
||||
)
|
||||
params["separators"] = params["separators"].encode().decode("unicode-escape")
|
||||
text_splitter = class_object(**params)
|
||||
else:
|
||||
from langchain.text_splitter import Language
|
||||
|
|
@ -413,8 +394,7 @@ def replace_zero_shot_prompt_with_prompt_template(nodes):
|
|||
tools = [
|
||||
tool
|
||||
for tool in nodes
|
||||
if tool["type"] != "chatOutputNode"
|
||||
and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
if tool["type"] != "chatOutputNode" and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
]
|
||||
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
|
||||
break
|
||||
|
|
@ -428,9 +408,7 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs)
|
|||
# agent has hidden args for memory. might need to be support
|
||||
# memory = params["memory"]
|
||||
# if allowed_tools is not a list or set, make it a list
|
||||
if not isinstance(allowed_tools, (list, set)) and isinstance(
|
||||
allowed_tools, BaseTool
|
||||
):
|
||||
if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool):
|
||||
allowed_tools = [allowed_tools]
|
||||
tool_names = [tool.name for tool in allowed_tools]
|
||||
# Agent class requires an output_parser but Agent classes
|
||||
|
|
@ -458,10 +436,7 @@ def build_prompt_template(prompt, tools):
|
|||
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
|
||||
|
||||
tool_strings = "\n".join(
|
||||
[
|
||||
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
|
||||
for tool in tools
|
||||
]
|
||||
[f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools]
|
||||
)
|
||||
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import contextlib
|
||||
import json
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
import orjson
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import orjson
|
||||
from langchain.agents import ZeroShotAgent
|
||||
|
||||
|
||||
from langchain.schema import Document, BaseOutputParser
|
||||
from langchain.schema import BaseOutputParser, Document
|
||||
from langflow.services.database.models.base import orjson_dumps
|
||||
|
||||
|
||||
def handle_node_type(node_type, class_object, params: Dict):
|
||||
|
|
@ -30,9 +28,7 @@ def check_tools_in_params(params: Dict):
|
|||
|
||||
|
||||
def instantiate_from_template(class_object, params: Dict):
|
||||
from_template_params = {
|
||||
"template": params.pop("prompt", params.pop("template", ""))
|
||||
}
|
||||
from_template_params = {"template": params.pop("prompt", params.pop("template", ""))}
|
||||
if not from_template_params.get("template"):
|
||||
raise ValueError("Prompt template is required")
|
||||
return class_object.from_template(**from_template_params)
|
||||
|
|
@ -48,9 +44,7 @@ def handle_format_kwargs(prompt, params: Dict):
|
|||
|
||||
def handle_partial_variables(prompt, format_kwargs: Dict):
|
||||
partial_variables = format_kwargs.copy()
|
||||
partial_variables = {
|
||||
key: value for key, value in partial_variables.items() if value
|
||||
}
|
||||
partial_variables = {key: value for key, value in partial_variables.items() if value}
|
||||
# Remove handle_keys otherwise LangChain raises an error
|
||||
partial_variables.pop("handle_keys", None)
|
||||
if partial_variables and hasattr(prompt, "partial"):
|
||||
|
|
@ -62,9 +56,7 @@ def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict):
|
|||
variable = params[input_variable]
|
||||
if isinstance(variable, str):
|
||||
format_kwargs[input_variable] = variable
|
||||
elif isinstance(variable, BaseOutputParser) and hasattr(
|
||||
variable, "get_format_instructions"
|
||||
):
|
||||
elif isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions"):
|
||||
format_kwargs[input_variable] = variable.get_format_instructions()
|
||||
elif is_instance_of_list_or_document(variable):
|
||||
format_kwargs = format_document(variable, input_variable, format_kwargs)
|
||||
|
|
@ -91,8 +83,10 @@ def format_document(variable, input_variable: str, format_kwargs: Dict):
|
|||
def format_content(variable):
|
||||
if len(variable) > 1:
|
||||
return "\n".join([item.page_content for item in variable if item.page_content])
|
||||
content = variable[0].page_content
|
||||
return try_to_load_json(content)
|
||||
elif len(variable) == 1:
|
||||
content = variable[0].page_content
|
||||
return try_to_load_json(content)
|
||||
return ""
|
||||
|
||||
|
||||
def try_to_load_json(content):
|
||||
|
|
@ -107,8 +101,7 @@ def try_to_load_json(content):
|
|||
|
||||
def needs_handle_keys(variable):
|
||||
return is_instance_of_list_or_document(variable) or (
|
||||
isinstance(variable, BaseOutputParser)
|
||||
and hasattr(variable, "get_format_instructions")
|
||||
isinstance(variable, BaseOutputParser) and hasattr(variable, "get_format_instructions")
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ from langchain.vectorstores import (
|
|||
def docs_in_params(params: dict) -> bool:
|
||||
"""Check if params has documents OR texts and one of them is not an empty list,
|
||||
If any of them is not an empty list, return True, else return False"""
|
||||
return ("documents" in params and params["documents"]) or (
|
||||
"texts" in params and params["texts"]
|
||||
)
|
||||
return ("documents" in params and params["documents"]) or ("texts" in params and params["texts"])
|
||||
|
||||
|
||||
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
|
||||
|
|
@ -31,9 +29,7 @@ def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dic
|
|||
import certifi
|
||||
from pymongo import MongoClient
|
||||
|
||||
client: MongoClient = MongoClient(
|
||||
MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where()
|
||||
)
|
||||
client: MongoClient = MongoClient(MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where())
|
||||
db_name = params.pop("db_name", None)
|
||||
collection_name = params.pop("collection_name", None)
|
||||
if not db_name or not collection_name:
|
||||
|
|
@ -142,9 +138,7 @@ def initialize_pinecone(class_object: Type[Pinecone], params: dict):
|
|||
pinecone_env = os.getenv("PINECONE_ENV")
|
||||
|
||||
if pinecone_api_key is None or pinecone_env is None:
|
||||
raise ValueError(
|
||||
"Pinecone API key and environment must be provided in the params"
|
||||
)
|
||||
raise ValueError("Pinecone API key and environment must be provided in the params")
|
||||
|
||||
# initialize pinecone
|
||||
pinecone.init(
|
||||
|
|
@ -178,26 +172,20 @@ def initialize_chroma(class_object: Type[Chroma], params: dict):
|
|||
import chromadb # type: ignore
|
||||
|
||||
settings_params = {
|
||||
key: params[key]
|
||||
for key, value_ in params.items()
|
||||
if key.startswith("chroma_server_") and value_
|
||||
key: params[key] for key, value_ in params.items() if key.startswith("chroma_server_") and value_
|
||||
}
|
||||
chroma_settings = chromadb.config.Settings(**settings_params)
|
||||
params["client_settings"] = chroma_settings
|
||||
else:
|
||||
# remove all chroma_server_ keys from params
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if not key.startswith("chroma_server_")
|
||||
}
|
||||
params = {key: value for key, value in params.items() if not key.startswith("chroma_server_")}
|
||||
|
||||
persist = params.pop("persist", False)
|
||||
if not docs_in_params(params):
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
params["embedding_function"] = params.pop("embedding")
|
||||
chromadb = class_object(**params)
|
||||
chromadb_instance = class_object(**params)
|
||||
else:
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
|
@ -212,10 +200,10 @@ def initialize_chroma(class_object: Type[Chroma], params: dict):
|
|||
if value is None:
|
||||
doc.metadata[key] = ""
|
||||
|
||||
chromadb = class_object.from_documents(**params)
|
||||
chromadb_instance = class_object.from_documents(**params)
|
||||
if persist:
|
||||
chromadb.persist()
|
||||
return chromadb
|
||||
chromadb_instance.persist()
|
||||
return chromadb_instance
|
||||
|
||||
|
||||
def initialize_qdrant(class_object: Type[Qdrant], params: dict):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.utils.lazy_load import LazyLoadDictBase
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Type
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import llm_type_to_cls_dict
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.llms import LLMFrontendNode
|
||||
from loguru import logger
|
||||
|
|
@ -38,8 +38,7 @@ class LLMCreator(LangChainTypeCreator):
|
|||
return [
|
||||
llm.__name__
|
||||
for llm in self.type_to_loader_dict.values()
|
||||
if llm.__name__ in settings_service.settings.LLMS
|
||||
or settings_service.settings.DEV
|
||||
if llm.__name__ in settings_service.settings.LLMS or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
from typing import ClassVar, Dict, List, Optional, Type
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.custom_lists import memory_type_to_cls_dict
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.memories import MemoryFrontendNode
|
||||
|
|
@ -14,7 +14,7 @@ from langflow.custom.customs import get_custom_nodes
|
|||
class MemoryCreator(LangChainTypeCreator):
|
||||
type_name: str = "memories"
|
||||
|
||||
from_method_nodes = {
|
||||
from_method_nodes: ClassVar[Dict] = {
|
||||
"ZepChatMessageHistory": "__init__",
|
||||
"SQLiteEntityStore": "__init__",
|
||||
}
|
||||
|
|
@ -53,8 +53,7 @@ class MemoryCreator(LangChainTypeCreator):
|
|||
return [
|
||||
memory.__name__
|
||||
for memory in self.type_to_loader_dict.values()
|
||||
if memory.__name__ in settings_service.settings.MEMORIES
|
||||
or settings_service.settings.DEV
|
||||
if memory.__name__ in settings_service.settings.MEMORIES or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
from typing import ClassVar, Dict, List, Optional, Type
|
||||
|
||||
from langchain import output_parsers
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.output_parsers import OutputParserFrontendNode
|
||||
from loguru import logger
|
||||
|
|
@ -13,7 +13,7 @@ from langflow.utils.util import build_template_from_class, build_template_from_m
|
|||
|
||||
class OutputParserCreator(LangChainTypeCreator):
|
||||
type_name: str = "output_parsers"
|
||||
from_method_nodes = {
|
||||
from_method_nodes: ClassVar[Dict] = {
|
||||
"StructuredOutputParser": "from_response_schemas",
|
||||
}
|
||||
|
||||
|
|
@ -26,17 +26,14 @@ class OutputParserCreator(LangChainTypeCreator):
|
|||
if self.type_dict is None:
|
||||
settings_service = get_settings_service()
|
||||
self.type_dict = {
|
||||
output_parser_name: import_class(
|
||||
f"langchain.output_parsers.{output_parser_name}"
|
||||
)
|
||||
output_parser_name: import_class(f"langchain.output_parsers.{output_parser_name}")
|
||||
# if output_parser_name is not lower case it is a class
|
||||
for output_parser_name in output_parsers.__all__
|
||||
}
|
||||
self.type_dict = {
|
||||
name: output_parser
|
||||
for name, output_parser in self.type_dict.items()
|
||||
if name in settings_service.settings.OUTPUT_PARSERS
|
||||
or settings_service.settings.DEV
|
||||
if name in settings_service.settings.OUTPUT_PARSERS or settings_service.settings.DEV
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langchain import prompts
|
|||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.prompts import PromptFrontendNode
|
||||
from loguru import logger
|
||||
|
|
@ -36,8 +36,7 @@ class PromptCreator(LangChainTypeCreator):
|
|||
self.type_dict = {
|
||||
name: prompt
|
||||
for name, prompt in self.type_dict.items()
|
||||
if name in settings_service.settings.PROMPTS
|
||||
or settings_service.settings.DEV
|
||||
if name in settings_service.settings.PROMPTS or settings_service.settings.DEV
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pydantic import root_validator
|
||||
from pydantic.v1 import root_validator
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
|
|
@ -42,17 +42,13 @@ class BaseCustomPrompt(PromptTemplate):
|
|||
values["template"] = values["template"].format(**format_dict)
|
||||
|
||||
values["template"] = values["template"]
|
||||
values["input_variables"] = extract_input_variables_from_prompt(
|
||||
values["template"]
|
||||
)
|
||||
values["input_variables"] = extract_input_variables_from_prompt(values["template"])
|
||||
return values
|
||||
|
||||
|
||||
class SeriesCharacterPrompt(BaseCustomPrompt):
|
||||
# Add a very descriptive description for the prompt generator
|
||||
description: Optional[
|
||||
str
|
||||
] = "A prompt that asks the AI to act like a character from a series."
|
||||
description: Optional[str] = "A prompt that asks the AI to act like a character from a series."
|
||||
character: str
|
||||
series: str
|
||||
template: str = """I want you to act like {character} from {series}.
|
||||
|
|
@ -68,6 +64,4 @@ Human: {input}
|
|||
input_variables: List[str] = ["character", "series"]
|
||||
|
||||
|
||||
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {
|
||||
"SeriesCharacterPrompt": SeriesCharacterPrompt
|
||||
}
|
||||
CUSTOM_PROMPTS: Dict[str, Type[BaseCustomPrompt]] = {"SeriesCharacterPrompt": SeriesCharacterPrompt}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Type
|
||||
|
||||
from langchain import retrievers
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.retrievers import RetrieverFrontendNode
|
||||
from loguru import logger
|
||||
|
|
@ -14,7 +14,10 @@ from langflow.utils.util import build_template_from_method, build_template_from_
|
|||
class RetrieverCreator(LangChainTypeCreator):
|
||||
type_name: str = "retrievers"
|
||||
|
||||
from_method_nodes = {"MultiQueryRetriever": "from_llm", "ZepRetriever": "__init__"}
|
||||
from_method_nodes: ClassVar[Dict] = {
|
||||
"MultiQueryRetriever": "from_llm",
|
||||
"ZepRetriever": "__init__",
|
||||
}
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[RetrieverFrontendNode]:
|
||||
|
|
@ -39,9 +42,7 @@ class RetrieverCreator(LangChainTypeCreator):
|
|||
method_name=self.from_method_nodes[name],
|
||||
)
|
||||
else:
|
||||
return build_template_from_class(
|
||||
name, type_to_cls_dict=self.type_to_loader_dict
|
||||
)
|
||||
return build_template_from_class(name, type_to_cls_dict=self.type_to_loader_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Retriever {name} not found") from exc
|
||||
except AttributeError as exc:
|
||||
|
|
@ -53,8 +54,7 @@ class RetrieverCreator(LangChainTypeCreator):
|
|||
return [
|
||||
retriever
|
||||
for retriever in self.type_to_loader_dict.keys()
|
||||
if retriever in settings_service.settings.RETRIEVERS
|
||||
or settings_service.settings.DEV
|
||||
if retriever in settings_service.settings.RETRIEVERS or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from typing import Dict, Tuple
|
||||
from langflow.graph import Graph
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph import Graph
|
||||
|
||||
def build_sorted_vertices(data_graph) -> Tuple[Graph, Dict]:
|
||||
|
||||
async def build_sorted_vertices(data_graph, user_id: Optional[Union[str, UUID]] = None) -> Tuple[Graph, Dict]:
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
|
@ -13,28 +16,12 @@ def build_sorted_vertices(data_graph) -> Tuple[Graph, Dict]:
|
|||
sorted_vertices = graph.topological_sort()
|
||||
artifacts = {}
|
||||
for vertex in sorted_vertices:
|
||||
vertex.build()
|
||||
await vertex.build(user_id=user_id)
|
||||
if vertex.artifacts:
|
||||
artifacts.update(vertex.artifacts)
|
||||
return graph, artifacts
|
||||
|
||||
|
||||
def build_langchain_object(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
nodes = data_graph["nodes"]
|
||||
# Add input variables
|
||||
# nodes = payload.extract_input_variables(nodes)
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
return graph.build()
|
||||
|
||||
|
||||
def get_memory_key(langchain_object):
|
||||
"""
|
||||
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.template.frontend_node.textsplitters import TextSplittersFrontendNode
|
||||
from langflow.interface.custom_lists import textsplitter_type_to_cls_dict
|
||||
|
||||
|
|
@ -35,8 +35,7 @@ class TextSplitterCreator(LangChainTypeCreator):
|
|||
return [
|
||||
textsplitter.__name__
|
||||
for textsplitter in self.type_to_loader_dict.values()
|
||||
if textsplitter.__name__ in settings_service.settings.TEXTSPLITTERS
|
||||
or settings_service.settings.DEV
|
||||
if textsplitter.__name__ in settings_service.settings.TEXTSPLITTERS or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langchain.agents import agent_toolkits
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class, import_module
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
|
@ -32,13 +32,10 @@ class ToolkitCreator(LangChainTypeCreator):
|
|||
if self.type_dict is None:
|
||||
settings_service = get_settings_service()
|
||||
self.type_dict = {
|
||||
toolkit_name: import_class(
|
||||
f"langchain.agents.agent_toolkits.{toolkit_name}"
|
||||
)
|
||||
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
|
||||
# if toolkit_name is not lower case it is a class
|
||||
for toolkit_name in agent_toolkits.__all__
|
||||
if not toolkit_name.islower()
|
||||
and toolkit_name in settings_service.settings.TOOLKITS
|
||||
if not toolkit_name.islower() and toolkit_name in settings_service.settings.TOOLKITS
|
||||
}
|
||||
|
||||
return self.type_dict
|
||||
|
|
@ -61,9 +58,7 @@ class ToolkitCreator(LangChainTypeCreator):
|
|||
|
||||
def get_create_function(self, name: str) -> Callable:
|
||||
if loader_name := self.create_functions.get(name):
|
||||
return import_module(
|
||||
f"from langchain.agents.agent_toolkits import {loader_name[0]}"
|
||||
)
|
||||
return import_module(f"from langchain.agents.agent_toolkits import {loader_name[0]}")
|
||||
else:
|
||||
raise ValueError("Toolkit not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,13 @@ from langflow.interface.tools.constants import (
|
|||
OTHER_TOOLS,
|
||||
)
|
||||
from langflow.interface.tools.util import get_tool_params
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils import util
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
TOOL_INPUTS = {
|
||||
"str": TemplateField(
|
||||
|
|
@ -32,11 +33,9 @@ TOOL_INPUTS = {
|
|||
placeholder="",
|
||||
value="",
|
||||
),
|
||||
"llm": TemplateField(
|
||||
field_type="BaseLanguageModel", required=True, is_list=False, show=True
|
||||
),
|
||||
"llm": TemplateField(field_type="BaseLanguageModel", required=True, is_list=False, show=True),
|
||||
"func": TemplateField(
|
||||
field_type="function",
|
||||
field_type="Callable",
|
||||
required=True,
|
||||
is_list=False,
|
||||
show=True,
|
||||
|
|
@ -56,8 +55,7 @@ TOOL_INPUTS = {
|
|||
is_list=False,
|
||||
show=True,
|
||||
value="",
|
||||
suffixes=[".json", ".yaml", ".yml"],
|
||||
file_types=["json", "yaml", "yml"],
|
||||
file_types=[".json", ".yaml", ".yml"],
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -73,14 +71,15 @@ class ToolCreator(LangChainTypeCreator):
|
|||
all_tools = {}
|
||||
|
||||
for tool, tool_fcn in ALL_TOOLS_NAMES.items():
|
||||
tool_params = get_tool_params(tool_fcn)
|
||||
try:
|
||||
tool_params = get_tool_params(tool_fcn)
|
||||
except Exception:
|
||||
logger.error(f"Error getting params for tool {tool}")
|
||||
continue
|
||||
|
||||
tool_name = tool_params.get("name") or tool
|
||||
|
||||
if (
|
||||
tool_name in settings_service.settings.TOOLS
|
||||
or settings_service.settings.DEV
|
||||
):
|
||||
if tool_name in settings_service.settings.TOOLS or settings_service.settings.DEV:
|
||||
if tool_name == "JsonSpec":
|
||||
tool_params["path"] = tool_params.pop("dict_") # type: ignore
|
||||
all_tools[tool_name] = {
|
||||
|
|
@ -122,7 +121,7 @@ class ToolCreator(LangChainTypeCreator):
|
|||
elif tool_type in CUSTOM_TOOLS:
|
||||
# Get custom tool params
|
||||
params = self.type_to_loader_dict[name]["params"] # type: ignore
|
||||
base_classes = ["function"]
|
||||
base_classes = ["Callable"]
|
||||
if node := customs.get_custom_nodes("tools").get(tool_type):
|
||||
return node
|
||||
elif tool_type in FILE_TOOLS:
|
||||
|
|
@ -132,10 +131,15 @@ class ToolCreator(LangChainTypeCreator):
|
|||
tool_dict = build_template_from_class(tool_type, OTHER_TOOLS)
|
||||
fields = tool_dict["template"]
|
||||
|
||||
# _type is the only key in fields
|
||||
# return None
|
||||
if len(fields) == 1 and "_type" in fields:
|
||||
return None
|
||||
|
||||
# Pop unnecessary fields and add name
|
||||
fields.pop("_type") # type: ignore
|
||||
fields.pop("return_direct") # type: ignore
|
||||
fields.pop("verbose") # type: ignore
|
||||
fields.pop("return_direct", None) # type: ignore
|
||||
fields.pop("verbose", None) # type: ignore
|
||||
|
||||
tool_params = {
|
||||
"name": fields.pop("name")["value"], # type: ignore
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Callable, Optional
|
||||
from langflow.interface.importing.utils import get_function
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic.v1 import BaseModel, validator
|
||||
|
||||
from langflow.utils import validate
|
||||
from langchain.agents.tools import Tool
|
||||
|
|
|
|||
|
|
@ -21,16 +21,12 @@ def get_func_tool_params(func, **kwargs) -> Union[Dict, None]:
|
|||
for keyword in tool.keywords:
|
||||
if keyword.arg == "name":
|
||||
try:
|
||||
tool_params["name"] = ast.literal_eval(
|
||||
keyword.value
|
||||
)
|
||||
tool_params["name"] = ast.literal_eval(keyword.value)
|
||||
except ValueError:
|
||||
break
|
||||
elif keyword.arg == "description":
|
||||
try:
|
||||
tool_params["description"] = ast.literal_eval(
|
||||
keyword.value
|
||||
)
|
||||
tool_params["description"] = ast.literal_eval(keyword.value)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
|
|
@ -43,9 +39,7 @@ def get_func_tool_params(func, **kwargs) -> Union[Dict, None]:
|
|||
else:
|
||||
# get the class object from the return statement
|
||||
try:
|
||||
class_obj = eval(
|
||||
compile(ast.Expression(tool), "<string>", "eval")
|
||||
)
|
||||
class_obj = eval(compile(ast.Expression(tool), "<string>", "eval"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,43 +1,40 @@
|
|||
import ast
|
||||
import contextlib
|
||||
from typing import Any, List
|
||||
from langflow.api.utils import get_new_key
|
||||
import re
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from cachetools import LRUCache, cached
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from langflow.field_typing.range_spec import RangeSpec
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.interface.custom.utils import extract_inner_type
|
||||
from langflow.interface.document_loaders.base import documentloader_creator
|
||||
from langflow.interface.embeddings.base import embedding_creator
|
||||
from langflow.interface.importing.utils import get_function_custom
|
||||
from langflow.interface.importing.utils import eval_custom_component_code
|
||||
from langflow.interface.llms.base import llm_creator
|
||||
from langflow.interface.memories.base import memory_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.prompts.base import prompt_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.text_splitters.base import textsplitter_creator
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.tools.base import tool_creator
|
||||
from langflow.interface.utilities.base import utility_creator
|
||||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE
|
||||
from langflow.template.frontend_node.custom_components import (
|
||||
CustomComponentFrontendNode,
|
||||
)
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from loguru import logger
|
||||
from langflow.template.frontend_node.custom_components import \
|
||||
CustomComponentFrontendNode
|
||||
from langflow.utils.util import get_base_classes
|
||||
|
||||
import re
|
||||
import warnings
|
||||
import traceback
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
# Used to get the base_classes list
|
||||
def get_type_list():
|
||||
|
|
@ -52,6 +49,7 @@ def get_type_list():
|
|||
return all_types
|
||||
|
||||
|
||||
@cached(LRUCache(maxsize=1))
|
||||
def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
||||
"""Build a dictionary of all langchain types"""
|
||||
all_types = {}
|
||||
|
|
@ -72,7 +70,6 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
|||
utility_creator,
|
||||
output_parser_creator,
|
||||
retriever_creator,
|
||||
custom_component_creator,
|
||||
]
|
||||
|
||||
all_types = {}
|
||||
|
|
@ -92,7 +89,7 @@ def process_type(field_type: str):
|
|||
|
||||
# TODO: Move to correct place
|
||||
def add_new_custom_field(
|
||||
template,
|
||||
frontend_node: CustomComponentFrontendNode,
|
||||
field_name: str,
|
||||
field_type: str,
|
||||
field_value: Any,
|
||||
|
|
@ -114,16 +111,10 @@ def add_new_custom_field(
|
|||
# If options is a list, then it's a dropdown
|
||||
# If options is None, then it's a list of strings
|
||||
is_list = isinstance(field_config.get("options"), list)
|
||||
field_config["is_list"] = (
|
||||
is_list or field_config.get("is_list", False) or field_contains_list
|
||||
)
|
||||
field_config["is_list"] = is_list or field_config.get("is_list", False) or field_contains_list
|
||||
|
||||
if "name" in field_config:
|
||||
warnings.warn(
|
||||
"The 'name' key in field_config is used to build the object and can't be changed."
|
||||
)
|
||||
field_config.pop("name", None)
|
||||
|
||||
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
|
||||
required = field_config.pop("required", field_required)
|
||||
placeholder = field_config.pop("placeholder", "")
|
||||
|
||||
|
|
@ -136,36 +127,39 @@ def add_new_custom_field(
|
|||
advanced=field_advanced,
|
||||
placeholder=placeholder,
|
||||
display_name=display_name,
|
||||
**field_config,
|
||||
**sanitize_field_config(field_config),
|
||||
)
|
||||
template.get("template")[field_name] = new_field.to_dict()
|
||||
template.get("custom_fields")[field_name] = None
|
||||
frontend_node.template.upsert_field(field_name, new_field)
|
||||
if isinstance(frontend_node.custom_fields, dict):
|
||||
frontend_node.custom_fields[field_name] = None
|
||||
|
||||
return template
|
||||
return frontend_node
|
||||
|
||||
|
||||
def sanitize_field_config(field_config: Dict):
|
||||
# If any of the already existing keys are in field_config, remove them
|
||||
for key in ["name", "field_type", "value", "required", "placeholder", "display_name", "advanced", "show"]:
|
||||
field_config.pop(key, None)
|
||||
return field_config
|
||||
|
||||
|
||||
# TODO: Move to correct place
|
||||
def add_code_field(template, raw_code, field_config):
|
||||
# Field with the Python code to allow update
|
||||
def add_code_field(frontend_node: CustomComponentFrontendNode, raw_code, field_config):
|
||||
code_field = TemplateField(
|
||||
dynamic=True,
|
||||
required=True,
|
||||
placeholder="",
|
||||
multiline=True,
|
||||
value=raw_code,
|
||||
password=False,
|
||||
name="code",
|
||||
advanced=field_config.pop("advanced", False),
|
||||
field_type="code",
|
||||
is_list=False,
|
||||
)
|
||||
frontend_node.template.add_field(code_field)
|
||||
|
||||
code_field = {
|
||||
"code": {
|
||||
"dynamic": True,
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"show": field_config.pop("show", True),
|
||||
"multiline": True,
|
||||
"value": raw_code,
|
||||
"password": False,
|
||||
"name": "code",
|
||||
"advanced": field_config.pop("advanced", False),
|
||||
"type": "code",
|
||||
"list": False,
|
||||
}
|
||||
}
|
||||
template.get("template")["code"] = code_field.get("code")
|
||||
|
||||
return template
|
||||
return frontend_node
|
||||
|
||||
|
||||
def extract_type_from_optional(field_type):
|
||||
|
|
@ -182,51 +176,108 @@ def extract_type_from_optional(field_type):
|
|||
return match[1] if match else None
|
||||
|
||||
|
||||
def build_frontend_node(custom_component: CustomComponent):
|
||||
def build_frontend_node(template_config):
|
||||
"""Build a frontend node for a custom component"""
|
||||
try:
|
||||
return (
|
||||
CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__)
|
||||
)
|
||||
|
||||
sanitized_template_config = sanitize_template_config(template_config)
|
||||
return CustomComponentFrontendNode(**sanitized_template_config)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while building base frontend node: {exc}")
|
||||
return None
|
||||
raise exc
|
||||
|
||||
|
||||
def update_attributes(frontend_node, template_config):
|
||||
"""Update the display name and description of a frontend node"""
|
||||
attributes = [
|
||||
def sanitize_template_config(template_config):
|
||||
"""Sanitize the template config"""
|
||||
attributes = {
|
||||
"display_name",
|
||||
"description",
|
||||
"beta",
|
||||
"documentation",
|
||||
"output_types",
|
||||
]
|
||||
for attribute in attributes:
|
||||
if attribute in template_config:
|
||||
frontend_node[attribute] = template_config[attribute]
|
||||
}
|
||||
for key in template_config.copy():
|
||||
if key not in attributes:
|
||||
template_config.pop(key, None)
|
||||
|
||||
return template_config
|
||||
|
||||
|
||||
def build_field_config(custom_component: CustomComponent):
|
||||
def build_field_config(
|
||||
custom_component: CustomComponent, user_id: Optional[Union[str, UUID]] = None, update_field=None
|
||||
):
|
||||
"""Build the field configuration for a custom component"""
|
||||
|
||||
try:
|
||||
custom_class = get_function_custom(custom_component.code)
|
||||
if custom_component.code is None:
|
||||
return {}
|
||||
elif isinstance(custom_component.code, str):
|
||||
custom_class = eval_custom_component_code(custom_component.code)
|
||||
else:
|
||||
raise ValueError("Invalid code type")
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while getting custom function: {str(exc)}")
|
||||
return {}
|
||||
logger.error(f"Error while evaluating custom component code: {str(exc)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": ("Invalid type convertion. Please check your code and try again."),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
) from exc
|
||||
|
||||
try:
|
||||
return custom_class().build_config()
|
||||
build_config: Dict = custom_class(user_id=user_id).build_config()
|
||||
|
||||
for field_name, field in build_config.items():
|
||||
# Allow user to build TemplateField as well
|
||||
# as a dict with the same keys as TemplateField
|
||||
field_dict = get_field_dict(field)
|
||||
if update_field is not None and field_name != update_field:
|
||||
continue
|
||||
try:
|
||||
update_field_dict(field_dict)
|
||||
build_config[field_name] = field_dict
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while getting build_config: {str(exc)}")
|
||||
|
||||
return build_config
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while building field config: {str(exc)}")
|
||||
return {}
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": ("Invalid type convertion. Please check your code and try again."),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
) from exc
|
||||
|
||||
|
||||
def get_field_dict(field):
|
||||
"""Get the field dictionary from a TemplateField or a dict"""
|
||||
if isinstance(field, TemplateField):
|
||||
return field.model_dump(by_alias=True, exclude_none=True)
|
||||
return field
|
||||
|
||||
|
||||
def update_field_dict(field_dict):
|
||||
"""Update the field dictionary by calling options() or value() if they are callable"""
|
||||
if "options" in field_dict and callable(field_dict["options"]):
|
||||
field_dict["options"] = field_dict["options"]()
|
||||
# Also update the "refresh" key
|
||||
field_dict["refresh"] = True
|
||||
|
||||
if "value" in field_dict and callable(field_dict["value"]):
|
||||
field_dict["value"] = field_dict["value"](field_dict.get("options", []))
|
||||
field_dict["refresh"] = True
|
||||
|
||||
# Let's check if "range_spec" is a RangeSpec object
|
||||
if "rangeSpec" in field_dict and isinstance(field_dict["rangeSpec"], RangeSpec):
|
||||
field_dict["rangeSpec"] = field_dict["rangeSpec"].model_dump()
|
||||
|
||||
|
||||
def add_extra_fields(frontend_node, field_config, function_args):
|
||||
"""Add extra fields to the frontend node"""
|
||||
if function_args is None or function_args == "":
|
||||
if not function_args:
|
||||
return
|
||||
|
||||
# sort function_args which is a list of dicts
|
||||
|
|
@ -236,9 +287,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
|
|||
if "name" not in extra_field or extra_field["name"] == "self":
|
||||
continue
|
||||
|
||||
field_name, field_type, field_value, field_required = get_field_properties(
|
||||
extra_field
|
||||
)
|
||||
field_name, field_type, field_value, field_required = get_field_properties(extra_field)
|
||||
config = field_config.get(field_name, {})
|
||||
frontend_node = add_new_custom_field(
|
||||
frontend_node,
|
||||
|
|
@ -259,72 +308,83 @@ def get_field_properties(extra_field):
|
|||
|
||||
if not field_required:
|
||||
field_type = extract_type_from_optional(field_type)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
field_value = ast.literal_eval(field_value)
|
||||
if field_value is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
field_value = ast.literal_eval(field_value)
|
||||
return field_name, field_type, field_value, field_required
|
||||
|
||||
|
||||
def add_base_classes(frontend_node, return_types: List[str]):
|
||||
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
|
||||
"""Add base classes to the frontend node"""
|
||||
for return_type in return_types:
|
||||
if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None:
|
||||
for return_type_instance in return_types:
|
||||
if return_type_instance is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": (
|
||||
"Invalid return type should be one of: "
|
||||
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
|
||||
),
|
||||
"error": ("Invalid return type. Please check your code and try again."),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
|
||||
return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type)
|
||||
base_classes = get_base_classes(return_type_instance)
|
||||
|
||||
for base_class in base_classes:
|
||||
if base_class not in CLASSES_TO_REMOVE:
|
||||
frontend_node.get("base_classes").append(base_class)
|
||||
frontend_node.add_base_class(base_class)
|
||||
|
||||
|
||||
def build_langchain_template_custom_component(custom_component: CustomComponent):
|
||||
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
|
||||
"""Add output types to the frontend node"""
|
||||
for return_type in return_types:
|
||||
if return_type is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": ("Invalid return type. Please check your code and try again."),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
if hasattr(return_type, "__name__"):
|
||||
return_type = return_type.__name__
|
||||
elif hasattr(return_type, "__class__"):
|
||||
return_type = return_type.__class__.__name__
|
||||
else:
|
||||
return_type = str(return_type)
|
||||
|
||||
frontend_node.add_output_type(return_type)
|
||||
|
||||
|
||||
def build_custom_component_template(
|
||||
custom_component: CustomComponent,
|
||||
user_id: Optional[Union[str, UUID]] = None,
|
||||
update_field: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Build a custom component template for the langchain"""
|
||||
try:
|
||||
logger.debug("Building custom component template")
|
||||
frontend_node = build_frontend_node(custom_component)
|
||||
frontend_node = build_frontend_node(custom_component.template_config)
|
||||
|
||||
if frontend_node is None:
|
||||
return None
|
||||
logger.debug("Built base frontend node")
|
||||
template_config = custom_component.build_template_config
|
||||
|
||||
update_attributes(frontend_node, template_config)
|
||||
logger.debug("Updated attributes")
|
||||
field_config = build_field_config(custom_component)
|
||||
field_config = build_field_config(custom_component, user_id=user_id, update_field=update_field)
|
||||
logger.debug("Built field config")
|
||||
entrypoint_args = custom_component.get_function_entrypoint_args
|
||||
|
||||
add_extra_fields(frontend_node, field_config, entrypoint_args)
|
||||
logger.debug("Added extra fields")
|
||||
frontend_node = add_code_field(
|
||||
frontend_node, custom_component.code, field_config.get("code", {})
|
||||
)
|
||||
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
|
||||
logger.debug("Added code field")
|
||||
add_base_classes(
|
||||
frontend_node, custom_component.get_function_entrypoint_return_type
|
||||
)
|
||||
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
|
||||
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
|
||||
logger.debug("Added base classes")
|
||||
return frontend_node
|
||||
return frontend_node.to_dict(add_name=False)
|
||||
except Exception as exc:
|
||||
if isinstance(exc, HTTPException):
|
||||
raise exc
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": (
|
||||
"Invalid type convertion. Please check your code and try again."
|
||||
),
|
||||
"error": ("Invalid type convertion. Please check your code and try again."),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
) from exc
|
||||
|
|
@ -348,119 +408,147 @@ def build_and_validate_all_files(reader: DirectoryReader, file_list):
|
|||
|
||||
|
||||
def build_valid_menu(valid_components):
|
||||
"""Build the valid menu"""
|
||||
"""Build the valid menu."""
|
||||
valid_menu = {}
|
||||
logger.debug("------------------- VALID COMPONENTS -------------------")
|
||||
for menu_item in valid_components["menu"]:
|
||||
menu_name = menu_item["name"]
|
||||
valid_menu[menu_name] = {}
|
||||
|
||||
for component in menu_item["components"]:
|
||||
logger.debug(
|
||||
f"Building component: {component.get('name'), component.get('output_types')}"
|
||||
)
|
||||
try:
|
||||
component_name = component["name"]
|
||||
component_code = component["code"]
|
||||
component_output_types = component["output_types"]
|
||||
|
||||
component_extractor = CustomComponent(code=component_code)
|
||||
component_extractor.is_check_valid()
|
||||
|
||||
component_template = build_langchain_template_custom_component(
|
||||
component_extractor
|
||||
)
|
||||
component_template["output_types"] = component_output_types
|
||||
if len(component_output_types) == 1:
|
||||
component_name = component_output_types[0]
|
||||
else:
|
||||
file_name = component.get("file").split(".")[0]
|
||||
if "_" in file_name:
|
||||
# turn .py file into camelcase
|
||||
component_name = "".join(
|
||||
[word.capitalize() for word in file_name.split("_")]
|
||||
)
|
||||
else:
|
||||
component_name = file_name
|
||||
|
||||
valid_menu[menu_name][component_name] = component_template
|
||||
logger.debug(f"Added {component_name} to valid menu to {menu_name}")
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error loading Component: {component['output_types']}")
|
||||
logger.exception(
|
||||
f"Error while building custom component {component_output_types}: {exc}"
|
||||
)
|
||||
|
||||
valid_menu[menu_name] = build_menu_items(menu_item)
|
||||
return valid_menu
|
||||
|
||||
|
||||
def build_menu_items(menu_item):
|
||||
"""Build menu items for a given menu."""
|
||||
menu_items = {}
|
||||
for component in menu_item["components"]:
|
||||
try:
|
||||
component_name, component_template = build_component(component)
|
||||
menu_items[component_name] = component_template
|
||||
logger.debug(f"Added {component_name} to valid menu.")
|
||||
except Exception as exc:
|
||||
logger.error(f"Error loading Component: {component['output_types']}")
|
||||
logger.exception(f"Error while building custom component {component['output_types']}: {exc}")
|
||||
return menu_items
|
||||
|
||||
|
||||
def build_component(component):
|
||||
"""Build a single component."""
|
||||
logger.debug(f"Building component: {component.get('name'), component.get('output_types')}")
|
||||
component_name = determine_component_name(component)
|
||||
component_template = create_component_template(component)
|
||||
return component_name, component_template
|
||||
|
||||
|
||||
def determine_component_name(component):
|
||||
"""Determine the name of the component."""
|
||||
component_output_types = component["output_types"]
|
||||
if len(component_output_types) == 1:
|
||||
return component_output_types[0]
|
||||
else:
|
||||
file_name = component.get("file").split(".")[0]
|
||||
return "".join(word.capitalize() for word in file_name.split("_")) if "_" in file_name else file_name
|
||||
|
||||
|
||||
def create_component_template(component):
|
||||
"""Create a template for a component."""
|
||||
component_code = component["code"]
|
||||
component_output_types = component["output_types"]
|
||||
|
||||
component_extractor = CustomComponent(code=component_code)
|
||||
component_extractor.validate()
|
||||
|
||||
component_template = build_custom_component_template(component_extractor)
|
||||
component_template["output_types"] = component_output_types
|
||||
return component_template
|
||||
|
||||
|
||||
def build_invalid_menu(invalid_components):
|
||||
"""Build the invalid menu"""
|
||||
if invalid_components.get("menu"):
|
||||
logger.debug("------------------- INVALID COMPONENTS -------------------")
|
||||
"""Build the invalid menu."""
|
||||
if not invalid_components.get("menu"):
|
||||
return {}
|
||||
|
||||
logger.debug("------------------- INVALID COMPONENTS -------------------")
|
||||
invalid_menu = {}
|
||||
for menu_item in invalid_components["menu"]:
|
||||
menu_name = menu_item["name"]
|
||||
invalid_menu[menu_name] = {}
|
||||
|
||||
for component in menu_item["components"]:
|
||||
try:
|
||||
component_name = component["name"]
|
||||
component_code = component["code"]
|
||||
|
||||
component_template = (
|
||||
CustomComponentFrontendNode(
|
||||
description="ERROR - Check your Python Code",
|
||||
display_name=f"ERROR - {component_name}",
|
||||
)
|
||||
.to_dict()
|
||||
.get(type(CustomComponent()).__name__)
|
||||
)
|
||||
|
||||
component_template["error"] = component.get("error", None)
|
||||
logger.debug(component)
|
||||
logger.debug(f"Component Path: {component.get('path', None)}")
|
||||
logger.debug(f"Component Error: {component.get('error', None)}")
|
||||
component_template.get("template").get("code")["value"] = component_code
|
||||
|
||||
invalid_menu[menu_name][component_name] = component_template
|
||||
logger.debug(f"Added {component_name} to invalid menu to {menu_name}")
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
f"Error while creating custom component [{component_name}]: {str(exc)}"
|
||||
)
|
||||
|
||||
invalid_menu[menu_name] = build_invalid_menu_items(menu_item)
|
||||
return invalid_menu
|
||||
|
||||
|
||||
def build_invalid_menu_items(menu_item):
|
||||
"""Build invalid menu items for a given menu."""
|
||||
menu_items = {}
|
||||
for component in menu_item["components"]:
|
||||
try:
|
||||
component_name, component_template = build_invalid_component(component)
|
||||
menu_items[component_name] = component_template
|
||||
logger.debug(f"Added {component_name} to invalid menu.")
|
||||
except Exception as exc:
|
||||
logger.exception(f"Error while creating custom component [{component_name}]: {str(exc)}")
|
||||
return menu_items
|
||||
|
||||
|
||||
def build_invalid_component(component):
|
||||
"""Build a single invalid component."""
|
||||
component_name = component["name"]
|
||||
component_template = create_invalid_component_template(component, component_name)
|
||||
log_invalid_component_details(component)
|
||||
return component_name, component_template
|
||||
|
||||
|
||||
def create_invalid_component_template(component, component_name):
|
||||
"""Create a template for an invalid component."""
|
||||
component_code = component["code"]
|
||||
component_template = (
|
||||
CustomComponentFrontendNode(
|
||||
description="ERROR - Check your Python Code",
|
||||
display_name=f"ERROR - {component_name}",
|
||||
)
|
||||
.to_dict()
|
||||
.get(type(CustomComponent()).__name__)
|
||||
)
|
||||
|
||||
component_template["error"] = component.get("error", None)
|
||||
component_template.get("template").get("code")["value"] = component_code
|
||||
return component_template
|
||||
|
||||
|
||||
def log_invalid_component_details(component):
|
||||
"""Log details of an invalid component."""
|
||||
logger.debug(component)
|
||||
logger.debug(f"Component Path: {component.get('path', None)}")
|
||||
logger.debug(f"Component Error: {component.get('error', None)}")
|
||||
|
||||
|
||||
def get_new_key(dictionary, original_key):
|
||||
counter = 1
|
||||
new_key = original_key + " (" + str(counter) + ")"
|
||||
while new_key in dictionary:
|
||||
counter += 1
|
||||
new_key = original_key + " (" + str(counter) + ")"
|
||||
return new_key
|
||||
|
||||
|
||||
def merge_nested_dicts_with_renaming(dict1, dict2):
|
||||
for key, value in dict2.items():
|
||||
if (
|
||||
key in dict1
|
||||
and isinstance(value, dict)
|
||||
and isinstance(dict1.get(key), dict)
|
||||
):
|
||||
if key in dict1 and isinstance(value, dict) and isinstance(dict1.get(key), dict):
|
||||
for sub_key, sub_value in value.items():
|
||||
if sub_key in dict1[key]:
|
||||
new_key = get_new_key(dict1[key], sub_key)
|
||||
dict1[key][new_key] = sub_value
|
||||
else:
|
||||
dict1[key][sub_key] = sub_value
|
||||
# if sub_key in dict1[key]:
|
||||
# new_key = get_new_key(dict1[key], sub_key)
|
||||
# dict1[key][new_key] = sub_value
|
||||
# else:
|
||||
dict1[key][sub_key] = sub_value
|
||||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
||||
|
||||
def build_langchain_custom_component_list_from_path(path: str):
|
||||
def build_custom_component_list_from_path(path: str):
|
||||
"""Build a list of custom components for the langchain from a given path"""
|
||||
file_list = load_files_from_path(path)
|
||||
reader = DirectoryReader(path, False)
|
||||
|
||||
valid_components, invalid_components = build_and_validate_all_files(
|
||||
reader, file_list
|
||||
)
|
||||
valid_components, invalid_components = build_and_validate_all_files(reader, file_list)
|
||||
|
||||
valid_menu = build_valid_menu(valid_components)
|
||||
invalid_menu = build_invalid_menu(invalid_components)
|
||||
|
|
@ -469,42 +557,35 @@ def build_langchain_custom_component_list_from_path(path: str):
|
|||
|
||||
|
||||
def get_all_types_dict(settings_service):
|
||||
"""Get all types dictionary combining native and custom components."""
|
||||
native_components = build_langchain_types_dict()
|
||||
# custom_components is a list of dicts
|
||||
# need to merge all the keys into one dict
|
||||
custom_components_from_file: dict[str, Any] = {}
|
||||
if settings_service.settings.COMPONENTS_PATH:
|
||||
logger.info(
|
||||
f"Building custom components from {settings_service.settings.COMPONENTS_PATH}"
|
||||
)
|
||||
custom_components_from_file = build_custom_components(settings_service)
|
||||
return merge_nested_dicts_with_renaming(native_components, custom_components_from_file)
|
||||
|
||||
custom_component_dicts = []
|
||||
processed_paths = []
|
||||
for path in settings_service.settings.COMPONENTS_PATH:
|
||||
if str(path) in processed_paths:
|
||||
continue
|
||||
custom_component_dict = build_langchain_custom_component_list_from_path(
|
||||
str(path)
|
||||
)
|
||||
custom_component_dicts.append(custom_component_dict)
|
||||
processed_paths.append(str(path))
|
||||
|
||||
logger.info(f"Loading {len(custom_component_dicts)} category(ies)")
|
||||
for custom_component_dict in custom_component_dicts:
|
||||
# custom_component_dict is a dict of dicts
|
||||
if not custom_component_dict:
|
||||
continue
|
||||
category = list(custom_component_dict.keys())[0]
|
||||
logger.info(
|
||||
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
|
||||
)
|
||||
def build_custom_components(settings_service):
|
||||
"""Build custom components from the specified paths."""
|
||||
if not settings_service.settings.COMPONENTS_PATH:
|
||||
return {}
|
||||
|
||||
logger.info(f"Building custom components from {settings_service.settings.COMPONENTS_PATH}")
|
||||
custom_components_from_file = {}
|
||||
processed_paths = set()
|
||||
for path in settings_service.settings.COMPONENTS_PATH:
|
||||
path_str = str(path)
|
||||
if path_str in processed_paths:
|
||||
continue
|
||||
|
||||
custom_component_dict = build_custom_component_list_from_path(path_str)
|
||||
if custom_component_dict:
|
||||
category = next(iter(custom_component_dict))
|
||||
logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}")
|
||||
custom_components_from_file = merge_nested_dicts_with_renaming(
|
||||
custom_components_from_file, custom_component_dict
|
||||
)
|
||||
processed_paths.add(path_str)
|
||||
|
||||
return merge_nested_dicts_with_renaming(
|
||||
native_components, custom_components_from_file
|
||||
)
|
||||
return custom_components_from_file
|
||||
|
||||
|
||||
def merge_nested_dicts(dict1, dict2):
|
||||
|
|
@ -514,3 +595,9 @@ def merge_nested_dicts(dict1, dict2):
|
|||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
||||
|
||||
def create_and_validate_component(code: str) -> CustomComponent:
|
||||
component = CustomComponent(code=code)
|
||||
component.validate()
|
||||
return component
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain import utilities
|
||||
from loguru import logger
|
||||
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.services.getters import get_settings_service
|
||||
|
||||
from langflow.services.deps import get_settings_service
|
||||
from langflow.template.frontend_node.utilities import UtilitiesFrontendNode
|
||||
from loguru import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
|
||||
|
||||
|
|
@ -41,8 +40,7 @@ class UtilityCreator(LangChainTypeCreator):
|
|||
self.type_dict = {
|
||||
name: utility
|
||||
for name, utility in self.type_dict.items()
|
||||
if name in settings_service.settings.UTILITIES
|
||||
or settings_service.settings.DEV
|
||||
if name in settings_service.settings.UTILITIES or settings_service.settings.DEV
|
||||
}
|
||||
|
||||
return self.type_dict
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from langchain.base_language import BaseLanguageModel
|
|||
from PIL.Image import Image
|
||||
from loguru import logger
|
||||
from langflow.services.chat.config import ChatConfig
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
||||
def load_file_into_dict(file_path: str) -> dict:
|
||||
|
|
@ -43,9 +43,7 @@ def try_setting_streaming_options(langchain_object):
|
|||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(langchain_object.llm_chain, "llm"):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
|
||||
if isinstance(llm, BaseLanguageModel):
|
||||
|
|
@ -74,17 +72,15 @@ def setup_llm_caching():
|
|||
|
||||
|
||||
def set_langchain_cache(settings):
|
||||
import langchain
|
||||
from langchain.globals import set_llm_cache
|
||||
from langflow.interface.importing.utils import import_class
|
||||
|
||||
if cache_type := os.getenv("LANGFLOW_LANGCHAIN_CACHE"):
|
||||
try:
|
||||
cache_class = import_class(
|
||||
f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}"
|
||||
)
|
||||
cache_class = import_class(f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}")
|
||||
|
||||
logger.debug(f"Setting up LLM caching with {cache_class.__name__}")
|
||||
langchain.llm_cache = cache_class()
|
||||
set_llm_cache(cache_class())
|
||||
logger.info(f"LLM caching setup with {cache_class.__name__}")
|
||||
except ImportError:
|
||||
logger.warning(f"Could not import {cache_type}. ")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langchain import vectorstores
|
|||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.services.getters import get_settings_service
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
from langflow.template.frontend_node.vectorstores import VectorStoreFrontendNode
|
||||
from loguru import logger
|
||||
|
|
@ -22,9 +22,7 @@ class VectorstoreCreator(LangChainTypeCreator):
|
|||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict: dict[str, Any] = {
|
||||
vectorstore_name: import_class(
|
||||
f"langchain.vectorstores.{vectorstore_name}"
|
||||
)
|
||||
vectorstore_name: import_class(f"langchain.vectorstores.{vectorstore_name}")
|
||||
for vectorstore_name in vectorstores.__all__
|
||||
}
|
||||
return self.type_dict
|
||||
|
|
@ -48,8 +46,7 @@ class VectorstoreCreator(LangChainTypeCreator):
|
|||
return [
|
||||
vectorstore
|
||||
for vectorstore in self.type_to_loader_dict.keys()
|
||||
if vectorstore in settings_service.settings.VECTORSTORES
|
||||
or settings_service.settings.DEV
|
||||
if vectorstore in settings_service.settings.VECTORSTORES or settings_service.settings.DEV
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, List, Optional
|
||||
from typing import ClassVar, Dict, List, Optional
|
||||
|
||||
from langchain.utilities import requests, sql_database
|
||||
|
||||
|
|
@ -10,14 +10,13 @@ from langflow.utils.util import build_template_from_class, build_template_from_m
|
|||
class WrapperCreator(LangChainTypeCreator):
|
||||
type_name: str = "wrappers"
|
||||
|
||||
from_method_nodes = {"SQLDatabase": "from_uri"}
|
||||
from_method_nodes: ClassVar[Dict] = {"SQLDatabase": "from_uri"}
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = {
|
||||
wrapper.__name__: wrapper
|
||||
for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase]
|
||||
wrapper.__name__: wrapper for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase]
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -1,29 +1,34 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from langflow.api import router
|
||||
|
||||
|
||||
from langflow.interface.utils import setup_llm_caching
|
||||
from langflow.services.utils import initialize_services
|
||||
from langflow.services.plugins.langfuse import LangfuseInstance
|
||||
from langflow.services.utils import (
|
||||
teardown_services,
|
||||
)
|
||||
from langflow.services.plugins.langfuse_plugin import LangfuseInstance
|
||||
from langflow.services.utils import initialize_services, teardown_services
|
||||
from langflow.utils.logger import configure
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
initialize_services()
|
||||
setup_llm_caching()
|
||||
LangfuseInstance.update()
|
||||
yield
|
||||
teardown_services()
|
||||
|
||||
|
||||
def create_app():
|
||||
"""Create the FastAPI app and include the router."""
|
||||
|
||||
configure()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
|
|
@ -34,19 +39,22 @@ def create_app():
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.middleware("http")
|
||||
async def flatten_query_string_lists(request: Request, call_next):
|
||||
flattened: list[tuple[str, str]] = []
|
||||
for key, value in request.query_params.multi_items():
|
||||
flattened.extend((key, entry) for entry in value.split(","))
|
||||
|
||||
request.scope["query_string"] = urlencode(flattened, doseq=True).encode("utf-8")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
app.on_event("startup")(initialize_services)
|
||||
app.on_event("startup")(setup_llm_caching)
|
||||
app.on_event("startup")(LangfuseInstance.update)
|
||||
|
||||
app.on_event("shutdown")(teardown_services)
|
||||
app.on_event("shutdown")(LangfuseInstance.teardown)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
@ -78,9 +86,7 @@ def get_static_files_dir():
|
|||
return frontend_path / "frontend"
|
||||
|
||||
|
||||
def setup_app(
|
||||
static_files_dir: Optional[Path] = None, backend_only: bool = False
|
||||
) -> FastAPI:
|
||||
def setup_app(static_files_dir: Optional[Path] = None, backend_only: bool = False) -> FastAPI:
|
||||
"""Setup the FastAPI app."""
|
||||
# get the directory of the current file
|
||||
if not static_files_dir:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import List, Union, TYPE_CHECKING
|
||||
from langflow.api.v1.callback import (
|
||||
AsyncStreamingLLMCallbackHandler,
|
||||
StreamingLLMCallbackHandler,
|
||||
)
|
||||
from langflow.processing.process import fix_memory_inputs, format_actions
|
||||
from loguru import logger
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.v1.callback import (AsyncStreamingLLMCallbackHandler,
|
||||
StreamingLLMCallbackHandler)
|
||||
from langflow.processing.process import fix_memory_inputs, format_actions
|
||||
from langflow.services.deps import get_plugins_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langfuse.callback import CallbackHandler # type: ignore
|
||||
|
|
@ -20,21 +21,23 @@ def setup_callbacks(sync, trace_id, **kwargs):
|
|||
else:
|
||||
callbacks.append(AsyncStreamingLLMCallbackHandler(**kwargs))
|
||||
|
||||
if langfuse_callback := get_langfuse_callback(trace_id=trace_id):
|
||||
logger.debug("Langfuse callback loaded")
|
||||
callbacks.append(langfuse_callback)
|
||||
plugin_service = get_plugins_service()
|
||||
plugin_callbacks = plugin_service.get_callbacks(_id=trace_id)
|
||||
if plugin_callbacks:
|
||||
callbacks.extend(plugin_callbacks)
|
||||
return callbacks
|
||||
|
||||
|
||||
def get_langfuse_callback(trace_id):
|
||||
from langflow.services.plugins.langfuse import LangfuseInstance
|
||||
from langfuse.callback import CreateTrace
|
||||
|
||||
from langflow.services.deps import get_plugins_service
|
||||
|
||||
logger.debug("Initializing langfuse callback")
|
||||
if langfuse := LangfuseInstance.get():
|
||||
if langfuse := get_plugins_service().get("langfuse"):
|
||||
logger.debug("Langfuse credentials found")
|
||||
try:
|
||||
trace = langfuse.trace(CreateTrace(id=trace_id))
|
||||
trace = langfuse.trace(CreateTrace(name="langflow-" + trace_id, id=trace_id))
|
||||
return trace.getNewHandler()
|
||||
except Exception as exc:
|
||||
logger.error(f"Error initializing langfuse callback: {exc}")
|
||||
|
|
@ -42,14 +45,12 @@ def get_langfuse_callback(trace_id):
|
|||
return None
|
||||
|
||||
|
||||
def flush_langfuse_callback_if_present(
|
||||
callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]
|
||||
):
|
||||
def flush_langfuse_callback_if_present(callbacks: List[Union[BaseCallbackHandler, "CallbackHandler"]]):
|
||||
"""
|
||||
If langfuse callback is present, run callback.langfuse.flush()
|
||||
"""
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "langfuse"):
|
||||
if hasattr(callback, "langfuse") and hasattr(callback.langfuse, "flush"):
|
||||
callback.langfuse.flush()
|
||||
break
|
||||
|
||||
|
|
@ -86,15 +87,9 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
|
|||
# if langfuse callback is present, run callback.langfuse.flush()
|
||||
flush_langfuse_callback_if_present(callbacks)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
intermediate_steps = output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
result = output.get(langchain_object.output_keys[0]) if isinstance(output, dict) else output
|
||||
try:
|
||||
thought = format_actions(intermediate_steps) if intermediate_steps else ""
|
||||
except Exception as exc:
|
||||
|
|
@ -103,4 +98,4 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
|
|||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
return result, thought, output
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue