Merge remote-tracking branch 'origin/dev' into chatImg

This commit is contained in:
anovazzi1 2024-04-12 13:53:34 -03:00
commit 601a9eff47
58 changed files with 1691 additions and 2539 deletions

View file

@ -9,12 +9,6 @@ import click
import httpx
import typer
from dotenv import load_dotenv
from langflow.main import setup_app
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from langflow.services.utils import initialize_services
from langflow.utils.logger import configure, logger
from langflow.utils.util import update_settings
from multiprocess import Process, cpu_count # type: ignore
from packaging import version as pkg_version
from rich import box
@ -23,6 +17,13 @@ from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from langflow.main import setup_app
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service
from langflow.services.utils import initialize_services
from langflow.utils.logger import configure, logger
from langflow.utils.util import update_settings
console = Console()
app = typer.Typer(no_args_is_help=True)
@ -151,17 +152,21 @@ def run(
# Define an env variable to know if we are just testing the server
if "pytest" in sys.modules:
return
if platform.system() in ["Windows"]:
# Run using uvicorn on MacOS and Windows
# Windows doesn't support gunicorn
# MacOS requires an env variable to be set to use gunicorn
run_on_windows(host, port, log_level, options, app)
else:
# Run using gunicorn on Linux
run_on_mac_or_linux(host, port, log_level, options, app)
if open_browser:
click.launch(f"http://{host}:{port}")
try:
if platform.system() in ["Windows"]:
# Run using uvicorn on MacOS and Windows
# Windows doesn't support gunicorn
# MacOS requires an env variable to be set to use gunicorn
process = run_on_windows(host, port, log_level, options, app)
else:
# Run using gunicorn on Linux
process = run_on_mac_or_linux(host, port, log_level, options, app)
if open_browser:
click.launch(f"http://{host}:{port}")
if process:
process.join()
except KeyboardInterrupt:
pass
def wait_for_server_ready(host, port):
@ -182,6 +187,7 @@ def run_on_mac_or_linux(host, port, log_level, options, app):
wait_for_server_ready(host, port)
print_banner(host, port)
return webapp_process
def run_on_windows(host, port, log_level, options, app):
@ -190,6 +196,7 @@ def run_on_windows(host, port, log_level, options, app):
"""
print_banner(host, port)
run_langflow(host, port, log_level, options, app)
return None
def is_port_in_use(port, host="localhost"):

View file

@ -0,0 +1,99 @@
"""Change datetime type
Revision ID: 79e675cb6752
Revises: e3bc869fa272
Create Date: 2024-04-11 19:23:10.697335
"""
from calendar import c
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision: str = "79e675cb6752"
down_revision: Union[str, None] = "e3bc869fa272"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
if "apikey" in table_names:
columns = inspector.get_columns("apikey")
created_at_column = next((column for column in columns if column["name"] == "created_at"), None)
if created_at_column is not None and created_at_column["type"] == postgresql.TIMESTAMP():
with op.batch_alter_table("apikey", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
if "variable" in table_names:
columns = inspector.get_columns("variable")
created_at_column = next((column for column in columns if column["name"] == "created_at"), None)
updated_at_column = next((column for column in columns if column["name"] == "updated_at"), None)
with op.batch_alter_table("variable", schema=None) as batch_op:
if created_at_column is not None and created_at_column["type"] == postgresql.TIMESTAMP():
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
if updated_at_column is not None and updated_at_column["type"] == postgresql.TIMESTAMP():
batch_op.alter_column(
"updated_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
if "variable" in table_names:
columns = inspector.get_columns("variable")
created_at_column = next((column for column in columns if column["name"] == "created_at"), None)
updated_at_column = next((column for column in columns if column["name"] == "updated_at"), None)
with op.batch_alter_table("variable", schema=None) as batch_op:
if updated_at_column is not None and updated_at_column["type"] == sa.DateTime(timezone=True):
batch_op.alter_column(
"updated_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
if created_at_column is not None and created_at_column["type"] == sa.DateTime(timezone=True):
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
if "apikey" in table_names:
columns = inspector.get_columns("apikey")
created_at_column = next((column for column in columns if column["name"] == "created_at"), None)
if created_at_column is not None and created_at_column["type"] == sa.DateTime(timezone=True):
with op.batch_alter_table("apikey", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
# ### end Alembic commands ###

View file

@ -0,0 +1,68 @@
"""Fix nullable
Revision ID: e3bc869fa272
Revises: 1a110b568907
Create Date: 2024-04-10 19:17:22.820455
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision: str = "e3bc869fa272"
down_revision: Union[str, None] = "1a110b568907"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
if "variable" not in table_names:
return
columns = [column for column in inspector.get_columns("variable")]
column_names = [column["name"] for column in columns]
with op.batch_alter_table("variable", schema=None) as batch_op:
if "created_at" in column_names:
created_at_colunmn = next(column for column in columns if column["name"] == "created_at")
if created_at_colunmn["nullable"] is False:
batch_op.alter_column(
"created_at",
existing_type=sa.TIMESTAMP(timezone=True),
nullable=True,
# existing_server_default expects str | bool | Identity | Computed | None
# sa.text("now()") is not a valid value for existing_server_default
existing_server_default=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
conn = op.get_bind()
inspector = Inspector.from_engine(conn) # type: ignore
table_names = inspector.get_table_names()
# ### commands auto generated by Alembic - please adjust! ###
if "variable" not in table_names:
return
columns = [column for column in inspector.get_columns("variable")]
column_names = [column["name"] for column in columns]
with op.batch_alter_table("variable", schema=None) as batch_op:
if "created_at" in column_names:
created_at_colunmn = next(column for column in columns if column["name"] == "created_at")
if created_at_colunmn["nullable"] is True:
batch_op.alter_column(
"created_at",
existing_type=sa.TIMESTAMP(timezone=True),
nullable=False,
existing_server_default=False,
)
# ### end Alembic commands ###

View file

@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException
from loguru import logger
from langflow.api.v1.base import Code, CodeValidationResponse, PromptValidationResponse, ValidatePromptRequest
from langflow.base.prompts.utils import (
from langflow.base.prompts.api_utils import (
add_new_variables_to_template,
get_old_custom_fields,
remove_old_variables_from_template,

View file

@ -0,0 +1,83 @@
from fastapi import HTTPException
from langchain.prompts import PromptTemplate
from loguru import logger
from langflow.api.v1.base import INVALID_NAMES, check_input_variables
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.template.field.prompt import DefaultPromptField
def validate_prompt(prompt_template: str, silent_errors: bool = False) -> list[str]:
input_variables = extract_input_variables_from_prompt(prompt_template)
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
if any(var in INVALID_NAMES for var in input_variables):
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
try:
PromptTemplate(template=prompt_template, input_variables=input_variables)
except Exception as exc:
logger.error(f"Invalid prompt: {exc}")
if not silent_errors:
raise ValueError(f"Invalid prompt: {exc}") from exc
return input_variables
def get_old_custom_fields(custom_fields, name):
try:
if len(custom_fields) == 1 and name == "":
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
name = list(custom_fields.keys())[0]
old_custom_fields = custom_fields[name]
if not old_custom_fields:
old_custom_fields = []
old_custom_fields = old_custom_fields.copy()
except KeyError:
old_custom_fields = []
custom_fields[name] = []
return old_custom_fields
def add_new_variables_to_template(input_variables, custom_fields, template, name):
for variable in input_variables:
try:
template_field = DefaultPromptField(name=variable, display_name=variable)
if variable in template:
# Set the new field with the old value
template_field.value = template[variable]["value"]
template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if variable not in custom_fields[name]:
custom_fields[name].append(variable)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if variable in custom_fields[name]:
custom_fields[name].remove(variable)
# Remove the variable from the template
template.pop(variable, None)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def update_input_variables_field(input_variables, template):
if "input_variables" in template:
template["input_variables"]["value"] = input_variables

View file

@ -1,12 +1,19 @@
from fastapi import HTTPException
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from loguru import logger
from langflow.api.v1.base import INVALID_NAMES, check_input_variables
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.schema import Record
from langflow.template.field.prompt import DefaultPromptField
def record_to_string(record: Record) -> str:
"""
Convert a record to a string.
Args:
record (Record): The record to convert.
Returns:
str: The record as a string.
"""
return record.get_text()
def dict_values_to_string(d: dict) -> dict:
@ -35,19 +42,6 @@ def dict_values_to_string(d: dict) -> dict:
return d
def record_to_string(record: Record) -> str:
"""
Convert a record to a string.
Args:
record (Record): The record to convert.
Returns:
str: The record as a string.
"""
return record.get_text()
def document_to_string(document: Document) -> str:
"""
Convert a document to a string.
@ -59,79 +53,3 @@ def document_to_string(document: Document) -> str:
str: The document as a string.
"""
return document.page_content
def validate_prompt(prompt_template: str, silent_errors: bool = False) -> list[str]:
input_variables = extract_input_variables_from_prompt(prompt_template)
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
if any(var in INVALID_NAMES for var in input_variables):
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
try:
PromptTemplate(template=prompt_template, input_variables=input_variables)
except Exception as exc:
logger.error(f"Invalid prompt: {exc}")
if not silent_errors:
raise ValueError(f"Invalid prompt: {exc}") from exc
return input_variables
def get_old_custom_fields(custom_fields, name):
try:
if len(custom_fields) == 1 and name == "":
# If there is only one custom field and the name is empty string
# then we are dealing with the first prompt request after the node was created
name = list(custom_fields.keys())[0]
old_custom_fields = custom_fields[name]
if not old_custom_fields:
old_custom_fields = []
old_custom_fields = old_custom_fields.copy()
except KeyError:
old_custom_fields = []
custom_fields[name] = []
return old_custom_fields
def add_new_variables_to_template(input_variables, custom_fields, template, name):
for variable in input_variables:
try:
template_field = DefaultPromptField(name=variable, display_name=variable)
if variable in template:
# Set the new field with the old value
template_field.value = template[variable]["value"]
template[variable] = template_field.to_dict()
# Check if variable is not already in the list before appending
if variable not in custom_fields[name]:
custom_fields[name].append(variable)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name):
for variable in old_custom_fields:
if variable not in input_variables:
try:
# Remove the variable from custom_fields associated with the given name
if variable in custom_fields[name]:
custom_fields[name].remove(variable)
# Remove the variable from the template
template.pop(variable, None)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
def update_input_variables_field(input_variables, template):
if "input_variables" in template:
template["input_variables"]["value"] = input_variables

View file

@ -1,9 +1,10 @@
from typing import Optional
from langchain.chains import LLMChain
from langchain.chains.llm import LLMChain
from langflow.field_typing import BaseLanguageModel, BaseMemory, BasePromptTemplate, Text
from langflow.field_typing import BaseLanguageModel, BaseMemory, Text
from langflow.interface.custom.custom_component import CustomComponent
from langchain_core.prompts import PromptTemplate
class LLMChainComponent(CustomComponent):
@ -19,10 +20,11 @@ class LLMChainComponent(CustomComponent):
def build(
self,
prompt: BasePromptTemplate,
template: Text,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
) -> Text:
prompt = PromptTemplate.from_template(template)
runnable = LLMChain(prompt=prompt, llm=llm, memory=memory)
result_dict = runnable.invoke({})
output_key = runnable.output_key

View file

@ -71,6 +71,8 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent):
)
if not vector_store:
raise ValueError("Failed to load the Pinecone index.")
if search_kwargs is None:
search_kwargs = {}
return self.search_with_vector_store(
vector_store=vector_store,

View file

@ -86,12 +86,13 @@ class QdrantSearchComponent(QdrantComponent, LCVectorStoreComponent):
port=port,
prefer_grpc=prefer_grpc,
prefix=prefix,
search_kwargs=search_kwargs,
timeout=timeout,
url=url,
)
if not vector_store:
raise ValueError("Failed to load the Qdrant index.")
if search_kwargs is None:
search_kwargs = {}
return self.search_with_vector_store(
vector_store=vector_store,

View file

@ -105,7 +105,7 @@ class AstraDBVectorStoreComponent(CustomComponent):
bulk_insert_batch_concurrency: Optional[int] = None,
bulk_insert_overwrite_concurrency: Optional[int] = None,
bulk_delete_concurrency: Optional[int] = None,
setup_mode: str = "Async",
setup_mode: str = "Sync",
pre_delete_collection: bool = False,
metadata_indexing_include: Optional[List[str]] = None,
metadata_indexing_exclude: Optional[List[str]] = None,

View file

@ -1,7 +1,7 @@
from typing import List, Optional
from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
from langflow.field_typing import Embeddings, NestedDict
from langflow.field_typing import Embeddings
from langflow.interface.custom.custom_component import CustomComponent
from langflow.schema.schema import Record

View file

@ -4,7 +4,7 @@ from langchain.schema import BaseRetriever
from langchain_community.vectorstores import VectorStore
from langchain_community.vectorstores.qdrant import Qdrant
from langflow.field_typing import Embeddings, NestedDict
from langflow.field_typing import Embeddings
from langflow.interface.custom.custom_component import CustomComponent
from langflow.schema.schema import Record

View file

@ -5,7 +5,7 @@ from langchain_community.vectorstores import VectorStore
from langchain_community.vectorstores.supabase import SupabaseVectorStore
from supabase.client import Client, create_client
from langflow.field_typing import Embeddings, NestedDict
from langflow.field_typing import Embeddings
from langflow.interface.custom.custom_component import CustomComponent
from langflow.schema.schema import Record

View file

@ -918,7 +918,7 @@ class Graph:
return ChatVertex
elif node_name in ["ShouldRunNext"]:
return RoutingVertex
elif node_name in ["SharedState", "Notify", "GetNotified"]:
elif node_name in ["SharedState", "Notify", "Listen"]:
return StateVertex
elif node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]
@ -1130,6 +1130,34 @@ class Graph:
return vertices_layers
def sort_layer_by_dependency(self, vertices_layers: List[List[str]]) -> List[List[str]]:
"""Sorts the vertices in each layer by dependency, ensuring no vertex depends on a subsequent vertex."""
sorted_layers = []
for layer in vertices_layers:
sorted_layer = self._sort_single_layer_by_dependency(layer)
sorted_layers.append(sorted_layer)
return sorted_layers
def _sort_single_layer_by_dependency(self, layer: List[str]) -> List[str]:
"""Sorts a single layer by dependency using a stable sorting method."""
# Build a map of each vertex to its index in the layer for quick lookup.
index_map = {vertex: index for index, vertex in enumerate(layer)}
# Create a sorted copy of the layer based on dependency order.
sorted_layer = sorted(layer, key=lambda vertex: self._max_dependency_index(vertex, index_map), reverse=True)
return sorted_layer
def _max_dependency_index(self, vertex_id: str, index_map: Dict[str, int]) -> int:
"""Finds the highest index a given vertex's dependencies occupy in the same layer."""
vertex = self.get_vertex(vertex_id)
max_index = -1
for successor in vertex.successors: # Assuming vertex.successors is a list of successor vertex identifiers.
if successor.id in index_map:
max_index = max(max_index, index_map[successor.id])
return max_index
def sort_vertices(
self,
stop_component_id: Optional[str] = None,
@ -1151,6 +1179,9 @@ class Graph:
vertices_layers = self.layered_topological_sort(vertices)
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
# vertices_layers = self.sort_chat_inputs_first(vertices_layers)
# Now we should sort each layer in a way that we make sure
# vertex V does not depend on vertex V+1
vertices_layers = self.sort_layer_by_dependency(vertices_layers)
self.increment_run_count()
self._sorted_vertices_layers = vertices_layers
first_layer = vertices_layers[0]

File diff suppressed because one or more lines are too long

View file

@ -5,17 +5,6 @@ from loguru import logger
from langflow.graph import Graph
async def build_sorted_vertices(data_graph, flow_id: str) -> Tuple[Graph, Dict]:
"""
Build langchain object from data_graph.
"""
logger.debug("Building langchain object")
graph = Graph.from_payload(data_graph, flow_id=flow_id)
return graph, {}
def get_memory_key(langchain_object):
"""
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.

View file

@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from loguru import logger
from rich import print as rprint
from langflow.api import router
from langflow.initial_setup.setup import create_or_update_starter_projects
@ -28,6 +29,8 @@ def get_lifespan(fix_migration=False, socketio_server=None):
LangfuseInstance.update()
create_or_update_starter_projects()
yield
# Shutdown message
rprint("[bold red]Shutting down Langflow...[/bold red]")
teardown_services()
return lifespan

View file

@ -1,4 +1,6 @@
import asyncio
import logging
import signal
from gunicorn import glogging # type: ignore
from gunicorn.app.base import BaseApplication # type: ignore
@ -10,6 +12,21 @@ from langflow.utils.logger import InterceptHandler # type: ignore
class LangflowUvicornWorker(UvicornWorker):
CONFIG_KWARGS = {"loop": "asyncio"}
def _install_sigint_handler(self) -> None:
"""Install a SIGQUIT handler on workers.
- https://github.com/encode/uvicorn/issues/1116
- https://github.com/benoitc/gunicorn/issues/2604
"""
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, self.handle_exit, signal.SIGINT, None)
async def _serve(self) -> None:
# We do this to not log the "Worker (pid:XXXXX) was sent SIGINT"
self._install_sigint_handler()
await super()._serve()
class Logger(glogging.Logger):
"""Implements and overrides the gunicorn logging interface.

View file

@ -211,7 +211,7 @@ def create_super_user(
return super_user
def create_user_longterm_token(db: Session = Depends(get_session)) -> dict:
def create_user_longterm_token(db: Session = Depends(get_session)) -> tuple[UUID, dict]:
settings_service = get_settings_service()
username = settings_service.auth_settings.SUPERUSER
password = settings_service.auth_settings.SUPERUSER_PASSWORD

View file

@ -1,8 +1,8 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional
from uuid import UUID, uuid4
from pydantic import validator
from pydantic import field_validator, validator
from sqlmodel import Field, Relationship, SQLModel, Column, func, DateTime
if TYPE_CHECKING:
@ -11,7 +11,9 @@ if TYPE_CHECKING:
class ApiKeyBase(SQLModel):
name: Optional[str] = Field(index=True, nullable=True, default=None)
created_at: datetime = Field(sa_column=Column(DateTime(timezone=True), server_default=func.now(), nullable=False))
created_at: datetime = Field(
default=None, sa_column=Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
)
last_used_at: Optional[datetime] = Field(default=None, nullable=True)
total_uses: int = Field(default=0)
is_active: bool = Field(default=True)
@ -33,6 +35,10 @@ class ApiKeyCreate(ApiKeyBase):
api_key: Optional[str] = None
user_id: Optional[UUID] = None
@field_validator("created_at", mode="before")
def set_created_at(cls, v):
return v or datetime.now(timezone.utc)
class UnmaskedApiKeyRead(ApiKeyBase):
id: UUID

View file

@ -26,11 +26,13 @@ class Variable(VariableBase, table=True):
)
# name is unique per user
created_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), server_default=func.now(), nullable=False),
default=None,
sa_column=Column(DateTime(timezone=True), server_default=func.now(), nullable=True),
description="Creation time of the variable",
)
updated_at: Optional[datetime] = Field(
sa_column=Column(DateTime(timezone=True)),
default=None,
sa_column=Column(DateTime(timezone=True), nullable=True),
description="Last update time of the variable",
)
# foreign key to user table
@ -39,7 +41,9 @@ class Variable(VariableBase, table=True):
class VariableCreate(VariableBase):
type: Optional[str] = Field(None, description="Type of the variable")
created_at: Optional[datetime] = Field(default_factory=utc_now, description="Creation time of the variable")
updated_at: Optional[datetime] = Field(default_factory=utc_now, description="Creation time of the variable")
class VariableRead(SQLModel):

View file

@ -1,21 +1,22 @@
from datetime import datetime
import time
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING
import sqlalchemy as sa
from alembic import command, util
from alembic.config import Config
from loguru import logger
from sqlalchemy import inspect
from sqlalchemy.exc import OperationalError
from sqlmodel import Session, SQLModel, create_engine, select, text
from langflow.services.base import Service
from langflow.services.database import models # noqa
from langflow.services.database.models.user.crud import get_user_by_username
from langflow.services.database.utils import Result, TableResults
from langflow.services.deps import get_settings_service
from langflow.services.utils import teardown_superuser
from loguru import logger
from sqlalchemy import inspect
from sqlalchemy.exc import OperationalError
from sqlmodel import Session, SQLModel, create_engine, select, text
if TYPE_CHECKING:
from sqlalchemy.engine import Engine
@ -195,7 +196,10 @@ class DatabaseService(Service):
# This method is used for testing purposes only
# We will check that all models are in the database
# and that the database is up to date with all columns
sql_models = [models.Flow, models.User, models.ApiKey]
# get all models that are subclasses of SQLModel
sql_models = [
model for model in models.__dict__.values() if isinstance(model, type) and issubclass(model, SQLModel)
]
return [TableResults(sql_model.__tablename__, self.check_table(sql_model)) for sql_model in sql_models]
def check_table(self, model):

View file

@ -1,7 +1,7 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Generator
from langflow.services import ServiceType, service_manager
from langflow.services.schema import ServiceType
if TYPE_CHECKING:
from sqlmodel import Session
@ -21,7 +21,7 @@ if TYPE_CHECKING:
from langflow.services.variable.service import VariableService
def get_service(service_type: ServiceType):
def get_service(service_type: ServiceType, default=None):
"""
Retrieves the service instance for the given service type.
@ -32,7 +32,13 @@ def get_service(service_type: ServiceType):
Any: The service instance.
"""
return service_manager.get(service_type) # type: ignore
from langflow.services.manager import service_manager
if not service_manager.factories:
#! This is a workaround to ensure that the service manager is initialized
#! Not optimal, but it works for now
service_manager.register_factories()
return service_manager.get(service_type, default) # type: ignore
def get_state_service() -> "StateService":
@ -42,7 +48,9 @@ def get_state_service() -> "StateService":
Returns:
The StateService instance.
"""
return service_manager.get(ServiceType.STATE_SERVICE) # type: ignore
from langflow.services.state.factory import StateServiceFactory
return get_service(ServiceType.STATE_SERVICE, StateServiceFactory()) # type: ignore
def get_socket_service() -> "SocketIOService":
@ -52,7 +60,7 @@ def get_socket_service() -> "SocketIOService":
Returns:
SocketIOService: The SocketIOService instance.
"""
return service_manager.get(ServiceType.SOCKETIO_SERVICE) # type: ignore
return get_service(ServiceType.SOCKETIO_SERVICE) # type: ignore
def get_storage_service() -> "StorageService":
@ -62,7 +70,9 @@ def get_storage_service() -> "StorageService":
Returns:
The storage service instance.
"""
return service_manager.get(ServiceType.STORAGE_SERVICE) # type: ignore
from langflow.services.storage.factory import StorageServiceFactory
return get_service(ServiceType.STORAGE_SERVICE, default=StorageServiceFactory()) # type: ignore
def get_variable_service() -> "VariableService":
@ -73,7 +83,9 @@ def get_variable_service() -> "VariableService":
The VariableService instance.
"""
return service_manager.get(ServiceType.VARIABLE_SERVICE) # type: ignore
from langflow.services.variable.factory import VariableServiceFactory
return get_service(ServiceType.VARIABLE_SERVICE, VariableServiceFactory()) # type: ignore
def get_plugins_service() -> "PluginService":
@ -83,7 +95,7 @@ def get_plugins_service() -> "PluginService":
Returns:
PluginService: The PluginService instance.
"""
return service_manager.get(ServiceType.PLUGIN_SERVICE) # type: ignore
return get_service(ServiceType.PLUGIN_SERVICE) # type: ignore
def get_settings_service() -> "SettingsService":
@ -98,14 +110,9 @@ def get_settings_service() -> "SettingsService":
Raises:
ValueError: If the service cannot be retrieved or initialized.
"""
try:
return service_manager.get(ServiceType.SETTINGS_SERVICE) # type: ignore
except ValueError:
# initialize settings service
from langflow.services.manager import initialize_settings_service
from langflow.services.settings.factory import SettingsServiceFactory
initialize_settings_service()
return service_manager.get(ServiceType.SETTINGS_SERVICE) # type: ignore
return get_service(ServiceType.SETTINGS_SERVICE, SettingsServiceFactory()) # type: ignore
def get_db_service() -> "DatabaseService":
@ -116,7 +123,9 @@ def get_db_service() -> "DatabaseService":
The DatabaseService instance.
"""
return service_manager.get(ServiceType.DATABASE_SERVICE) # type: ignore
from langflow.services.database.factory import DatabaseServiceFactory
return get_service(ServiceType.DATABASE_SERVICE, DatabaseServiceFactory()) # type: ignore
def get_session() -> Generator["Session", None, None]:
@ -165,7 +174,9 @@ def get_cache_service() -> "CacheService":
Returns:
The cache service instance.
"""
return service_manager.get(ServiceType.CACHE_SERVICE) # type: ignore
from langflow.services.cache.factory import CacheServiceFactory
return get_service(ServiceType.CACHE_SERVICE, CacheServiceFactory()) # type: ignore
def get_session_service() -> "SessionService":
@ -175,7 +186,9 @@ def get_session_service() -> "SessionService":
Returns:
The session service instance.
"""
return service_manager.get(ServiceType.SESSION_SERVICE) # type: ignore
from langflow.services.session.factory import SessionServiceFactory
return get_service(ServiceType.SESSION_SERVICE, SessionServiceFactory()) # type: ignore
def get_monitor_service() -> "MonitorService":
@ -185,7 +198,9 @@ def get_monitor_service() -> "MonitorService":
Returns:
MonitorService: The MonitorService instance.
"""
return service_manager.get(ServiceType.MONITOR_SERVICE) # type: ignore
from langflow.services.monitor.factory import MonitorServiceFactory
return get_service(ServiceType.MONITOR_SERVICE, MonitorServiceFactory()) # type: ignore
def get_task_service() -> "TaskService":
@ -196,7 +211,9 @@ def get_task_service() -> "TaskService":
The TaskService instance.
"""
return service_manager.get(ServiceType.TASK_SERVICE) # type: ignore
from langflow.services.task.factory import TaskServiceFactory
return get_service(ServiceType.TASK_SERVICE, TaskServiceFactory()) # type: ignore
def get_chat_service() -> "ChatService":
@ -206,7 +223,7 @@ def get_chat_service() -> "ChatService":
Returns:
ChatService: The chat service instance.
"""
return service_manager.get(ServiceType.CHAT_SERVICE) # type: ignore
return get_service(ServiceType.CHAT_SERVICE) # type: ignore
def get_store_service() -> "StoreService":
@ -216,4 +233,4 @@ def get_store_service() -> "StoreService":
Returns:
StoreService: The StoreService instance.
"""
return service_manager.get(ServiceType.STORE_SERVICE) # type: ignore
return get_service(ServiceType.STORE_SERVICE) # type: ignore

View file

@ -1,13 +1,20 @@
from typing import TYPE_CHECKING, Dict
import importlib
import inspect
from typing import TYPE_CHECKING, Dict, Optional
from loguru import logger
if TYPE_CHECKING:
from langflow.services.base import Service
from langflow.services.factory import ServiceFactory
from langflow.services.schema import ServiceType
class NoFactoryRegisteredError(Exception):
pass
class ServiceManager:
"""
Manages the creation of different services.
@ -16,6 +23,15 @@ class ServiceManager:
def __init__(self):
self.services: Dict[str, "Service"] = {}
self.factories = {}
self.register_factories()
def register_factories(self):
for factory in self.get_factories():
try:
self.register_factory(factory)
except Exception as exc:
logger.exception(exc)
logger.error(f"Error initializing {factory}: {exc}")
def register_factory(
self,
@ -28,24 +44,28 @@ class ServiceManager:
service_name = service_factory.service_class.name
self.factories[service_name] = service_factory
def get(self, service_name: "ServiceType") -> "Service":
def get(self, service_name: "ServiceType", default: Optional["ServiceFactory"] = None) -> "Service":
"""
Get (or create) a service by its name.
"""
if service_name not in self.services:
self._create_service(service_name)
self._create_service(service_name, default)
return self.services[service_name]
def _create_service(self, service_name: "ServiceType"):
def _create_service(self, service_name: "ServiceType", default: Optional["ServiceFactory"] = None):
"""
Create a new service given its name, handling dependencies.
"""
logger.debug(f"Create service {service_name}")
self._validate_service_creation(service_name)
self._validate_service_creation(service_name, default)
# Create dependencies first
factory = self.factories.get(service_name)
if factory is None and default is not None:
self.register_factory(default)
factory = default
for dependency in factory.dependencies:
if dependency not in self.services:
self._create_service(dependency)
@ -57,12 +77,12 @@ class ServiceManager:
self.services[service_name] = self.factories[service_name].create(**dependent_services)
self.services[service_name].set_ready()
def _validate_service_creation(self, service_name: "ServiceType"):
def _validate_service_creation(self, service_name: "ServiceType", default: Optional["ServiceFactory"] = None):
"""
Validate whether the service can be created.
"""
if service_name not in self.factories:
raise ValueError(f"No factory registered for the service class '{service_name.name}'")
if service_name not in self.factories and default is None:
raise NoFactoryRegisteredError(f"No factory registered for the service class '{service_name.name}'")
def update(self, service_name: "ServiceType"):
"""
@ -88,6 +108,34 @@ class ServiceManager:
self.services = {}
self.factories = {}
@staticmethod
def get_factories():
from langflow.services.factory import ServiceFactory
from langflow.services.schema import ServiceType
service_names = [ServiceType(service_type).value.replace("_service", "") for service_type in ServiceType]
base_module = "langflow.services"
factories = []
for name in service_names:
try:
module_name = f"{base_module}.{name}.factory"
module = importlib.import_module(module_name)
# Find all classes in the module that are subclasses of ServiceFactory
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, ServiceFactory) and obj is not ServiceFactory:
factories.append(obj())
break
except Exception as exc:
logger.exception(exc)
raise RuntimeError(
f"Could not initialize services. Please check your settings. Error in {name}."
) from exc
return factories
service_manager = ServiceManager()
@ -106,9 +154,7 @@ def initialize_session_service():
Initialize the session manager.
"""
from langflow.services.cache import factory as cache_factory
from langflow.services.session import (
factory as session_service_factory,
) # type: ignore
from langflow.services.session import factory as session_service_factory # type: ignore
initialize_settings_service()

View file

@ -19,5 +19,5 @@ class ServiceType(str, Enum):
VARIABLE_SERVICE = "variable_service"
STORAGE_SERVICE = "storage_service"
MONITOR_SERVICE = "monitor_service"
SOCKETIO_SERVICE = "socket_service"
# SOCKETIO_SERVICE = "socket_service"
STATE_SERVICE = "state_service"

View file

@ -1,6 +1,5 @@
from typing import Coroutine, Optional
from langflow.interface.run import build_sorted_vertices
from langflow.services.base import Service
from langflow.services.cache.base import CacheService
from langflow.services.session.utils import compute_dict_hash, session_id_generator
@ -25,8 +24,10 @@ class SessionService(Service):
if data_graph is None:
return (None, None)
# If not cached, build the graph and cache it
graph, artifacts = await build_sorted_vertices(data_graph, flow_id)
from langflow.graph.graph.base import Graph
graph = Graph.from_payload(data_graph, flow_id=flow_id)
artifacts: dict = {}
await self.cache_service.set(key, (graph, artifacts))
return graph, artifacts

View file

@ -7,12 +7,13 @@ from typing import Any, List, Optional, Tuple, Type
import orjson
import yaml
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
from loguru import logger
from pydantic import field_validator, validator
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, EnvSettingsSource, PydanticBaseSettingsSource, SettingsConfigDict
from langflow.services.settings.constants import VARIABLES_TO_GET_FROM_ENVIRONMENT
# BASE_COMPONENTS_PATH = str(Path(__file__).parent / "components")
BASE_COMPONENTS_PATH = str(Path(__file__).parent.parent.parent / "components")
@ -27,9 +28,16 @@ def is_list_of_any(field: FieldInfo) -> bool:
Returns:
bool: True if the field is a list or a list of any type, False otherwise.
"""
if field.annotation is None:
return False
try:
if hasattr(field.annotation, "__args__"):
union_args = field.annotation.__args__
else:
union_args = []
return field.annotation.__origin__ == list or any(
arg.__origin__ == list for arg in field.annotation.__args__ if hasattr(arg, "__origin__")
arg.__origin__ == list for arg in union_args if hasattr(arg, "__origin__")
)
except AttributeError:
return False
@ -92,9 +100,9 @@ class Settings(BaseSettings):
STORE: Optional[bool] = True
STORE_URL: Optional[str] = "https://api.langflow.store"
DOWNLOAD_WEBHOOK_URL: Optional[
str
] = "https://api.langflow.store/flows/trigger/ec611a61-8460-4438-b187-a4f65e5559d4"
DOWNLOAD_WEBHOOK_URL: Optional[str] = (
"https://api.langflow.store/flows/trigger/ec611a61-8460-4438-b187-a4f65e5559d4"
)
LIKE_WEBHOOK_URL: Optional[str] = "https://api.langflow.store/flows/trigger/64275852-ec00-45c1-984e-3bff814732da"
STORAGE_TYPE: str = "local"
@ -143,21 +151,50 @@ class Settings(BaseSettings):
# if there is a database in that location
if not values["CONFIG_DIR"]:
raise ValueError("CONFIG_DIR not set, please set it or provide a DATABASE_URL")
from langflow.version import is_pre_release # type: ignore
new_path = f"{values['CONFIG_DIR']}/langflow.db"
if Path("./langflow.db").exists():
pre_db_file_name = "langflow-pre.db"
db_file_name = "langflow.db"
new_pre_path = f"{values['CONFIG_DIR']}/{pre_db_file_name}"
new_path = f"{values['CONFIG_DIR']}/{db_file_name}"
final_path = None
if is_pre_release:
if Path(new_pre_path).exists():
final_path = new_pre_path
elif Path(new_path).exists():
# We need to copy the current db to the new location
logger.debug("Copying existing database to new location")
copy2(new_path, new_pre_path)
logger.debug(f"Copied existing database to {new_pre_path}")
elif Path(f"./{db_file_name}").exists():
logger.debug("Copying existing database to new location")
copy2(f"./{db_file_name}", new_pre_path)
logger.debug(f"Copied existing database to {new_pre_path}")
else:
logger.debug(f"Database already exists at {new_pre_path}, using it")
final_path = new_pre_path
else:
if Path(new_path).exists():
logger.debug(f"Database already exists at {new_path}, using it")
else:
final_path = new_path
elif Path("./{db_file_name}").exists():
try:
logger.debug("Copying existing database to new location")
copy2("./langflow.db", new_path)
copy2("./{db_file_name}", new_path)
logger.debug(f"Copied existing database to {new_path}")
except Exception:
logger.error("Failed to copy database, using default path")
new_path = "./langflow.db"
new_path = "./{db_file_name}"
else:
final_path = new_path
value = f"sqlite:///{new_path}"
if final_path is None:
if is_pre_release:
final_path = new_pre_path
else:
final_path = new_path
value = f"sqlite:///{final_path}"
return value

View file

@ -1,41 +1,13 @@
import importlib
import inspect
from loguru import logger
from sqlmodel import Session, select
from langflow.services.auth.utils import create_super_user, verify_password
from langflow.services.cache.factory import CacheServiceFactory
from langflow.services.database.utils import initialize_database
from langflow.services.factory import ServiceFactory
from langflow.services.manager import service_manager
from langflow.services.schema import ServiceType
from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD
from langflow.services.socket.utils import set_socketio_server
from .deps import get_db_service, get_session, get_settings_service
def get_factories():
service_names = [ServiceType(service_type).value.replace("_service", "") for service_type in ServiceType]
base_module = "langflow.services"
factories = []
for name in service_names:
try:
module_name = f"{base_module}.{name}.factory"
module = importlib.import_module(module_name)
# Find all classes in the module that are subclasses of ServiceFactory
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, ServiceFactory) and obj is not ServiceFactory:
factories.append(obj())
break
except Exception as exc:
logger.exception(exc)
raise RuntimeError(f"Could not initialize services. Please check your settings. Error in {name}.") from exc
return factories
from .deps import get_db_service, get_service, get_session, get_settings_service
def get_or_create_super_user(session: Session, username, password, is_default):
@ -145,6 +117,8 @@ def teardown_services():
except Exception as exc:
logger.exception(exc)
try:
from langflow.services.manager import service_manager
service_manager.teardown()
except Exception as exc:
logger.exception(exc)
@ -156,7 +130,7 @@ def initialize_settings_service():
"""
from langflow.services.settings import factory as settings_factory
service_manager.register_factory(settings_factory.SettingsServiceFactory())
get_service(ServiceType.SETTINGS_SERVICE, settings_factory.SettingsServiceFactory())
def initialize_session_service():
@ -168,11 +142,13 @@ def initialize_session_service():
initialize_settings_service()
service_manager.register_factory(
get_service(
ServiceType.CACHE_SERVICE,
cache_factory.CacheServiceFactory(),
)
service_manager.register_factory(
get_service(
ServiceType.SESSION_SERVICE,
session_service_factory.SessionServiceFactory(),
)
@ -181,27 +157,17 @@ def initialize_services(fix_migration: bool = False, socketio_server=None):
"""
Initialize all the services needed.
"""
for factory in get_factories():
try:
service_manager.register_factory(factory)
except Exception as exc:
logger.exception(exc)
logger.error(f"Error initializing {factory}: {exc}")
# Test cache connection
service_manager.get(ServiceType.CACHE_SERVICE)
get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory())
# Setup the superuser
try:
initialize_database(fix_migration=fix_migration)
except Exception as exc:
logger.error(exc)
raise exc
setup_superuser(service_manager.get(ServiceType.SETTINGS_SERVICE), next(get_session()))
setup_superuser(get_service(ServiceType.SETTINGS_SERVICE), next(get_session()))
try:
get_db_service().migrate_flows_if_auto_login()
except Exception as exc:
logger.error(f"Error migrating flows: {exc}")
raise RuntimeError("Error migrating flows") from exc
# Initialize the SocketIO service
set_socketio_server(socketio_server)

View file

@ -3,13 +3,14 @@ from typing import TYPE_CHECKING, Optional, Union
from uuid import UUID
from fastapi import Depends
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable
from langflow.services.deps import get_session
from loguru import logger
from sqlmodel import Session, select
from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
from langflow.services.deps import get_session
if TYPE_CHECKING:
from langflow.services.settings.service import SettingsService
@ -93,14 +94,13 @@ class VariableService(Service):
_type: str = "Generic",
session: Session = Depends(get_session),
):
variable = Variable(
user_id=user_id,
variable_base = VariableCreate(
name=name,
type=_type,
value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service),
)
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id})
session.add(variable)
session.commit()
session.refresh(variable)
return variable
return variable

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "langflow-base"
version = "0.0.27"
version = "0.0.30"
description = "A Python package with a built-in web application"
authors = ["Logspace <contact@logspace.ai>"]
maintainers = [
@ -62,30 +62,6 @@ cryptography = "^42.0.5"
asyncer = "^0.0.5"
[tool.poetry.group.dev.dependencies]
pytest-asyncio = "^0.23.1"
types-redis = "^4.6.0.5"
ipykernel = "^6.26.0"
mypy = "^1.7.1"
ruff = "^0.3.5"
httpx = "*"
pytest = "^8.1.1"
types-requests = "^2.31.0"
requests = "^2.31.0"
pytest-cov = "^5.0.0"
pandas-stubs = "^2.2.1.230412"
types-pillow = "^10.2.0.0"
types-pyyaml = "^6.0.12.8"
types-python-jose = "^3.3.4.8"
types-passlib = "^1.7.7.13"
locust = "^2.24.1"
pytest-mock = "^3.14.0"
pytest-xdist = "^3.5.0"
types-pywin32 = "^306.0.0.4"
types-google-cloud-ndb = "^2.3.0.0"
pytest-sugar = "^1.0.0"
[tool.poetry.extras]
deploy = ["celery", "redis", "flower"]
local = ["llama-cpp-python", "sentence-transformers", "ctransformers"]

View file

@ -1 +1 @@
from .version import __version__ # noqa: F401
from .version import __version__, is_pre_release # noqa: F401

View file

@ -2,6 +2,9 @@ from importlib import metadata
try:
__version__ = metadata.version("langflow")
# Check if the version is a pre-release version
is_pre_release = any(label in __version__ for label in ["a", "b", "rc", "dev", "post"])
except metadata.PackageNotFoundError:
__version__ = ""
is_pre_release = False
del metadata

View file

@ -1,4 +1,4 @@
import { useContext, useEffect } from "react";
import { useContext } from "react";
import { FaDiscord, FaGithub } from "react-icons/fa";
import { RiTwitterXFill } from "react-icons/ri";
import { Link, useLocation, useNavigate, useParams } from "react-router-dom";

View file

@ -164,7 +164,12 @@ export default function ChatMessage({
{chat.thought && chat.thought !== "" && !hidden && <br></br>}
<div className="flex w-full flex-col">
<div className="flex w-full flex-col dark:text-white">
<div data-testid={"chat-message-"+chat.sender_name+"-"+chatMessage} className="flex w-full flex-col">
<div
data-testid={
"chat-message-" + chat.sender_name + "-" + chatMessage
}
className="flex w-full flex-col"
>
{useMemo(
() =>
chatMessage === "" && lockChat ? (
@ -313,7 +318,13 @@ dark:prose-invert"
</span>
</>
) : (
<span data-testid={"chat-message-"+chat.sender_name+"-"+chatMessage}>{chatMessage}</span>
<span
data-testid={
"chat-message-" + chat.sender_name + "-" + chatMessage
}
>
{chatMessage}
</span>
)}
</div>
)}

File diff suppressed because one or more lines are too long

View file

@ -1,4 +1,4 @@
import { test,expect } from "@playwright/test";
import { expect, test } from "@playwright/test";
import { readFileSync } from "fs";
test("chat_io_teste", async ({ page }) => {