ref: Auto-fix some ruff ANN rules (#4210)

* Auto-fix some ruff ANN rules

* Fix mypy errors

* Changes following review

* Fix ServiceFactory
This commit is contained in:
Christophe Bornet 2024-10-22 13:55:48 +02:00 committed by GitHub
commit 507eda997a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
138 changed files with 588 additions and 580 deletions

View file

@ -45,7 +45,7 @@ def get_number_of_workers(workers=None):
return workers
def display_results(results):
def display_results(results) -> None:
"""Display the results of the migration."""
for table_results in results:
table = Table(title=f"Migration {table_results.table_name}")
@ -62,7 +62,7 @@ def display_results(results):
console.print() # Print a new line
def set_var_for_macos_issue():
def set_var_for_macos_issue() -> None:
# OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
# we need to set this var is we are running on MacOS
# otherwise we get an error when running gunicorn
@ -146,7 +146,7 @@ def run(
help="Defines the maximum file size for the upload in MB.",
show_default=False,
),
):
) -> None:
"""Run Langflow."""
configure(log_level=log_level, log_file=log_file)
set_var_for_macos_issue()
@ -202,7 +202,7 @@ def run(
# Run using uvicorn on MacOS and Windows
# Windows doesn't support gunicorn
# MacOS requires an env variable to be set to use gunicorn
process = run_on_windows(host, port, log_level, options, app)
run_on_windows(host, port, log_level, options, app)
else:
# Run using gunicorn on Linux
process = run_on_mac_or_linux(host, port, log_level, options, app)
@ -219,7 +219,7 @@ def run(
sys.exit(1)
def wait_for_server_ready(host, port):
def wait_for_server_ready(host, port) -> None:
"""Wait for the server to become ready by polling the health endpoint."""
status_code = 0
while status_code != httpx.codes.OK:
@ -241,7 +241,7 @@ def run_on_mac_or_linux(host, port, log_level, options, app):
return webapp_process
def run_on_windows(host, port, log_level, options, app):
def run_on_windows(host, port, log_level, options, app) -> None:
"""Run the Langflow server on Windows."""
print_banner(host, port)
run_langflow(host, port, log_level, options, app)
@ -275,7 +275,7 @@ def get_free_port(port):
return port
def get_letter_from_version(version: str):
def get_letter_from_version(version: str) -> str | None:
"""Get the letter from a pre-release version."""
if "a" in version:
return "a"
@ -294,7 +294,7 @@ def build_version_notice(current_version: str, package_name: str) -> str:
return ""
def generate_pip_command(package_names, is_pre_release):
def generate_pip_command(package_names, is_pre_release) -> str:
"""Generate the pip install command based on the packages and whether it's a pre-release."""
base_command = "pip install"
if is_pre_release:
@ -309,7 +309,7 @@ def stylize_text(text: str, to_style: str, *, is_prerelease: bool) -> str:
return text.replace(to_style, styled_text)
def print_banner(host: str, port: int):
def print_banner(host: str, port: int) -> None:
notices = []
package_names = [] # Track package names for pip install instructions
is_pre_release = False # Track if any package is a pre-release
@ -355,7 +355,7 @@ def print_banner(host: str, port: int):
rprint(panel)
def run_langflow(host, port, log_level, options, app):
def run_langflow(host, port, log_level, options, app) -> None:
"""Run Langflow server on localhost."""
if platform.system() == "Windows":
# Run using uvicorn on MacOS and Windows
@ -381,7 +381,7 @@ def superuser(
username: str = typer.Option(..., prompt=True, help="Username for the superuser."),
password: str = typer.Option(..., prompt=True, hide_input=True, help="Password for the superuser."),
log_level: str = typer.Option("error", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
):
) -> None:
"""Create a superuser."""
configure(log_level=log_level)
initialize_services()
@ -413,7 +413,7 @@ def superuser(
# command to copy the langflow database from the cache to the current directory
# because now the database is stored per installation
@app.command()
def copy_db():
def copy_db() -> None:
"""Copy the database files to the current directory.
This function copies the 'langflow.db' and 'langflow-pre.db' files from the cache directory to the current
@ -452,7 +452,7 @@ def migration(
default=False,
help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.",
),
):
) -> None:
"""Run or test migrations."""
if fix and not typer.confirm(
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
@ -470,7 +470,7 @@ def migration(
@app.command()
def api_key(
log_level: str = typer.Option("error", help="Logging level."),
):
) -> None:
"""Creates an API key for the default superuser if AUTO_LOGIN is enabled.
Args:
@ -510,7 +510,7 @@ def api_key(
api_key_banner(unmasked_api_key)
def api_key_banner(unmasked_api_key):
def api_key_banner(unmasked_api_key) -> None:
is_mac = platform.system() == "Darwin"
import pyperclip
@ -529,7 +529,7 @@ def api_key_banner(unmasked_api_key):
console.print(panel)
def main():
def main() -> None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
app()

View file

@ -93,7 +93,7 @@ def get_is_component_from_data(data: dict):
return data.get("is_component")
async def check_langflow_version(component: StoreComponentCreate):
async def check_langflow_version(component: StoreComponentCreate) -> None:
from langflow.utils.version import get_version_info
__version__ = get_version_info()["version"]
@ -259,7 +259,7 @@ def parse_value(value: Any, input_type: str) -> Any:
return value
async def cascade_delete_flow(session: Session, flow: Flow):
async def cascade_delete_flow(session: Session, flow: Flow) -> None:
try:
session.exec(delete(TransactionTable).where(TransactionTable.flow_id == flow.id))
session.exec(delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow.id))

View file

@ -110,7 +110,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
@override
async def on_agent_action( # type: ignore[misc]
self, action: AgentAction, **kwargs: Any
):
) -> None:
log = f"Thought: {action.log}"
# if there are line breaks, split them and send them
# as separate messages

View file

@ -419,7 +419,7 @@ async def build_flow(
event_manager = create_default_event_manager(queue=asyncio_queue)
main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed))
def on_disconnect():
def on_disconnect() -> None:
logger.debug("Client disconnected, closing tasks")
main_task.cancel()

View file

@ -16,7 +16,6 @@ from langflow.api.v1.schemas import (
CustomComponentRequest,
CustomComponentResponse,
InputValueRequest,
ProcessResponse,
RunResponse,
SidebarCategoriesResponse,
SimplifiedAPIRequest,
@ -68,7 +67,7 @@ async def get_all():
raise HTTPException(status_code=500, detail=str(exc)) from exc
def validate_input_and_tweaks(input_request: SimplifiedAPIRequest):
def validate_input_and_tweaks(input_request: SimplifiedAPIRequest) -> None:
# If the input_value is not None and the input_type is "chat"
# then we need to check the tweaks if the ChatInput component is present
# and if its input_value is not None
@ -483,15 +482,13 @@ async def experimental_run_flow(
@router.post(
"/predict/{flow_id}",
response_model=ProcessResponse,
dependencies=[Depends(api_key_security)],
)
@router.post(
"/process/{flow_id}",
response_model=ProcessResponse,
dependencies=[Depends(api_key_security)],
)
async def process():
async def process() -> None:
"""Endpoint to process an input with a given flow_id."""
# Raise a depreciation warning
logger.warning(

View file

@ -36,7 +36,7 @@ async def get_vertex_builds(
async def delete_vertex_builds(
flow_id: Annotated[UUID, Query()],
session: Annotated[Session, Depends(get_session)],
):
) -> None:
try:
delete_vertex_builds_by_flow_id(session, flow_id)
except Exception as e:
@ -75,7 +75,7 @@ async def get_messages(
async def delete_messages(
message_ids: list[UUID],
session: Annotated[Session, Depends(get_session)],
):
) -> None:
try:
session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined]
session.commit()

View file

@ -95,7 +95,7 @@ def delete_variable(
variable_id: UUID,
current_user: User = Depends(get_current_active_user),
variable_service: VariableService = Depends(get_variable_service),
):
) -> None:
"""Delete a variable."""
try:
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)

View file

@ -64,7 +64,7 @@ class LCAgentComponent(Component):
self.status = message
return message
def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["build_agent"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:

View file

@ -52,7 +52,7 @@ class BaseCrewComponent(Component):
def get_task_callback(
self,
) -> Callable:
def task_callback(task_output: TaskOutput):
def task_callback(task_output: TaskOutput) -> None:
vertex_id = self._vertex.id if self._vertex else self.display_name or self.__class__.__name__
self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}")
@ -61,7 +61,7 @@ class BaseCrewComponent(Component):
def get_step_callback(
self,
) -> Callable:
def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]):
def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]) -> None:
_id = self._vertex.id if self._vertex else self.display_name
if isinstance(agent_output, AgentFinish):
messages = agent_output.messages

View file

@ -36,7 +36,7 @@ tool_names = []
tools_and_names = {}
def tools_from_package(your_package):
def tools_from_package(your_package) -> None:
# Iterate over all modules in the package
package_name = your_package.__name__
for module_info in pkgutil.iter_modules(your_package.__path__):

View file

@ -7,7 +7,7 @@ class LCChainComponent(Component):
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["invoke_chain"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:

View file

@ -10,7 +10,7 @@ class LCEmbeddingsModel(Component):
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["build_embeddings"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:

View file

@ -29,7 +29,7 @@ class ChatComponent(Component):
self.status = stored_message
return stored_message
def _send_message_event(self, message: Message):
def _send_message_event(self, message: Message) -> None:
if hasattr(self, "_event_manager") and self._event_manager:
self._event_manager.on_message(data=message.data)
@ -107,7 +107,7 @@ class ChatComponent(Component):
return Message.from_data(input_value)
return Message(text=input_value, sender=sender, sender_name=sender_name, files=files, session_id=session_id)
def _send_messages_events(self, messages):
def _send_messages_events(self, messages) -> None:
if hasattr(self, "_event_manager") and self._event_manager:
for stored_message in messages:
self._event_manager.on_message(data=stored_message.data)

View file

@ -14,7 +14,7 @@ class LCToolComponent(Component):
Output(name="api_build_tool", display_name="Tool", method="build_tool"),
]
def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["run_model", "build_tool"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:

View file

@ -45,5 +45,5 @@ class BaseMemoryComponent(CustomComponent):
def add_message(
self, sender: str, sender_name: str, text: str, session_id: str, metadata: dict | None = None, **kwargs
):
) -> None:
raise NotImplementedError

View file

@ -17,7 +17,7 @@ class LCChatMemoryComponent(Component):
)
]
def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["build_message_history"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:

View file

@ -42,7 +42,7 @@ class LCModelComponent(Component):
def _get_exception_message(self, e: Exception):
return str(e)
def _validate_outputs(self):
def _validate_outputs(self) -> None:
# At least these two outputs must be defined
required_output_methods = ["text_response", "build_model"]
output_names = [output.name for output in self.outputs]

View file

@ -86,7 +86,7 @@ def _check_variable(var, invalid_chars, wrong_variables, empty_variables):
return wrong_variables, empty_variables
def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables):
def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables) -> None:
if any(var for var in input_variables if var not in fixed_variables):
error_message = (
f"Error: Input variables contain invalid characters or formats. \n"
@ -159,7 +159,7 @@ def get_old_custom_fields(custom_fields, name):
return old_custom_fields
def add_new_variables_to_template(input_variables, custom_fields, template, name):
def add_new_variables_to_template(input_variables, custom_fields, template, name) -> None:
for variable in input_variables:
try:
template_field = DefaultPromptField(name=variable, display_name=variable)
@ -177,7 +177,7 @@ def add_new_variables_to_template(input_variables, custom_fields, template, name
raise HTTPException(status_code=500, detail=str(exc)) from exc
def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name):
def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name) -> None:
for variable in old_custom_fields:
if variable not in input_variables:
try:
@ -192,7 +192,7 @@ def remove_old_variables_from_template(old_custom_fields, input_variables, custo
raise HTTPException(status_code=500, detail=str(exc)) from exc
def update_input_variables_field(input_variables, template):
def update_input_variables_field(input_variables, template) -> None:
if "input_variables" in template:
template["input_variables"]["value"] = input_variables

View file

@ -9,7 +9,7 @@ from langflow.base.document_transformers.model import LCDocumentTransformerCompo
class LCTextSplitterComponent(LCDocumentTransformerComponent):
trace_type = "text_splitter"
def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["text_splitter"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:

View file

@ -27,7 +27,7 @@ def _get_input_type(_input: InputTypes):
return _input.field_type
def build_description(component: Component, output: Output):
def build_description(component: Component, output: Output) -> str:
if not output.required_inputs:
logger.warning(f"Output {output.name} does not have required inputs defined")

View file

@ -71,7 +71,7 @@ class LCVectorStoreComponent(Component):
),
]
def _validate_outputs(self):
def _validate_outputs(self) -> None:
# At least these three outputs must be defined
required_output_methods = [
"build_base_retriever",

View file

@ -67,39 +67,39 @@ class AstraAssistantManager(ComponentWithCache):
Output(display_name="Assistant Id", name="output_assistant_id", method="get_assistant_id"),
]
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.lock = asyncio.Lock()
self.initialized = False
self.assistant_response = None
self.tool_output = None
self.thread_id = None
self.assistant_id = None
self.initialized: bool = False
self._assistant_response: Message = None # type: ignore[assignment]
self._tool_output: Message = None # type: ignore[assignment]
self._thread_id: Message = None # type: ignore[assignment]
self._assistant_id: Message = None # type: ignore[assignment]
self.client = get_patched_openai_client(self._shared_component_cache)
async def get_assistant_response(self) -> Message:
await self.initialize()
return self.assistant_response
return self._assistant_response
async def get_tool_output(self) -> Message:
await self.initialize()
return self.tool_output
return self._tool_output
async def get_thread_id(self) -> Message:
await self.initialize()
return self.thread_id
return self._thread_id
async def get_assistant_id(self) -> Message:
await self.initialize()
return self.assistant_id
return self._assistant_id
async def initialize(self):
async def initialize(self) -> None:
async with self.lock:
if not self.initialized:
await self.process_inputs()
self.initialized = True
async def process_inputs(self):
async def process_inputs(self) -> None:
logger.info(f"env_set is {self.env_set}")
logger.info(self.tool)
tools = []
@ -126,10 +126,10 @@ class AstraAssistantManager(ComponentWithCache):
content = self.user_message
result = await assistant_manager.run_thread(content=content, tool=tool_obj)
self.assistant_response = Message(text=result["text"])
self._assistant_response = Message(text=result["text"])
if "decision" in result:
self.tool_output = Message(text=str(result["decision"].is_complete))
self._tool_output = Message(text=str(result["decision"].is_complete))
else:
self.tool_output = Message(text=result["text"])
self.thread_id = Message(text=assistant_manager.thread.id)
self.assistant_id = Message(text=assistant_manager.assistant.id)
self._tool_output = Message(text=result["text"])
self._thread_id = Message(text=assistant_manager.thread.id)
self._assistant_id = Message(text=assistant_manager.assistant.id)

View file

@ -45,7 +45,7 @@ class AssistantsCreateAssistant(ComponentWithCache):
Output(display_name="Assistant ID", name="assistant_id", method="process_inputs"),
]
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.client = get_patched_openai_client(self._shared_component_cache)

View file

@ -21,7 +21,7 @@ class AssistantsCreateThread(ComponentWithCache):
Output(display_name="Thread ID", name="thread_id", method="process_inputs"),
]
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.client = get_patched_openai_client(self._shared_component_cache)

View file

@ -26,7 +26,7 @@ class AssistantsGetAssistantName(ComponentWithCache):
Output(display_name="Assistant Name", name="assistant_name", method="process_inputs"),
]
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.client = get_patched_openai_client(self._shared_component_cache)

View file

@ -12,7 +12,7 @@ class AssistantsListAssistants(ComponentWithCache):
Output(display_name="Assistants", name="assistants", method="process_inputs"),
]
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.client = get_patched_openai_client(self._shared_component_cache)

View file

@ -16,7 +16,7 @@ class AssistantsRun(ComponentWithCache):
display_name = "Run Assistant"
description = "Executes an Assistant Run against a thread"
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.client = get_patched_openai_client(self._shared_component_cache)
self.thread_id = None
@ -26,7 +26,7 @@ class AssistantsRun(ComponentWithCache):
build_config: dotdict,
field_value: Any,
field_name: str | None = None,
):
) -> None:
if field_name == "thread_id":
if field_value is None:
thread = self.client.beta.threads.create()
@ -75,7 +75,7 @@ class AssistantsRun(ComponentWithCache):
self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=self.user_message)
class EventHandler(AssistantEventHandler):
def __init__(self):
def __init__(self) -> None:
super().__init__()
def on_exception(self, exception: Exception) -> None:

View file

@ -88,7 +88,7 @@ class GoogleDriveSearchComponent(Component):
return query
def on_inputs_changed(self):
def on_inputs_changed(self) -> None:
# Automatically regenerate the query string when inputs change
self.generate_query_string()

View file

@ -36,7 +36,7 @@ class GoogleGenerativeAIEmbeddingsComponent(Component):
raise ValueError(msg)
class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs)
def embed_documents(

View file

@ -99,7 +99,7 @@ class CreateDataComponent(Component):
data.update(_value_dict)
return data
def validate_text_key(self):
def validate_text_key(self) -> None:
"""This function validates that the Text Key is one of the keys in the Data."""
data_keys = self.get_data().keys()
if self.text_key not in data_keys and self.text_key != "":

View file

@ -106,7 +106,7 @@ class UpdateDataComponent(Component):
data.update(_value_dict)
return data
def validate_text_key(self, data: Data):
def validate_text_key(self, data: Data) -> None:
"""This function validates that the Text Key is one of the keys in the Data."""
data_keys = data.data.keys()
if self.text_key not in data_keys and self.text_key != "":

View file

@ -434,7 +434,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
return vector_store
def _add_documents_to_vector_store(self, vector_store):
def _add_documents_to_vector_store(self, vector_store) -> None:
documents = []
for _input in self.ingest_data or []:
if isinstance(_input, Data):
@ -453,7 +453,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
else:
logger.debug("No documents to add to the Vector Store.")
def _map_search_type(self):
def _map_search_type(self) -> str:
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
if self.search_type == "MMR (Max Marginal Relevance)":

View file

@ -206,7 +206,7 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
)
return table
def _map_search_type(self):
def _map_search_type(self) -> str:
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
if self.search_type == "MMR (Max Marginal Relevance)":

View file

@ -181,7 +181,7 @@ class CassandraGraphVectorStoreComponent(LCVectorStoreComponent):
)
return store
def _map_search_type(self):
def _map_search_type(self) -> str:
if self.search_type == "Similarity":
return "similarity"
if self.search_type == "Similarity with score threshold":

View file

@ -253,7 +253,7 @@ class HCDVectorStoreComponent(LCVectorStoreComponent):
self._add_documents_to_vector_store(vector_store)
return vector_store
def _add_documents_to_vector_store(self, vector_store):
def _add_documents_to_vector_store(self, vector_store) -> None:
documents = []
for _input in self.ingest_data or []:
if isinstance(_input, Data):
@ -272,7 +272,7 @@ class HCDVectorStoreComponent(LCVectorStoreComponent):
else:
logger.debug("No documents to add to the Vector Store.")
def _map_search_type(self):
def _map_search_type(self) -> str:
if self.search_type == "Similarity with score threshold":
return "similarity_score_threshold"
if self.search_type == "MMR (Max Marginal Relevance)":

View file

@ -318,7 +318,7 @@ class CodeParser:
self.process_class_node(_node, class_details)
self.data["classes"].append(class_details.model_dump())
def process_class_node(self, node, class_details):
def process_class_node(self, node, class_details) -> None:
for stmt in node.body:
if isinstance(stmt, ast.Assign):
if attr := self.parse_assign(stmt):

View file

@ -31,15 +31,15 @@ class BaseComponent:
_user_id: str | UUID | None = None
_template_config: dict = {}
def __init__(self, **data):
self.cache = TTLCache(maxsize=1024, ttl=60)
def __init__(self, **data) -> None:
self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60)
for key, value in data.items():
if key == "user_id":
self._user_id = value
else:
setattr(self, key, value)
def __setattr__(self, key, value):
def __setattr__(self, key, value) -> None:
if key == "_user_id" and self._user_id is not None:
logger.warning("user_id is immutable and cannot be changed.")
super().__setattr__(key, value)

View file

@ -61,7 +61,7 @@ class Component(CustomComponent):
_current_output: str = ""
_metadata: dict = {}
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
# if key starts with _ it is a config
# else it is an input
inputs = {}
@ -107,10 +107,10 @@ class Component(CustomComponent):
self.set_class_code()
self._set_output_required_inputs()
def set_event_manager(self, event_manager: EventManager | None = None):
def set_event_manager(self, event_manager: EventManager | None = None) -> None:
self._event_manager = event_manager
def _reset_all_output_values(self):
def _reset_all_output_values(self) -> None:
if isinstance(self._outputs_map, dict):
for output in self._outputs_map.values():
output.value = UNDEFINED
@ -153,7 +153,7 @@ class Component(CustomComponent):
memo[id(self)] = new_component
return new_component
def set_class_code(self):
def set_class_code(self) -> None:
# Get the source code of the calling class
if self._code:
return
@ -200,7 +200,7 @@ class Component(CustomComponent):
"""
return await self._run()
def set_vertex(self, vertex: Vertex):
def set_vertex(self, vertex: Vertex) -> None:
"""Sets the vertex for the component.
Args:
@ -245,7 +245,7 @@ class Component(CustomComponent):
msg = f"Output {name} not found in {self.__class__.__name__}"
raise ValueError(msg)
def set_on_output(self, name: str, **kwargs):
def set_on_output(self, name: str, **kwargs) -> None:
output = self.get_output(name)
for key, value in kwargs.items():
if not hasattr(output, key):
@ -253,14 +253,14 @@ class Component(CustomComponent):
raise ValueError(msg)
setattr(output, key, value)
def set_output_value(self, name: str, value: Any):
def set_output_value(self, name: str, value: Any) -> None:
if name in self._outputs_map:
self._outputs_map[name].value = value
else:
msg = f"Output {name} not found in {self.__class__.__name__}"
raise ValueError(msg)
def map_outputs(self, outputs: list[Output]):
def map_outputs(self, outputs: list[Output]) -> None:
"""Maps the given list of outputs to the component.
Args:
@ -280,7 +280,7 @@ class Component(CustomComponent):
# allows each instance of each component to modify its own output
self._outputs_map[output.name] = deepcopy(output)
def map_inputs(self, inputs: list[InputTypes]):
def map_inputs(self, inputs: list[InputTypes]) -> None:
"""Maps the given inputs to the component.
Args:
@ -296,7 +296,7 @@ class Component(CustomComponent):
raise ValueError(msg)
self._inputs[input_.name] = deepcopy(input_)
def validate(self, params: dict):
def validate(self, params: dict) -> None:
"""Validates the component parameters.
Args:
@ -309,13 +309,16 @@ class Component(CustomComponent):
self._validate_inputs(params)
self._validate_outputs()
def _set_output_types(self):
def _set_output_types(self) -> None:
for output in self._outputs_map.values():
if output.method is None:
msg = f"Output {output.name} does not have a method"
raise ValueError(msg)
return_types = self._get_method_return_type(output.method)
output.add_types(return_types)
output.set_selected()
def _set_output_required_inputs(self):
def _set_output_required_inputs(self) -> None:
for output in self.outputs:
if not output.method:
continue
@ -326,8 +329,7 @@ class Component(CustomComponent):
source_code = inspect.getsource(method)
ast_tree = ast.parse(dedent(source_code))
except Exception: # noqa: BLE001
source_code = self._code
ast_tree = ast.parse(dedent(source_code))
ast_tree = ast.parse(dedent(self._code or ""))
visitor = RequiredInputsVisitor(self._inputs)
visitor.visit(ast_tree)
@ -420,7 +422,7 @@ class Component(CustomComponent):
raise TypeError(msg)
return getattr(value, output.method)
def _process_connection_or_parameter(self, key, value):
def _process_connection_or_parameter(self, key, value) -> None:
_input = self._get_or_create_input(key)
# We need to check if callable AND if it is a method from a class that inherits from Component
if isinstance(value, Component):
@ -438,7 +440,7 @@ class Component(CustomComponent):
else:
self._set_parameter_or_attribute(key, value)
def _process_connection_or_parameters(self, key, value):
def _process_connection_or_parameters(self, key, value) -> None:
# if value is a list of components, we need to process each component
if isinstance(value, list):
for val in value:
@ -455,13 +457,13 @@ class Component(CustomComponent):
self.inputs.append(_input)
return _input
def _connect_to_component(self, key, value, _input):
def _connect_to_component(self, key, value, _input) -> None:
component = value.__self__
self._components.append(component)
output = component.get_output_by_method(value)
self._add_edge(component, key, output, _input)
def _add_edge(self, component, key, output, _input):
def _add_edge(self, component, key, output, _input) -> None:
self._edges.append(
{
"source": component._id,
@ -483,7 +485,7 @@ class Component(CustomComponent):
}
)
def _set_parameter_or_attribute(self, key, value):
def _set_parameter_or_attribute(self, key, value) -> None:
if isinstance(value, Component):
methods = ", ".join([f"'{output.method}'" for output in value.outputs])
msg = (
@ -527,7 +529,7 @@ class Component(CustomComponent):
msg = f"{name} not found in {self.__class__.__name__}"
raise AttributeError(msg)
def _set_input_value(self, name: str, value: Any):
def _set_input_value(self, name: str, value: Any) -> None:
if name in self._inputs:
input_value = self._inputs[name].value
if isinstance(input_value, Component):
@ -547,15 +549,15 @@ class Component(CustomComponent):
msg = f"Input {name} not found in {self.__class__.__name__}"
raise ValueError(msg)
def _validate_outputs(self):
def _validate_outputs(self) -> None:
# Raise Error if some rule isn't met
pass
def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode):
def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode) -> None:
for name, value in self._parameters.items():
frontend_node.set_field_value_in_template(name, value)
def _map_parameters_on_template(self, template: dict):
def _map_parameters_on_template(self, template: dict) -> None:
for name, value in self._parameters.items():
try:
template[name]["value"] = value
@ -625,7 +627,7 @@ class Component(CustomComponent):
"id": self._id,
}
def _validate_inputs(self, params: dict):
def _validate_inputs(self, params: dict) -> None:
# Params keys are the `name` attribute of the Input objects
for key, value in params.copy().items():
if key not in self._inputs:
@ -636,7 +638,7 @@ class Component(CustomComponent):
input_.value = value
params[input_.name] = input_.value
def set_attributes(self, params: dict):
def set_attributes(self, params: dict) -> None:
self._validate_inputs(params)
_attributes = {}
for key, value in params.items():
@ -652,7 +654,7 @@ class Component(CustomComponent):
_attributes[key] = input_obj.value or None
self._attributes = _attributes
def _set_outputs(self, outputs: list[dict]):
def _set_outputs(self, outputs: list[dict]) -> None:
self.outputs = [Output(**output) for output in outputs]
for output in self.outputs:
setattr(self, output.name, output)
@ -794,7 +796,7 @@ class Component(CustomComponent):
except KeyError:
return []
def build(self, **kwargs):
def build(self, **kwargs) -> None:
self.set_attributes(kwargs)
def _get_fallback_input(self, **kwargs):
@ -809,7 +811,7 @@ class Component(CustomComponent):
return self._tracing_service.project_name
return "Langflow"
def log(self, message: LoggableType | list[LoggableType], name: str | None = None):
def log(self, message: LoggableType | list[LoggableType], name: str | None = None) -> None:
"""Logs a message.
Args:
@ -828,6 +830,6 @@ class Component(CustomComponent):
data["component_id"] = self._id
self._event_manager.on_log(data=data)
def _append_tool_output(self):
def _append_tool_output(self) -> None:
if next((output for output in self.outputs if output.name == TOOL_OUTPUT_NAME), None) is None:
self.outputs.append(Output(name=TOOL_OUTPUT_NAME, display_name="Tool", method="to_toolkit", types=["Tool"]))

View file

@ -3,6 +3,6 @@ from langflow.services.deps import get_shared_component_cache_service
class ComponentWithCache(Component):
def __init__(self, **data):
def __init__(self, **data) -> None:
super().__init__(**data)
self._shared_component_cache = get_shared_component_cache_service()

View file

@ -85,7 +85,7 @@ class CustomComponent(BaseComponent):
_tracing_service: TracingService | None = None
_tree: dict | None = None
def __init__(self, **data):
def __init__(self, **data) -> None:
"""Initializes a new instance of the CustomComponent class.
Args:
@ -93,22 +93,25 @@ class CustomComponent(BaseComponent):
"""
self.cache = TTLCache(maxsize=1024, ttl=60)
self._logs = []
self._results = {}
self._artifacts = {}
self._results: dict = {}
self._artifacts: dict = {}
super().__init__(**data)
def set_attributes(self, parameters: dict):
def set_attributes(self, parameters: dict) -> None:
pass
def set_parameters(self, parameters: dict):
def set_parameters(self, parameters: dict) -> None:
self._parameters = parameters
self.set_attributes(self._parameters)
@property
def trace_name(self):
def trace_name(self) -> str:
if self._vertex is None:
msg = "Vertex is not set"
raise ValueError(msg)
return f"{self.display_name} ({self._vertex.id})"
def update_state(self, name: str, value: Any):
def update_state(self, name: str, value: Any) -> None:
if not self._vertex:
msg = "Vertex is not set"
raise ValueError(msg)
@ -118,7 +121,7 @@ class CustomComponent(BaseComponent):
msg = f"Error updating state: {e}"
raise ValueError(msg) from e
def stop(self, output_name: str | None = None):
def stop(self, output_name: str | None = None) -> None:
if not output_name and self._vertex and len(self._vertex.outputs) == 1:
output_name = self._vertex.outputs[0]["name"]
elif not output_name:
@ -133,7 +136,7 @@ class CustomComponent(BaseComponent):
msg = f"Error stopping {self.display_name}: {e}"
raise ValueError(msg) from e
def append_state(self, name: str, value: Any):
def append_state(self, name: str, value: Any) -> None:
if not self._vertex:
msg = "Vertex is not set"
raise ValueError(msg)

View file

@ -13,7 +13,7 @@ class CustomComponentPathValueError(ValueError):
class StringCompressor:
def __init__(self, input_string):
def __init__(self, input_string) -> None:
"""Initialize StringCompressor with a string to compress."""
self.input_string = input_string
@ -39,7 +39,7 @@ class DirectoryReader:
# the custom components from this directory.
base_path = ""
def __init__(self, directory_path, *, compress_code_field=False):
def __init__(self, directory_path, *, compress_code_field=False) -> None:
"""Initialize DirectoryReader with a directory path and a flag indicating whether to compress the code."""
self.directory_path = directory_path
self.compress_code_field = compress_code_field
@ -76,7 +76,7 @@ class DirectoryReader:
logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}')
return {"menu": filtered}
def validate_code(self, file_content):
def validate_code(self, file_content) -> bool:
"""Validate the Python code by trying to parse it with ast.parse."""
try:
ast.parse(file_content)

View file

@ -109,7 +109,7 @@ def create_invalid_component_template(component, component_name):
return component_frontend_node.model_dump(by_alias=True, exclude_none=True)
def log_invalid_component_details(component):
def log_invalid_component_details(component) -> None:
"""Log details of an invalid component."""
logger.debug(component)
logger.debug(f"Component Path: {component.get('path', None)}")

View file

@ -28,5 +28,5 @@ class CallableCodeDetails(BaseModel):
class MissingDefault:
"""A class to represent a missing default value."""
def __repr__(self):
def __repr__(self) -> str:
return "MISSING"

View file

@ -1,15 +1,16 @@
import ast
from typing import Any
from typing_extensions import override
class RequiredInputsVisitor(ast.NodeVisitor):
def __init__(self, inputs):
self.inputs = inputs
self.required_inputs = set()
def __init__(self, inputs: dict[str, Any]):
self.inputs: dict[str, Any] = inputs
self.required_inputs: set[str] = set()
@override
def visit_Attribute(self, node):
def visit_Attribute(self, node) -> None:
if isinstance(node.value, ast.Name) and node.value.id == "self" and node.attr in self.inputs:
self.required_inputs.add(node.attr)
self.generic_visit(node)

View file

@ -32,7 +32,7 @@ class UpdateBuildConfigError(Exception):
pass
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: list[str]):
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: list[str]) -> None:
"""Add output types to the frontend node."""
for return_type in return_types:
if return_type is None:
@ -55,7 +55,7 @@ def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: l
frontend_node.add_output_type(_return_type)
def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list[str]):
def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list[str]) -> None:
"""Reorder fields in the frontend node based on the specified field_order."""
if not field_order:
return
@ -69,7 +69,7 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list
frontend_node.field_order = field_order
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: list[str]):
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: list[str]) -> None:
"""Add base classes to the frontend node."""
for return_type_instance in return_types:
if return_type_instance is None:
@ -196,7 +196,7 @@ def add_new_custom_field(
return frontend_node
def add_extra_fields(frontend_node, field_config, function_args):
def add_extra_fields(frontend_node, field_config, function_args) -> None:
"""Add extra fields to the frontend node."""
if not function_args:
return

View file

@ -25,7 +25,7 @@ class EventManager:
self.events: dict[str, PartialEventCallback] = {}
@staticmethod
def _validate_callback(callback: EventCallback):
def _validate_callback(callback: EventCallback) -> None:
if not callable(callback):
msg = "Callback must be callable"
raise TypeError(msg)
@ -39,7 +39,7 @@ class EventManager:
msg = "Callback must have exactly 3 parameters: manager, event_type, and data"
raise ValueError(msg)
def register_event(self, name: str, event_type: str, callback: EventCallback | None = None):
def register_event(self, name: str, event_type: str, callback: EventCallback | None = None) -> None:
if not name:
msg = "Event name cannot be empty"
raise ValueError(msg)
@ -52,14 +52,14 @@ class EventManager:
_callback = partial(callback, manager=self, event_type=event_type)
self.events[name] = _callback
def send_event(self, *, event_type: str, data: LoggableType):
def send_event(self, *, event_type: str, data: LoggableType) -> None:
jsonable_data = jsonable_encoder(data)
json_data = {"event": event_type, "data": jsonable_data}
event_id = uuid.uuid4()
str_data = json.dumps(json_data) + "\n\n"
self.queue.put_nowait((event_id, str_data.encode("utf-8"), time.time()))
def noop(self, *, data: LoggableType):
def noop(self, *, data: LoggableType) -> None:
pass
def __getattr__(self, name: str) -> PartialEventCallback:

View file

@ -24,7 +24,7 @@ class VertexViewer:
HEIGHT = 3 # top and bottom box edges + text
def __init__(self, name):
def __init__(self, name) -> None:
self._h = self.HEIGHT # top and bottom box edges + text
self._w = len(name) + 2 # right and left bottom edges + text
@ -40,7 +40,7 @@ class VertexViewer:
class AsciiCanvas:
"""Class for drawing in ASCII."""
def __init__(self, cols, lines):
def __init__(self, cols, lines) -> None:
assert cols > 1
assert lines > 1
self.cols = cols
@ -53,19 +53,19 @@ class AsciiCanvas:
def draws(self):
return "\n".join(self.get_lines())
def draw(self):
def draw(self) -> None:
"""Draws ASCII canvas on the screen."""
lines = self.get_lines()
print("\n".join(lines)) # noqa: T201
def point(self, x, y, char):
def point(self, x, y, char) -> None:
"""Create a point on ASCII canvas."""
assert len(char) == 1
assert 0 <= x < self.cols
assert 0 <= y < self.lines
self.canvas[y][x] = char
def line(self, x0, y0, x1, y1, char):
def line(self, x0, y0, x1, y1, char) -> None:
"""Create a line on ASCII canvas."""
if x0 > x1:
x1, x0 = x0, x1
@ -85,12 +85,12 @@ class AsciiCanvas:
x = x0 + int(round((y - y0) * dx / float(dy))) if dy else x0
self.point(x, y, char)
def text(self, x, y, text):
def text(self, x, y, text) -> None:
"""Print a text on ASCII canvas."""
for i, char in enumerate(text):
self.point(x + i, y, char)
def box(self, x0, y0, width, height):
def box(self, x0, y0, width, height) -> None:
"""Create a box on ASCII canvas."""
assert width > 1
assert height > 1

View file

@ -201,7 +201,7 @@ class Graph:
graph_dict["endpoint_name"] = str(endpoint_name)
return graph_dict
def add_nodes_and_edges(self, nodes: list[NodeData], edges: list[EdgeData]):
def add_nodes_and_edges(self, nodes: list[NodeData], edges: list[EdgeData]) -> None:
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
@ -238,7 +238,7 @@ class Graph:
return component_id
def _set_start_and_end(self, start: Component, end: Component):
def _set_start_and_end(self, start: Component, end: Component) -> None:
if not hasattr(start, "to_frontend_node"):
msg = f"start must be a Component. Got {type(start)}"
raise TypeError(msg)
@ -248,7 +248,7 @@ class Graph:
self.add_component(start, start._id)
self.add_component(end, end._id)
def add_component_edge(self, source_id: str, output_input_tuple: tuple[str, str], target_id: str):
def add_component_edge(self, source_id: str, output_input_tuple: tuple[str, str], target_id: str) -> None:
source_vertex = self.get_vertex(source_id)
if not isinstance(source_vertex, ComponentVertex):
msg = f"Source vertex {source_id} is not a component vertex."
@ -337,7 +337,7 @@ class Graph:
"run_manager": copy.deepcopy(self.run_manager.to_dict()),
}
def __apply_config(self, config: StartConfigDict):
def __apply_config(self, config: StartConfigDict) -> None:
for vertex in self.vertices:
if vertex._custom_component is None:
continue
@ -373,7 +373,7 @@ class Graph:
except StopAsyncIteration:
break
def _add_edge(self, edge: EdgeData):
def _add_edge(self, edge: EdgeData) -> None:
self.add_edge(edge)
source_id = edge["data"]["sourceHandle"]["id"]
target_id = edge["data"]["targetHandle"]["id"]
@ -382,16 +382,16 @@ class Graph:
self.in_degree_map[target_id] += 1
self.parent_child_map[source_id].append(target_id)
def add_node(self, node: NodeData):
def add_node(self, node: NodeData) -> None:
self._vertices.append(node)
def add_edge(self, edge: EdgeData):
def add_edge(self, edge: EdgeData) -> None:
# Check if the edge already exists
if edge in self._edges:
return
self._edges.append(edge)
def initialize(self):
def initialize(self) -> None:
self._build_graph()
self.build_graph_maps(self.edges)
self.define_vertices_lists()
@ -424,7 +424,7 @@ class Graph:
self.state_manager.update_state(name, record, run_id=self._run_id)
def activate_state_vertices(self, name: str, caller: str):
def activate_state_vertices(self, name: str, caller: str) -> None:
"""Activates the state vertices in the graph with the given name and caller.
Args:
@ -473,7 +473,7 @@ class Graph:
vertices_to_run=self.vertices_to_run,
)
def reset_activated_vertices(self):
def reset_activated_vertices(self) -> None:
"""Resets the activated vertices in the graph."""
self.activated_vertices = []
@ -490,7 +490,7 @@ class Graph:
self.state_manager.append_state(name, record, run_id=self._run_id)
def validate_stream(self):
def validate_stream(self) -> None:
"""Validates the stream configuration of the graph.
If there are two vertices in the same graph (connected by edges)
@ -548,7 +548,7 @@ class Graph:
raise ValueError(msg)
return self._run_id
def set_run_id(self, run_id: uuid.UUID | None = None):
def set_run_id(self, run_id: uuid.UUID | None = None) -> None:
"""Sets the ID of the current run.
Args:
@ -564,7 +564,7 @@ class Graph:
if self.tracing_service:
self.tracing_service.set_run_id(run_id)
def set_run_name(self):
def set_run_name(self) -> None:
# Given a flow name, flow_id
if not self.tracing_service:
return
@ -573,16 +573,16 @@ class Graph:
self.set_run_id()
self.tracing_service.set_run_name(name)
async def initialize_run(self):
async def initialize_run(self) -> None:
if self.tracing_service:
await self.tracing_service.initialize_tracers()
def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None):
def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
task = asyncio.create_task(self.end_all_traces(outputs, error))
self._end_trace_tasks.add(task)
task.add_done_callback(self._end_trace_tasks.discard)
async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None):
async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
if not self.tracing_service:
return
self._end_time = datetime.now(timezone.utc)
@ -602,7 +602,7 @@ class Graph:
self.sort_vertices()
return self._sorted_vertices_layers
def define_vertices_lists(self):
def define_vertices_lists(self) -> None:
"""Defines the lists of vertices that are inputs, outputs, and have session_id."""
attributes = ["is_input", "is_output", "has_session_id", "is_state"]
for vertex in self.vertices:
@ -610,7 +610,7 @@ class Graph:
if getattr(vertex, attribute):
getattr(self, f"_{attribute}_vertices").append(vertex.id)
def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None):
def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None) -> None:
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
# If the vertex is not in the input_components list
@ -838,7 +838,7 @@ class Graph:
"flow_name": self.flow_name,
}
def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list[Vertex] | None = None):
def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list[Vertex] | None = None) -> None:
"""Builds the adjacency maps for the graph."""
if edges is None:
edges = self.edges
@ -851,26 +851,28 @@ class Graph:
self.in_degree_map = self.build_in_degree(edges)
self.parent_child_map = self.build_parent_child_map(vertices)
def reset_inactivated_vertices(self):
def reset_inactivated_vertices(self) -> None:
"""Resets the inactivated vertices in the graph."""
for vertex_id in self.inactivated_vertices.copy():
self.mark_vertex(vertex_id, "ACTIVE")
self.inactivated_vertices = []
self.inactivated_vertices = set()
self.inactivated_vertices = set()
def mark_all_vertices(self, state: str):
def mark_all_vertices(self, state: str) -> None:
"""Marks all vertices in the graph."""
for vertex in self.vertices:
vertex.set_state(state)
def mark_vertex(self, vertex_id: str, state: str):
def mark_vertex(self, vertex_id: str, state: str) -> None:
"""Marks a vertex in the graph."""
vertex = self.get_vertex(vertex_id)
vertex.set_state(state)
if state == VertexStates.INACTIVE:
self.run_manager.remove_from_predecessors(vertex_id)
def _mark_branch(self, vertex_id: str, state: str, visited: set | None = None, output_name: str | None = None):
def _mark_branch(
self, vertex_id: str, state: str, visited: set | None = None, output_name: str | None = None
) -> None:
"""Marks a branch of the graph."""
if visited is None:
visited = set()
@ -889,7 +891,7 @@ class Graph:
continue
self._mark_branch(child_id, state, visited)
def mark_branch(self, vertex_id: str, state: str, output_name: str | None = None):
def mark_branch(self, vertex_id: str, state: str, output_name: str | None = None) -> None:
self._mark_branch(vertex_id=vertex_id, state=state, output_name=output_name)
new_predecessor_map, _ = self.build_adjacency_maps(self.edges)
self.run_manager.update_run_state(
@ -910,10 +912,10 @@ class Graph:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
return parent_child_map
def increment_run_count(self):
def increment_run_count(self) -> None:
self._runs += 1
def increment_update_count(self):
def increment_update_count(self) -> None:
self._updates += 1
def __getstate__(self):
@ -1239,7 +1241,7 @@ class Graph:
return None
return self._run_queue.popleft()
def extend_run_queue(self, vertices: list[str]):
def extend_run_queue(self, vertices: list[str]) -> None:
self._run_queue.extend(vertices)
async def astep(
@ -1292,7 +1294,7 @@ class Graph:
}
)
def _record_snapshot(self, vertex_id: str | None = None):
def _record_snapshot(self, vertex_id: str | None = None) -> None:
self._snapshots.append(self.get_snapshot())
if vertex_id:
self._call_order.append(vertex_id)
@ -1557,7 +1559,7 @@ class Graph:
state = dict.fromkeys(self.vertices, 0)
sorted_vertices = []
def dfs(vertex):
def dfs(vertex) -> None:
if state[vertex] == 1:
# We have a cycle
msg = "Graph contains a cycle, cannot perform topological sort"
@ -1770,7 +1772,7 @@ class Graph:
children.append(vertex)
return children
def __repr__(self):
def __repr__(self) -> str:
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f" {edge.source_id} --> {edge.target_id}" for edge in self.edges])
@ -2009,7 +2011,7 @@ class Graph:
is_active = self.get_vertex(vertex_id).is_active()
return self.run_manager.is_vertex_runnable(vertex_id, is_active=is_active)
def build_run_map(self):
def build_run_map(self) -> None:
"""Builds the run map for the graph.
This method is responsible for building the run map for the graph,
@ -2036,7 +2038,7 @@ class Graph:
runnable_vertices = []
visited = set()
def find_runnable_predecessors(predecessor: Vertex):
def find_runnable_predecessors(predecessor: Vertex) -> None:
predecessor_id = predecessor.id
if predecessor_id in visited:
return
@ -2052,10 +2054,10 @@ class Graph:
find_runnable_predecessors(self.get_vertex(predecessor_id))
return runnable_vertices
def remove_from_predecessors(self, vertex_id: str):
def remove_from_predecessors(self, vertex_id: str) -> None:
self.run_manager.remove_from_predecessors(vertex_id)
def remove_vertex_from_runnables(self, vertex_id: str):
def remove_vertex_from_runnables(self, vertex_id: str) -> None:
self.run_manager.remove_vertex_from_runnables(vertex_id)
def get_top_level_vertices(self, vertices_ids):

View file

@ -3,7 +3,7 @@ from langflow.utils.lazy_load import LazyLoadDictBase
class Finish:
def __bool__(self):
def __bool__(self) -> bool:
return True
def __eq__(self, other):
@ -17,7 +17,7 @@ def _import_vertex_types():
class VertexTypesDict(LazyLoadDictBase):
def __init__(self):
def __init__(self) -> None:
self._all_types_dict = None
self._types = _import_vertex_types

View file

@ -2,11 +2,11 @@ from collections import defaultdict
class RunnableVerticesManager:
def __init__(self):
self.run_map = defaultdict(list) # Tracks successors of each vertex
self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex
self.vertices_to_run = set() # Set of vertices that are ready to run
self.vertices_being_run = set() # Set of vertices that are currently running
def __init__(self) -> None:
self.run_map: dict[str, list[str]] = defaultdict(list) # Tracks successors of each vertex
self.run_predecessors: dict[str, set[str]] = defaultdict(set) # Tracks predecessors for each vertex
self.vertices_to_run: set[str] = set() # Set of vertices that are ready to run
self.vertices_being_run: set[str] = set() # Set of vertices that are currently running
def to_dict(self) -> dict:
return {
@ -42,7 +42,7 @@ class RunnableVerticesManager:
def all_predecessors_are_fulfilled(self) -> bool:
return all(not value for value in self.run_predecessors.values())
def update_run_state(self, run_predecessors: dict, vertices_to_run: set):
def update_run_state(self, run_predecessors: dict, vertices_to_run: set) -> None:
self.run_predecessors.update(run_predecessors)
self.vertices_to_run.update(vertices_to_run)
self.build_run_map(self.run_predecessors, self.vertices_to_run)
@ -60,14 +60,14 @@ class RunnableVerticesManager:
def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool:
return not any(self.run_predecessors.get(vertex_id, []))
def remove_from_predecessors(self, vertex_id: str):
def remove_from_predecessors(self, vertex_id: str) -> None:
"""Removes a vertex from the predecessor list of its successors."""
predecessors = self.run_map.get(vertex_id, [])
for predecessor in predecessors:
if vertex_id in self.run_predecessors[predecessor]:
self.run_predecessors[predecessor].remove(vertex_id)
def build_run_map(self, predecessor_map, vertices_to_run):
def build_run_map(self, predecessor_map, vertices_to_run) -> None:
"""Builds a map of vertices and their runnable successors."""
self.run_map = defaultdict(list)
for vertex_id, predecessors in predecessor_map.items():
@ -76,16 +76,16 @@ class RunnableVerticesManager:
self.run_predecessors = predecessor_map.copy()
self.vertices_to_run = vertices_to_run
def update_vertex_run_state(self, vertex_id: str, *, is_runnable: bool):
def update_vertex_run_state(self, vertex_id: str, *, is_runnable: bool) -> None:
"""Updates the runnable state of a vertex."""
if is_runnable:
self.vertices_to_run.add(vertex_id)
else:
self.vertices_being_run.discard(vertex_id)
def remove_vertex_from_runnables(self, v_id):
def remove_vertex_from_runnables(self, v_id) -> None:
self.update_vertex_run_state(v_id, is_runnable=False)
self.remove_from_predecessors(v_id)
def add_to_vertices_being_run(self, v_id):
def add_to_vertices_being_run(self, v_id) -> None:
self.vertices_being_run.add(v_id)

View file

@ -13,7 +13,7 @@ if TYPE_CHECKING:
class GraphStateManager:
def __init__(self):
def __init__(self) -> None:
try:
self.state_service: StateService = get_state_service()
except Exception: # noqa: BLE001
@ -22,26 +22,14 @@ class GraphStateManager:
self.state_service = InMemoryStateService(get_settings_service())
def append_state(self, key, new_state, run_id: str):
def append_state(self, key, new_state, run_id: str) -> None:
self.state_service.append_state(key, new_state, run_id)
def update_state(self, key, new_state, run_id: str):
def update_state(self, key, new_state, run_id: str) -> None:
self.state_service.update_state(key, new_state, run_id)
def get_state(self, key, run_id: str):
return self.state_service.get_state(key, run_id)
def subscribe(self, key, observer: Callable):
def subscribe(self, key, observer: Callable) -> None:
self.state_service.subscribe(key, observer)
def notify_observers(self, key, new_state):
for callback in self.observers[key]:
callback(key, new_state, append=False)
def notify_append_observers(self, key, new_state):
for callback in self.observers[key]:
try:
callback(key, new_state, append=True)
except Exception: # noqa: BLE001
logger.exception(f"Error in observer {callback} for key {key}")
logger.warning("Callbacks not implemented yet")

View file

@ -27,13 +27,13 @@ def find_last_node(nodes, edges):
return next((n for n in nodes if all(e["source"] != n["id"] for e in edges)), None)
def add_parent_node_id(nodes, parent_node_id):
def add_parent_node_id(nodes, parent_node_id) -> None:
"""This function receives a list of nodes and adds a parent_node_id to each node."""
for node in nodes:
node["parent_node_id"] = parent_node_id
def add_frozen(nodes, frozen):
def add_frozen(nodes, frozen) -> None:
"""This function receives a list of nodes and adds a frozen to each node."""
for node in nodes:
node["data"]["node"]["frozen"] = frozen
@ -75,7 +75,7 @@ def process_flow(flow_object):
cloned_flow = copy.deepcopy(flow_object)
processed_nodes = set() # To keep track of processed nodes
def process_node(node):
def process_node(node) -> None:
node_id = node.get("id")
# If node already processed, skip
@ -100,7 +100,7 @@ def process_flow(flow_object):
return cloned_flow
def update_template(template, g_nodes):
def update_template(template, g_nodes) -> None:
"""Updates the template of a node in a graph with the given template.
Args:
@ -149,7 +149,7 @@ def update_target_handle(new_edge, g_nodes):
return new_edge
def set_new_target_handle(proxy_id, new_edge, target_handle, node):
def set_new_target_handle(proxy_id, new_edge, target_handle, node) -> None:
"""Sets a new target handle for a given edge.
Args:
@ -330,7 +330,7 @@ def has_cycle(vertex_ids: list[str], edges: list[tuple[str, str]]) -> bool:
graph[u].append(v)
# Utility function to perform DFS
def dfs(v, visited, rec_stack):
def dfs(v, visited, rec_stack) -> bool:
visited.add(v)
rec_stack.add(v)

View file

@ -126,10 +126,10 @@ def build_output_setter(method: Callable, *, validate: bool = True) -> Callable:
>>> print(component.get_output_by_method(component.set_message).value) # Prints "New message"
"""
def output_setter(self, value): # noqa: ARG001
def output_setter(self, value) -> None: # noqa: ARG001
if validate:
__validate_method(method)
methods_class = method.__self__
methods_class = method.__self__ # type: ignore[attr-defined]
output = methods_class.get_output_by_method(method)
output.value = value

View file

@ -172,7 +172,7 @@ def log_vertex_build(
params: Any,
data: ResultDataResponse,
artifacts: dict | None = None,
):
) -> None:
try:
if not get_settings_service().settings.vertex_builds_storage_enabled:
return

View file

@ -75,8 +75,8 @@ class Vertex:
self.base_type: str | None = base_type
self.outputs: list[dict] = []
self._parse_data()
self._built_object = UnbuiltObject()
self._built_result = None
self._built_object: Any = UnbuiltObject()
self._built_result: Any = None
self._built = False
self._successors_ids: list[str] | None = None
self.artifacts: dict[str, Any] = {}
@ -106,7 +106,7 @@ class Vertex:
self.state = VertexStates.ACTIVE
self.log_transaction_tasks: set[asyncio.Task] = set()
def set_input_value(self, name: str, value: Any):
def set_input_value(self, name: str, value: Any) -> None:
if self._custom_component is None:
msg = f"Vertex {self.id} does not have a component instance."
raise ValueError(msg)
@ -115,20 +115,20 @@ class Vertex:
def to_data(self):
return self._data
def add_component_instance(self, component_instance: Component):
def add_component_instance(self, component_instance: Component) -> None:
component_instance.set_vertex(self)
self._custom_component = component_instance
def add_result(self, name: str, result: Any):
def add_result(self, name: str, result: Any) -> None:
self.results[name] = result
def update_graph_state(self, key, new_state, *, append: bool):
def update_graph_state(self, key, new_state, *, append: bool) -> None:
if append:
self.graph.append_state(key, new_state, caller=self.id)
else:
self.graph.update_state(key, new_state, caller=self.id)
def set_state(self, state: str):
def set_state(self, state: str) -> None:
self.state = VertexStates[state]
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] <= 1:
# If the vertex is inactive and has only one in degree
@ -144,7 +144,7 @@ class Vertex:
def avg_build_time(self):
return sum(self.build_times) / len(self.build_times) if self.build_times else 0
def add_build_time(self, time):
def add_build_time(self, time) -> None:
self.build_times.append(time)
def set_result(self, result: ResultData) -> None:
@ -300,7 +300,7 @@ class Vertex:
params[param_key] = self.graph.get_vertex(edge.source_id)
return params
def _build_params(self):
def _build_params(self) -> None:
# sourcery skip: merge-list-append, remove-redundant-if
# Some params are required, some are optional
# but most importantly, some params are python base classes
@ -326,7 +326,7 @@ class Vertex:
return
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
params = {}
params: dict = {}
for edge in self.edges:
if not hasattr(edge, "target_param"):
@ -438,7 +438,7 @@ class Vertex:
self.load_from_db_fields = load_from_db_fields
self._raw_params = params.copy()
def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False):
def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False) -> None:
"""Update the raw parameters of the vertex with the given new parameters.
Args:
@ -466,7 +466,7 @@ class Vertex:
"""Checks if the vertex has any cycle edges."""
return self._has_cycle_edges
async def instantiate_component(self, user_id=None):
async def instantiate_component(self, user_id=None) -> None:
if not self._custom_component:
self._custom_component, _ = await initialize.loading.instantiate_class(
user_id=user_id,
@ -478,7 +478,7 @@ class Vertex:
fallback_to_env_vars,
user_id=None,
event_manager: EventManager | None = None,
):
) -> None:
"""Initiate the build process."""
logger.debug(f"Building {self.display_name}")
await self._build_each_vertex_in_params_dict()
@ -500,6 +500,7 @@ class Vertex:
custom_component=custom_component,
custom_params=custom_params,
fallback_to_env_vars=fallback_to_env_vars,
base_type=self.base_type,
)
self._validate_built_object()
@ -545,7 +546,7 @@ class Vertex:
return messages
def _finalize_build(self):
def _finalize_build(self) -> None:
result_dict = self.get_built_result()
# We need to set the artifacts to pass information
# to the frontend
@ -563,7 +564,7 @@ class Vertex:
)
self.set_result(result_dict)
async def _build_each_vertex_in_params_dict(self):
async def _build_each_vertex_in_params_dict(self) -> None:
"""Iterates over each vertex in the params dictionary and builds it."""
for key, value in self._raw_params.items():
if self._is_vertex(value):
@ -588,7 +589,7 @@ class Vertex:
self,
key,
vertices_dict: dict[str, Vertex],
):
) -> None:
"""Iterates over a dictionary of vertices, builds each and updates the params dictionary."""
for sub_key, value in vertices_dict.items():
if not self._is_vertex(value):
@ -647,7 +648,7 @@ class Vertex:
self._log_transaction_async(str(flow_id), source=self, target=requester, status="success")
return result
async def _build_vertex_and_update_params(self, key, vertex: Vertex):
async def _build_vertex_and_update_params(self, key, vertex: Vertex) -> None:
"""Builds a given vertex and updates the params dictionary accordingly."""
result = await vertex.get_result(self, target_handle_name=key)
self._handle_func(key, result)
@ -659,7 +660,7 @@ class Vertex:
self,
key,
vertices: list[Vertex],
):
) -> None:
"""Iterates over a list of vertices, builds each and updates the params dictionary."""
self.params[key] = []
for vertex in vertices:
@ -685,7 +686,7 @@ class Vertex:
)
raise ValueError(msg) from e
def _handle_func(self, key, result):
def _handle_func(self, key, result) -> None:
"""Handles 'func' key by checking if the result is a function and setting it as coroutine."""
if key == "func":
if not isinstance(result, types.FunctionType):
@ -698,19 +699,21 @@ class Vertex:
else:
self.params["coroutine"] = sync_to_async(result)
def _extend_params_list_with_result(self, key, result):
def _extend_params_list_with_result(self, key, result) -> None:
"""Extends a list in the params dictionary with the given result if it exists."""
if isinstance(self.params[key], list):
self.params[key].extend(result)
async def _build_results(self, custom_component, custom_params, *, fallback_to_env_vars=False):
async def _build_results(
self, custom_component, custom_params, base_type: str, *, fallback_to_env_vars=False
) -> None:
try:
result = await initialize.loading.get_instance_results(
custom_component=custom_component,
custom_params=custom_params,
vertex=self,
fallback_to_env_vars=fallback_to_env_vars,
base_type=self.base_type,
base_type=base_type,
)
self.outputs_logs = build_output_logs(self, result)
@ -722,7 +725,7 @@ class Vertex:
msg = f"Error building Component {self.display_name}: \n\n{exc}"
raise ComponentBuildError(msg, tb) from exc
def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]):
def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]) -> None:
"""Updates the built object and its artifacts."""
if isinstance(result, tuple):
if len(result) == 2: # noqa: PLR2004
@ -738,7 +741,7 @@ class Vertex:
else:
self._built_object = result
def _validate_built_object(self):
def _validate_built_object(self) -> None:
"""Checks if the built object is None and raises a ValueError if so."""
if isinstance(self._built_object, UnbuiltObject):
msg = f"{self.display_name}: {self._built_object_repr()}"
@ -754,7 +757,7 @@ class Vertex:
msg = f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead."
raise ValueError(msg)
def _reset(self):
def _reset(self) -> None:
self._built = False
self._built_object = UnbuiltObject()
self._built_result = UnbuiltResult()
@ -762,10 +765,10 @@ class Vertex:
self.steps_ran = []
self._build_params()
def _is_chat_input(self):
def _is_chat_input(self) -> bool:
return False
def build_inactive(self):
def build_inactive(self) -> None:
# Just set the results to None
self._built = True
self._built_object = None
@ -865,11 +868,11 @@ class Vertex:
def __hash__(self) -> int:
return id(self)
def _built_object_repr(self):
def _built_object_repr(self) -> str:
# Add a message with an emoji, stars for sucess,
return "Built successfully ✨" if self._built_object is not None else "Failed to build 😵‍💫"
def apply_on_outputs(self, func: Callable[[Any], Any]):
def apply_on_outputs(self, func: Callable[[Any], Any]) -> None:
"""Applies a function to the outputs of the vertex."""
if not self._custom_component or not self._custom_component.outputs:
return

View file

@ -57,7 +57,7 @@ class ComponentVertex(Vertex):
return self.artifacts["repr"] or super()._built_object_repr()
return None
def _update_built_object_and_artifacts(self, result):
def _update_built_object_and_artifacts(self, result) -> None:
"""Updates the built object and its artifacts."""
if isinstance(result, tuple):
if len(result) == 2: # noqa: PLR2004
@ -182,7 +182,7 @@ class ComponentVertex(Vertex):
)
return messages
def _finalize_build(self):
def _finalize_build(self) -> None:
result_dict = self.get_built_result()
# We need to set the artifacts to pass information
# to the frontend
@ -206,7 +206,7 @@ class InterfaceVertex(ComponentVertex):
self.steps = [self._build, self._run]
self.is_interface_component = True
def build_stream_url(self):
def build_stream_url(self) -> str:
return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream"
def _built_object_repr(self):
@ -352,21 +352,17 @@ class InterfaceVertex(ComponentVertex):
self.artifacts = DataOutputResponse(data=artifacts)
return self._built_object
async def _run(self, *args, **kwargs):
if self.is_interface_component:
if self.vertex_type in CHAT_COMPONENTS:
message = self._process_chat_component()
elif self.vertex_type in RECORDS_COMPONENTS:
message = self._process_data_component()
if isinstance(self._built_object, AsyncIterator | Iterator):
if self.params.get("return_data", False):
self._built_object = Data(text=message, data=self.artifacts)
else:
self._built_object = message
self._built_result = self._built_object
else:
await super()._run(*args, **kwargs)
async def _run(self, *args, **kwargs) -> None: # noqa: ARG002
if self.vertex_type in CHAT_COMPONENTS:
message = self._process_chat_component()
elif self.vertex_type in RECORDS_COMPONENTS:
message = self._process_data_component()
if isinstance(self._built_object, AsyncIterator | Iterator):
if self.params.get("return_data", False):
self._built_object = Data(text=message, data=self.artifacts)
else:
self._built_object = message
self._built_result = self._built_object
async def stream(self):
iterator = self.params.get(INPUT_FIELD_NAME, None)
@ -452,7 +448,7 @@ class InterfaceVertex(ComponentVertex):
self._validate_built_object()
self._built = True
async def consume_async_generator(self):
async def consume_async_generator(self) -> None:
async for _ in self.stream():
pass

View file

@ -12,6 +12,7 @@ from uuid import UUID
import orjson
from emoji import demojize, purely_emoji
from loguru import logger
from sqlalchemy.exc import NoResultFound
from sqlmodel import select
from langflow.base.constants import (
@ -340,7 +341,7 @@ def update_edges_with_latest_component_versions(project_data):
return project_data_copy
def log_node_changes(node_changes_log):
def log_node_changes(node_changes_log) -> None:
# The idea here is to log the changes that were made to the nodes in debug
# Something like:
# Node: "Node Name" was updated with the following changes:
@ -377,8 +378,11 @@ def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]:
return starter_projects
def copy_profile_pictures():
def copy_profile_pictures() -> None:
config_dir = get_storage_service().settings_service.settings.config_dir
if config_dir is None:
msg = "Config dir is not set in the settings"
raise ValueError(msg)
origin = Path(__file__).parent / "profile_pictures"
target = Path(config_dir) / "profile_pictures"
@ -425,7 +429,7 @@ def get_project_data(project):
)
def update_project_file(project_path: Path, project: dict, updated_project_data):
def update_project_file(project_path: Path, project: dict, updated_project_data) -> None:
project["data"] = updated_project_data
project_path.write_text(orjson.dumps(project, option=ORJSON_OPTIONS).decode(), encoding="utf-8")
logger.info(f"Updated starter project {project['name']} file")
@ -440,7 +444,7 @@ def update_existing_project(
project_data,
project_icon,
project_icon_bg_color,
):
) -> None:
logger.info(f"Updating starter project {project_name}")
existing_project.data = project_data
existing_project.folder = STARTER_FOLDER_NAME
@ -463,7 +467,7 @@ def create_new_project(
project_icon,
project_icon_bg_color,
new_folder_id,
):
) -> None:
logger.debug(f"Creating starter project {project_name}")
new_project = FlowCreate(
name=project_name,
@ -485,7 +489,7 @@ def get_all_flows_similar_to_project(session, folder_id):
return session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
def delete_start_projects(session, folder_id):
def delete_start_projects(session, folder_id) -> None:
flows = session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
for flow in flows:
session.delete(flow)
@ -516,7 +520,7 @@ def _is_valid_uuid(val):
return str(uuid_obj) == val
def load_flows_from_directory():
def load_flows_from_directory() -> None:
"""On langflow startup, this loads all flows from the directory specified in the settings.
All flows are uploaded into the default folder for the superuser.
@ -531,7 +535,11 @@ def load_flows_from_directory():
return
with session_scope() as session:
user_id = get_user_by_username(session, settings_service.auth_settings.SUPERUSER).id
user = get_user_by_username(session, settings_service.auth_settings.SUPERUSER)
if user is None:
msg = "Superuser not found in the database"
raise NoResultFound(msg)
user_id = user.id
_flows_path = Path(flows_path)
files = [f for f in _flows_path.iterdir() if f.is_file()]
for f in files:
@ -592,7 +600,7 @@ def find_existing_flow(session, flow_id, flow_endpoint_name):
return None
async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]):
async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]) -> None:
try:
all_types_dict = await get_all_components_coro
except Exception:
@ -647,7 +655,7 @@ async def create_or_update_starter_projects(get_all_components_coro: Awaitable[d
)
def initialize_super_user_if_needed():
def initialize_super_user_if_needed() -> None:
settings_service = get_settings_service()
if not settings_service.auth_settings.AUTO_LOGIN:
return

View file

@ -3,7 +3,7 @@ from langflow.utils.lazy_load import LazyLoadDictBase
class AllTypesDict(LazyLoadDictBase):
def __init__(self):
def __init__(self) -> None:
self._all_types_dict = None
def _build_dict(self):

View file

@ -19,7 +19,7 @@ def get_memory_key(langchain_object):
return None # or some other default value or action
def update_memory_keys(langchain_object, possible_new_mem_key):
def update_memory_keys(langchain_object, possible_new_mem_key) -> None:
"""Update the memory keys in the LangChain object's memory attribute.
Given a LangChain object and a possible new memory key, this function updates the input and output keys in the

View file

@ -89,7 +89,7 @@ def extract_input_variables_from_prompt(prompt: str) -> list[str]:
return variables
def setup_llm_caching():
def setup_llm_caching() -> None:
"""Setup LLM caching."""
settings_service = get_settings_service()
try:
@ -100,7 +100,7 @@ def setup_llm_caching():
logger.opt(exception=True).warning("Could not setup LLM caching.")
def set_langchain_cache(settings):
def set_langchain_cache(settings) -> None:
from langchain.globals import set_llm_cache
from langflow.interface.importing.utils import import_class

View file

@ -42,7 +42,7 @@ class SizedLogBuffer:
def get_write_lock(self) -> Lock:
return self._wlock
def write(self, message: str):
def write(self, message: str) -> None:
record = json.loads(message)
log_entry = record["text"]
epoch = int(record["record"]["time"]["timestamp"] * 1000)
@ -52,7 +52,7 @@ class SizedLogBuffer:
self.buffer.popleft()
self.buffer.append((epoch, log_entry))
def __len__(self):
def __len__(self) -> int:
return len(self.buffer)
def get_after_timestamp(self, timestamp: int, lines: int = 5) -> dict[int, str]:
@ -123,7 +123,7 @@ def serialize_log(record):
return orjson.dumps(subset)
def patching(record):
def patching(record) -> None:
record["extra"]["serialized"] = serialize_log(record)
if DEV is False:
record.pop("exception", None)
@ -142,7 +142,7 @@ def configure(
log_file: Path | None = None,
disable: bool | None = False,
log_env: str | None = None,
):
) -> None:
if disable and log_level is None and log_file is None:
logger.disable("langflow")
if os.getenv("LANGFLOW_LOG_LEVEL", "").upper() in VALID_LOG_LEVELS and log_level is None:
@ -205,14 +205,14 @@ def configure(
setup_gunicorn_logger()
def setup_uvicorn_logger():
def setup_uvicorn_logger() -> None:
loggers = (logging.getLogger(name) for name in logging.root.manager.loggerDict if name.startswith("uvicorn."))
for uvicorn_logger in loggers:
uvicorn_logger.handlers = []
logging.getLogger("uvicorn").handlers = [InterceptHandler()]
def setup_gunicorn_logger():
def setup_gunicorn_logger() -> None:
logging.getLogger("gunicorn.error").handlers = [InterceptHandler()]
logging.getLogger("gunicorn.access").handlers = [InterceptHandler()]
@ -223,7 +223,7 @@ class InterceptHandler(logging.Handler):
See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging.
"""
def emit(self, record):
def emit(self, record) -> None:
# Get corresponding Loguru level if it exists
try:
level = logger.level(record.levelname).name
@ -232,7 +232,7 @@ class InterceptHandler(logging.Handler):
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__:
while frame.f_code.co_filename == logging.__file__ and frame.f_back:
frame = frame.f_back
depth += 1

View file

@ -3,14 +3,14 @@ from loguru import logger
LOGGING_CONFIGURED = False
def disable_logging():
def disable_logging() -> None:
global LOGGING_CONFIGURED # noqa: PLW0603
if not LOGGING_CONFIGURED:
logger.disable("langflow")
LOGGING_CONFIGURED = True
def enable_logging():
def enable_logging() -> None:
global LOGGING_CONFIGURED # noqa: PLW0603
logger.enable("langflow")
LOGGING_CONFIGURED = True

View file

@ -40,7 +40,7 @@ MAX_PORT = 65535
class RequestCancelledMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
def __init__(self, app) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next):
@ -224,7 +224,7 @@ def create_app():
return app
def setup_sentry(app: FastAPI):
def setup_sentry(app: FastAPI) -> None:
settings = get_settings_service().settings
if settings.sentry_dsn:
import sentry_sdk
@ -238,7 +238,7 @@ def setup_sentry(app: FastAPI):
app.add_middleware(SentryAsgiMiddleware)
def setup_static_files(app: FastAPI, static_files_dir: Path):
def setup_static_files(app: FastAPI, static_files_dir: Path) -> None:
"""Setup the static files directory.
Args:

View file

@ -86,7 +86,7 @@ def add_messagetables(messages: list[MessageTable], session: Session):
return [MessageRead.model_validate(message, from_attributes=True) for message in messages]
def delete_messages(session_id: str):
def delete_messages(session_id: str) -> None:
"""Delete messages from the monitor service based on the provided session ID.
Args:

View file

@ -36,7 +36,7 @@ def get_langfuse_callback(trace_id):
return None
def flush_langfuse_callback_if_present(callbacks: list[BaseCallbackHandler | CallbackHandler]):
def flush_langfuse_callback_if_present(callbacks: list[BaseCallbackHandler | CallbackHandler]) -> None:
"""If langfuse callback is present, run callback.langfuse.flush()."""
for callback in callbacks:
if hasattr(callback, "langfuse") and hasattr(callback.langfuse, "flush"):

View file

@ -165,7 +165,7 @@ class Data(BaseModel):
msg = f"'{type(self).__name__}' object has no attribute '{key}'"
raise AttributeError(msg) from e
def __setattr__(self, key, value):
def __setattr__(self, key, value) -> None:
"""Set attribute-like values in the data dictionary.
Allows attribute-like setting of values in the data dictionary.
@ -179,7 +179,7 @@ class Data(BaseModel):
else:
self.data[key] = value
def __delattr__(self, key):
def __delattr__(self, key) -> None:
"""Allows attribute-like deletion from the data dictionary."""
if key in {"data", "text_key"} or key.startswith("_"):
super().__delattr__(key)
@ -204,7 +204,7 @@ class Data(BaseModel):
logger.opt(exception=True).debug("Error converting Data to JSON")
return str(self.data)
def __contains__(self, key):
def __contains__(self, key) -> bool:
return key in self.data
def __eq__(self, other):

View file

@ -33,7 +33,7 @@ class dotdict(dict): # noqa: N801
else:
return value
def __setattr__(self, key, value):
def __setattr__(self, key, value) -> None:
"""Override attribute setting to work as dictionary item assignment.
Args:
@ -44,7 +44,7 @@ class dotdict(dict): # noqa: N801
value = dotdict(value)
self[key] = value
def __delattr__(self, key):
def __delattr__(self, key) -> None:
"""Override attribute deletion to work as dictionary item deletion.
Args:

View file

@ -37,10 +37,10 @@ class Tweaks(RootModel):
def __getitem__(self, key):
return self.root[key]
def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
self.root[key] = value
def __delitem__(self, key):
def __delitem__(self, key) -> None:
del self.root[key]
def items(self):

View file

@ -8,7 +8,7 @@ from langflow.services.deps import get_storage_service
IMAGE_ENDPOINT = "/files/images/"
def is_image_file(file_path):
def is_image_file(file_path) -> bool:
try:
with PILImage.open(file_path) as img:
img.verify() # Verify that it is, in fact, an image
@ -61,5 +61,5 @@ class Image(BaseModel):
"image_url": self.to_base64(),
}
def get_url(self):
def get_url(self) -> str:
return f"{IMAGE_ENDPOINT}{self.path}"

View file

@ -94,7 +94,7 @@ class Message(Data):
if "timestamp" not in self.data:
self.data["timestamp"] = self.timestamp
def set_flow_id(self, flow_id: str):
def set_flow_id(self, flow_id: str) -> None:
self.flow_id = flow_id
def to_lc_message(

View file

@ -35,14 +35,14 @@ class Logger(glogging.Logger):
gunicorn logs to loguru.
"""
def __init__(self, cfg):
def __init__(self, cfg) -> None:
super().__init__(cfg)
logging.getLogger("gunicorn.error").handlers = [InterceptHandler()]
logging.getLogger("gunicorn.access").handlers = [InterceptHandler()]
class LangflowApplication(BaseApplication):
def __init__(self, app, options=None):
def __init__(self, app, options=None) -> None:
self.options = options or {}
self.options["worker_class"] = "langflow.server.LangflowUvicornWorker"
@ -50,7 +50,7 @@ class LangflowApplication(BaseApplication):
self.application = app
super().__init__()
def load_config(self):
def load_config(self) -> None:
config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None}
for key, value in config.items():
self.cfg.set(key.lower(), value)

View file

@ -5,7 +5,7 @@ from langflow.services.factory import ServiceFactory
class AuthServiceFactory(ServiceFactory):
name = "auth_service"
def __init__(self):
def __init__(self) -> None:
super().__init__(AuthService)
def create(self, settings_service):

View file

@ -21,8 +21,8 @@ class Service(ABC):
}
return schema
async def teardown(self):
async def teardown(self) -> None:
return
def set_ready(self):
def set_ready(self) -> None:
self.ready = True

View file

@ -60,7 +60,7 @@ class CacheService(Service, Generic[LockType]):
"""Clear all items from the cache."""
@abc.abstractmethod
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Check if the key is in the cache.
Args:
@ -79,7 +79,7 @@ class CacheService(Service, Generic[LockType]):
"""
@abc.abstractmethod
def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
"""Add an item to the cache using the square bracket notation.
Args:
@ -88,7 +88,7 @@ class CacheService(Service, Generic[LockType]):
"""
@abc.abstractmethod
def __delitem__(self, key):
def __delitem__(self, key) -> None:
"""Remove an item from the cache using the square bracket notation.
Args:
@ -147,7 +147,7 @@ class AsyncBaseCacheService(Service, Generic[AsyncLockType]):
"""Clear all items from the cache."""
@abc.abstractmethod
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Check if the key is in the cache.
Args:

View file

@ -11,7 +11,7 @@ from langflow.services.cache.utils import CACHE_MISS
class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
def __init__(self, cache_dir, max_size=None, expiration_time=3600):
def __init__(self, cache_dir, max_size=None, expiration_time=3600) -> None:
self.cache = Cache(cache_dir)
# Let's clear the cache for now to maintain a similar
# behavior as the in-memory cache
@ -40,56 +40,56 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
await self._delete(key) # Log before deleting the expired item
return CACHE_MISS
async def set(self, key, value, lock: asyncio.Lock | None = None):
async def set(self, key, value, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._set(key, value)
else:
await self._set(key, value)
async def _set(self, key, value):
async def _set(self, key, value) -> None:
if self.max_size and len(self.cache) >= self.max_size:
await asyncio.to_thread(self.cache.cull)
item = {"value": pickle.dumps(value) if not isinstance(value, str | bytes) else value, "time": time.time()}
await asyncio.to_thread(self.cache.set, key, item)
async def delete(self, key, lock: asyncio.Lock | None = None):
async def delete(self, key, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._delete(key)
else:
await self._delete(key)
async def _delete(self, key):
async def _delete(self, key) -> None:
await asyncio.to_thread(self.cache.delete, key)
async def clear(self, lock: asyncio.Lock | None = None):
async def clear(self, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._clear()
else:
await self._clear()
async def _clear(self):
async def _clear(self) -> None:
await asyncio.to_thread(self.cache.clear)
async def upsert(self, key, value, lock: asyncio.Lock | None = None):
async def upsert(self, key, value, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._upsert(key, value)
else:
await self._upsert(key, value)
async def _upsert(self, key, value):
async def _upsert(self, key, value) -> None:
existing_value = await self.get(key)
if existing_value is not CACHE_MISS and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
value = existing_value
await self.set(key, value)
def __contains__(self, key):
def __contains__(self, key) -> bool:
return asyncio.run(asyncio.to_thread(self.cache.__contains__, key))
async def teardown(self):
async def teardown(self) -> None:
# Clean up the cache directory
self.cache.clear(retry=True)

View file

@ -12,7 +12,7 @@ if TYPE_CHECKING:
class CacheServiceFactory(ServiceFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(CacheService)
def create(self, settings_service: SettingsService):

View file

@ -36,14 +36,14 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
b = cache["b"]
"""
def __init__(self, max_size=None, expiration_time=60 * 60):
def __init__(self, max_size=None, expiration_time=60 * 60) -> None:
"""Initialize a new InMemoryCache instance.
Args:
max_size (int, optional): Maximum number of items to store in the cache.
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
"""
self._cache = OrderedDict()
self._cache: OrderedDict = OrderedDict()
self._lock = threading.RLock()
self.max_size = max_size
self.expiration_time = expiration_time
@ -72,7 +72,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
self.delete(key)
return None
def set(self, key, value, lock: Union[threading.Lock, None] = None): # noqa: UP007
def set(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
"""Add an item to the cache.
If the cache is full, the least recently used item is evicted.
@ -93,7 +93,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
self._cache[key] = {"value": value, "time": time.time()}
def upsert(self, key, value, lock: Union[threading.Lock, None] = None): # noqa: UP007
def upsert(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
"""Inserts or updates a value in the cache.
If the existing value and the new value are both dictionaries, they are merged.
@ -130,16 +130,16 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
self.set(key, value)
return value
def delete(self, key, lock: Union[threading.Lock, None] = None): # noqa: UP007
def delete(self, key, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
with lock or self._lock:
self._cache.pop(key, None)
def clear(self, lock: Union[threading.Lock, None] = None): # noqa: UP007
def clear(self, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
"""Clear all items from the cache."""
with lock or self._lock:
self._cache.clear()
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Check if the key is in the cache."""
return key in self._cache
@ -147,19 +147,19 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
"""Retrieve an item from the cache using the square bracket notation."""
return self.get(key)
def __setitem__(self, key, value):
def __setitem__(self, key, value) -> None:
"""Add an item to the cache using the square bracket notation."""
self.set(key, value)
def __delitem__(self, key):
def __delitem__(self, key) -> None:
"""Remove an item from the cache using the square bracket notation."""
self.delete(key)
def __len__(self):
def __len__(self) -> int:
"""Return the number of items in the cache."""
return len(self._cache)
def __repr__(self):
def __repr__(self) -> str:
"""Return a string representation of the InMemoryCache instance."""
return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})"
@ -185,7 +185,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
b = cache["b"]
"""
def __init__(self, host="localhost", port=6379, db=0, url=None, expiration_time=60 * 60):
def __init__(self, host="localhost", port=6379, db=0, url=None, expiration_time=60 * 60) -> None:
"""Initialize a new RedisCache instance.
Args:
@ -215,7 +215,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
self.expiration_time = expiration_time
# check connection
def is_connected(self):
def is_connected(self) -> bool:
"""Check if the Redis client is connected."""
import redis
@ -234,7 +234,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
return pickle.loads(value) if value else None
@override
async def set(self, key, value, lock=None):
async def set(self, key, value, lock=None) -> None:
try:
if pickled := pickle.dumps(value):
result = await self._client.setex(str(key), self.expiration_time, pickled)
@ -246,7 +246,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
raise TypeError(msg) from exc
@override
async def upsert(self, key, value, lock=None):
async def upsert(self, key, value, lock=None) -> None:
"""Inserts or updates a value in the cache.
If the existing value and the new value are both dictionaries, they are merged.
@ -266,28 +266,28 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
await self.set(key, value)
@override
async def delete(self, key, lock=None):
async def delete(self, key, lock=None) -> None:
await self._client.delete(key)
@override
async def clear(self, lock=None):
async def clear(self, lock=None) -> None:
"""Clear all items from the cache."""
await self._client.flushdb()
def __contains__(self, key):
def __contains__(self, key) -> bool:
"""Check if the key is in the cache."""
if key is None:
return False
return asyncio.run(self._client.exists(str(key)))
return bool(asyncio.run(self._client.exists(str(key))))
def __repr__(self):
def __repr__(self) -> str:
"""Return a string representation of the RedisCache instance."""
return f"RedisCache(expiration_time={self.expiration_time})"
class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
def __init__(self, max_size=None, expiration_time=3600):
self.cache = OrderedDict()
def __init__(self, max_size=None, expiration_time=3600) -> None:
self.cache: OrderedDict = OrderedDict()
self.lock = asyncio.Lock()
self.max_size = max_size
@ -310,7 +310,7 @@ class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
await self._delete(key) # Log before deleting the expired item
return CACHE_MISS
async def set(self, key, value, lock: asyncio.Lock | None = None):
async def set(self, key, value, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._set(
@ -323,46 +323,46 @@ class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
value,
)
async def _set(self, key, value):
async def _set(self, key, value) -> None:
if self.max_size and len(self.cache) >= self.max_size:
self.cache.popitem(last=False)
self.cache[key] = {"value": value, "time": time.time()}
self.cache.move_to_end(key)
async def delete(self, key, lock: asyncio.Lock | None = None):
async def delete(self, key, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._delete(key)
else:
await self._delete(key)
async def _delete(self, key):
async def _delete(self, key) -> None:
if key in self.cache:
del self.cache[key]
async def clear(self, lock: asyncio.Lock | None = None):
async def clear(self, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._clear()
else:
await self._clear()
async def _clear(self):
async def _clear(self) -> None:
self.cache.clear()
async def upsert(self, key, value, lock: asyncio.Lock | None = None):
async def upsert(self, key, value, lock: asyncio.Lock | None = None) -> None:
if not lock:
async with self.lock:
await self._upsert(key, value)
else:
await self._upsert(key, value)
async def _upsert(self, key, value):
async def _upsert(self, key, value) -> None:
existing_value = await self.get(key)
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
existing_value.update(value)
value = existing_value
await self.set(key, value)
def __contains__(self, key):
def __contains__(self, key) -> bool:
return key in self.cache

View file

@ -19,10 +19,10 @@ PREFIX = "langflow_cache"
class CacheMiss:
def __repr__(self):
def __repr__(self) -> str:
return "<CACHE_MISS>"
def __bool__(self):
def __bool__(self) -> bool:
return False
@ -40,7 +40,7 @@ def create_cache_folder(func):
@create_cache_folder
def clear_old_cache_files(max_cache_size: int = 3):
def clear_old_cache_files(max_cache_size: int = 3) -> None:
cache_dir = Path(tempfile.gettempdir()) / PREFIX
cache_files = list(cache_dir.glob("*.dill"))
@ -155,7 +155,7 @@ def save_uploaded_file(file: UploadFile, folder_name):
return file_path
def update_build_status(cache_service, flow_id: str, status: "BuildStatus"):
def update_build_status(cache_service, flow_id: str, status: "BuildStatus") -> None:
cached_flow = cache_service[flow_id]
if cached_flow is None:
msg = f"Flow {flow_id} not found in cache"

View file

@ -11,18 +11,18 @@ from langflow.services.base import Service
class Subject:
"""Base class for implementing the observer pattern."""
def __init__(self):
def __init__(self) -> None:
self.observers: list[Callable[[], None]] = []
def attach(self, observer: Callable[[], None]):
def attach(self, observer: Callable[[], None]) -> None:
"""Attach an observer to the subject."""
self.observers.append(observer)
def detach(self, observer: Callable[[], None]):
def detach(self, observer: Callable[[], None]) -> None:
"""Detach an observer from the subject."""
self.observers.remove(observer)
def notify(self):
def notify(self) -> None:
"""Notify all observers about an event."""
for observer in self.observers:
if observer is None:
@ -33,18 +33,18 @@ class Subject:
class AsyncSubject:
"""Base class for implementing the async observer pattern."""
def __init__(self):
def __init__(self) -> None:
self.observers: list[Callable[[], Awaitable]] = []
def attach(self, observer: Callable[[], Awaitable]):
def attach(self, observer: Callable[[], Awaitable]) -> None:
"""Attach an observer to the subject."""
self.observers.append(observer)
def detach(self, observer: Callable[[], Awaitable]):
def detach(self, observer: Callable[[], Awaitable]) -> None:
"""Detach an observer from the subject."""
self.observers.remove(observer)
async def notify(self):
async def notify(self) -> None:
"""Notify all observers about an event."""
for observer in self.observers:
if observer is None:
@ -57,11 +57,11 @@ class CacheService(Subject, Service):
name = "cache_service"
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._cache = {}
self.current_client_id = None
self.current_cache = {}
self._cache: dict[str, Any] = {}
self.current_client_id: str | None = None
self.current_cache: dict[str, Any] = {}
@contextmanager
def set_client_id(self, client_id: str):
@ -77,9 +77,9 @@ class CacheService(Subject, Service):
yield
finally:
self.current_client_id = previous_client_id
self.current_cache = self._cache.get(self.current_client_id, {})
self.current_cache = self._cache.setdefault(previous_client_id, {}) if previous_client_id else {}
def add(self, name: str, obj: Any, obj_type: str, extension: str | None = None):
def add(self, name: str, obj: Any, obj_type: str, extension: str | None = None) -> None:
"""Add an object to the current client's cache.
Args:
@ -100,7 +100,7 @@ class CacheService(Subject, Service):
}
self.notify()
def add_pandas(self, name: str, obj: Any):
def add_pandas(self, name: str, obj: Any) -> None:
"""Add a pandas DataFrame or Series to the current client's cache.
Args:
@ -113,7 +113,7 @@ class CacheService(Subject, Service):
msg = "Object is not a pandas DataFrame or Series"
raise TypeError(msg)
def add_image(self, name: str, obj: Any, extension: str = "png"):
def add_image(self, name: str, obj: Any, extension: str = "png") -> None:
"""Add a PIL Image to the current client's cache.
Args:

View file

@ -3,7 +3,7 @@ from langflow.services.factory import ServiceFactory
class ChatServiceFactory(ServiceFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(ChatService)
def create(self):

View file

@ -13,9 +13,9 @@ class ChatService(Service):
name = "chat_service"
def __init__(self):
self._async_cache_locks = defaultdict(asyncio.Lock)
self._sync_cache_locks = defaultdict(RLock)
def __init__(self) -> None:
self._async_cache_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._sync_cache_locks: dict[str, RLock] = defaultdict(RLock)
self.cache_service = get_cache_service()
def _get_lock(self, key: str):
@ -101,7 +101,7 @@ class ChatService(Service):
"""
return await self._perform_cache_operation("get", key, lock=lock or self._get_lock(key))
async def clear_cache(self, key: str, lock: asyncio.Lock | None = None):
async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None:
"""Clear the cache for a client.
Args:

View file

@ -10,7 +10,7 @@ if TYPE_CHECKING:
class DatabaseServiceFactory(ServiceFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(DatabaseService)
def create(self, settings_service: SettingsService):

View file

@ -59,6 +59,6 @@ class ApiKeyRead(ApiKeyBase):
@field_validator("api_key")
@classmethod
def mask_api_key(cls, v):
def mask_api_key(cls, v) -> str:
# This validator will always run, and will mask the API key
return f"{v[:8]}{'*' * (len(v) - 8)}"

View file

@ -45,7 +45,7 @@ class DatabaseService(Service):
self.alembic_cfg_path = langflow_dir / "alembic.ini"
self.engine = self._create_engine()
def reload_engine(self):
def reload_engine(self) -> None:
self.engine = self._create_engine()
def _create_engine(self) -> Engine:
@ -82,13 +82,13 @@ class DatabaseService(Service):
msg = "Error creating database engine"
raise RuntimeError(msg) from exc
def on_connection(self, dbapi_connection, _connection_record):
def on_connection(self, dbapi_connection, _connection_record) -> None:
from sqlite3 import Connection as sqliteConnection
if isinstance(dbapi_connection, sqliteConnection):
pragmas: dict | None = self.settings_service.settings.sqlite_pragmas
pragmas: dict = self.settings_service.settings.sqlite_pragmas or {}
pragmas_list = []
for key, val in pragmas.items() or {}:
for key, val in pragmas.items():
pragmas_list.append(f"PRAGMA {key} = {val}")
logger.info(f"sqlite connection, setting pragmas: {pragmas_list}")
if pragmas_list:
@ -107,7 +107,7 @@ class DatabaseService(Service):
with Session(self.engine) as session:
yield session
def migrate_flows_if_auto_login(self):
def migrate_flows_if_auto_login(self) -> None:
# if auto_login is enabled, we need to migrate the flows
# to the default superuser if they don't have a user id
# associated with them
@ -161,14 +161,14 @@ class DatabaseService(Service):
return True
def init_alembic(self, alembic_cfg):
def init_alembic(self, alembic_cfg) -> None:
logger.info("Initializing alembic")
command.ensure_version(alembic_cfg)
# alembic_cfg.attributes["connection"].commit()
command.upgrade(alembic_cfg, "head")
logger.info("Alembic initialized")
def run_migrations(self, *, fix=False):
def run_migrations(self, *, fix=False) -> None:
# First we need to check if alembic has been initialized
# If not, we need to initialize it
# if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized
@ -227,7 +227,7 @@ class DatabaseService(Service):
if fix:
self.try_downgrade_upgrade_until_success(alembic_cfg)
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5):
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None:
# Try -1 then head, if it fails, try -2 then head, etc.
# until we reach the number of retries
for i in range(1, retries + 1):
@ -273,7 +273,7 @@ class DatabaseService(Service):
results.append(Result(name=column, type="column", success=True))
return results
def create_db_and_tables(self):
def create_db_and_tables(self) -> None:
from sqlalchemy import inspect
inspector = inspect(self.engine)
@ -308,7 +308,7 @@ class DatabaseService(Service):
logger.debug("Database and tables created successfully")
async def teardown(self):
async def teardown(self) -> None:
logger.debug("Tearing down database")
try:
settings_service = get_settings_service()

View file

@ -12,7 +12,7 @@ if TYPE_CHECKING:
from langflow.services.database.service import DatabaseService
def initialize_database(*, fix_migration: bool = False):
def initialize_database(*, fix_migration: bool = False) -> None:
logger.debug("Initializing database")
from langflow.services.deps import get_db_service

View file

@ -15,7 +15,7 @@ class ServiceFactory:
def __init__(
self,
service_class,
):
) -> None:
self.service_class = service_class
self.dependencies = infer_service_types(self, import_all_services_into_a_dict())
@ -23,7 +23,7 @@ class ServiceFactory:
raise self.service_class(*args, **kwargs)
def hash_factory(factory: type[ServiceFactory]) -> str:
def hash_factory(factory: ServiceFactory) -> str:
return factory.service_class.__name__
@ -31,15 +31,15 @@ def hash_dict(d: dict) -> str:
return str(d)
def hash_infer_service_types_args(factory_class: type[ServiceFactory], available_services=None) -> str:
factory_hash = hash_factory(factory_class)
def hash_infer_service_types_args(factory: ServiceFactory, available_services=None) -> str:
factory_hash = hash_factory(factory)
services_hash = hash_dict(available_services)
return f"{factory_hash}_{services_hash}"
@cached(cache=LRUCache(maxsize=10), key=hash_infer_service_types_args)
def infer_service_types(factory_class: type[ServiceFactory], available_services=None) -> list["ServiceType"]:
create_method = factory_class.create
def infer_service_types(factory: ServiceFactory, available_services=None) -> list["ServiceType"]:
create_method = factory.create
type_hints = get_type_hints(create_method, globalns=available_services)
service_types = []
for param_name, param_type in type_hints.items():

View file

@ -22,13 +22,13 @@ class NoFactoryRegisteredError(Exception):
class ServiceManager:
"""Manages the creation of different services."""
def __init__(self):
def __init__(self) -> None:
self.services: dict[str, Service] = {}
self.factories = {}
self.factories: dict[str, ServiceFactory] = {}
self.register_factories()
self.keyed_lock = KeyedMemoryLockManager()
def register_factories(self):
def register_factories(self) -> None:
for factory in self.get_factories():
try:
self.register_factory(factory)
@ -38,7 +38,7 @@ class ServiceManager:
def register_factory(
self,
service_factory: ServiceFactory,
):
) -> None:
"""Registers a new factory with dependencies."""
service_name = service_factory.service_class.name
self.factories[service_name] = service_factory
@ -51,7 +51,7 @@ class ServiceManager:
return self.services[service_name]
def _create_service(self, service_name: ServiceType, default: ServiceFactory | None = None):
def _create_service(self, service_name: ServiceType, default: ServiceFactory | None = None) -> None:
"""Create a new service given its name, handling dependencies."""
logger.debug(f"Create service {service_name}")
self._validate_service_creation(service_name, default)
@ -61,6 +61,9 @@ class ServiceManager:
if factory is None and default is not None:
self.register_factory(default)
factory = default
if factory is None:
msg = f"No factory registered for {service_name}"
raise NoFactoryRegisteredError(msg)
for dependency in factory.dependencies:
if dependency not in self.services:
self._create_service(dependency)
@ -72,20 +75,20 @@ class ServiceManager:
self.services[service_name] = self.factories[service_name].create(**dependent_services)
self.services[service_name].set_ready()
def _validate_service_creation(self, service_name: ServiceType, default: ServiceFactory | None = None):
def _validate_service_creation(self, service_name: ServiceType, default: ServiceFactory | None = None) -> None:
"""Validate whether the service can be created."""
if service_name not in self.factories and default is None:
msg = f"No factory registered for the service class '{service_name.name}'"
raise NoFactoryRegisteredError(msg)
def update(self, service_name: ServiceType):
def update(self, service_name: ServiceType) -> None:
"""Update a service by its name."""
if service_name in self.services:
logger.debug(f"Update service {service_name}")
self.services.pop(service_name, None)
self.get(service_name)
async def teardown(self):
async def teardown(self) -> None:
"""Teardown all the services."""
for service in self.services.values():
if service is None:
@ -131,14 +134,14 @@ class ServiceManager:
service_manager = ServiceManager()
def initialize_settings_service():
def initialize_settings_service() -> None:
"""Initialize the settings manager."""
from langflow.services.settings import factory as settings_factory
service_manager.register_factory(settings_factory.SettingsServiceFactory())
def initialize_session_service():
def initialize_session_service() -> None:
"""Initialize the session manager."""
from langflow.services.cache import factory as cache_factory
from langflow.services.session import factory as session_service_factory

View file

@ -2,10 +2,10 @@ from typing import Any
class BasePlugin:
def initialize(self):
def initialize(self) -> None:
pass
def teardown(self):
def teardown(self) -> None:
pass
def get(self) -> Any:

View file

@ -5,7 +5,7 @@ from langflow.services.plugins.service import PluginService
class PluginServiceFactory(ServiceFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(PluginService)
def create(self):

View file

@ -22,7 +22,7 @@ class LangfuseInstance:
return cls._instance
@classmethod
def create(cls):
def create(cls) -> None:
try:
logger.debug("Creating Langfuse instance")
from langfuse import Langfuse
@ -44,13 +44,13 @@ class LangfuseInstance:
cls._instance = None
@classmethod
def update(cls):
def update(cls) -> None:
logger.debug("Updating Langfuse instance")
cls._instance = None
cls.create()
@classmethod
def teardown(cls):
def teardown(cls) -> None:
logger.debug("Tearing down Langfuse instance")
if cls._instance is not None:
cls._instance.flush()
@ -58,10 +58,10 @@ class LangfuseInstance:
class LangfusePlugin(CallbackPlugin):
def initialize(self):
def initialize(self) -> None:
LangfuseInstance.create()
def teardown(self):
def teardown(self) -> None:
LangfuseInstance.teardown()
def get(self):

View file

@ -13,13 +13,13 @@ from langflow.services.plugins.base import BasePlugin, CallbackPlugin
class PluginService(Service):
name = "plugin_service"
def __init__(self):
def __init__(self) -> None:
self.plugins: dict[str, BasePlugin] = {}
self.plugin_dir = Path(__file__).parent
self.plugins_base_module = "langflow.services.plugins"
self.load_plugins()
def load_plugins(self):
def load_plugins(self) -> None:
base_files = ["base.py", "service.py", "factory.py", "__init__.py"]
for module in self.plugin_dir.iterdir():
if module.suffix == ".py" and module.name not in base_files:
@ -38,7 +38,7 @@ class PluginService(Service):
except Exception: # noqa: BLE001
logger.exception(f"Error loading plugin {plugin_name}")
def register_plugin(self, plugin_name, plugin_instance):
def register_plugin(self, plugin_name, plugin_instance) -> None:
self.plugins[plugin_name] = plugin_instance
plugin_instance.initialize()
@ -50,7 +50,7 @@ class PluginService(Service):
return plugin.get()
return None
async def teardown(self):
async def teardown(self) -> None:
for plugin in self.plugins.values():
plugin.teardown()

View file

@ -8,7 +8,7 @@ if TYPE_CHECKING:
class SessionServiceFactory(ServiceFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(SessionService)
def create(self, cache_service: "CacheService"):

View file

@ -11,7 +11,7 @@ if TYPE_CHECKING:
class SessionService(Service):
name = "session_service"
def __init__(self, cache_service):
def __init__(self, cache_service) -> None:
self.cache_service: CacheService = cache_service
async def load_session(self, key, flow_id: str, data_graph: dict | None = None):
@ -35,7 +35,7 @@ class SessionService(Service):
return graph, artifacts
def build_key(self, session_id, data_graph):
def build_key(self, session_id, data_graph) -> str:
json_hash = compute_dict_hash(data_graph)
return f"{session_id}{':' if session_id else ''}{json_hash}"
@ -46,13 +46,13 @@ class SessionService(Service):
session_id = session_id_generator()
return self.build_key(session_id, data_graph=data_graph)
async def update_session(self, session_id, value):
async def update_session(self, session_id, value) -> None:
result = self.cache_service.set(session_id, value)
# if it is a coroutine, await it
if isinstance(result, Coroutine):
await result
async def clear_session(self, session_id):
async def clear_session(self, session_id) -> None:
result = self.cache_service.delete(session_id)
# if it is a coroutine, await it
if isinstance(result, Coroutine):

View file

@ -57,7 +57,7 @@ class AuthSettings(BaseSettings):
extra = "ignore"
env_prefix = "LANGFLOW_"
def reset_credentials(self):
def reset_credentials(self) -> None:
self.SUPERUSER = DEFAULT_SUPERUSER
self.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD

View file

@ -326,12 +326,12 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(validate_assignment=True, extra="ignore", env_prefix="LANGFLOW_")
def update_from_yaml(self, file_path: str, *, dev: bool = False):
def update_from_yaml(self, file_path: str, *, dev: bool = False) -> None:
new_settings = load_settings_from_yaml(file_path)
self.components_path = new_settings.components_path or []
self.dev = dev
def update_settings(self, **kwargs):
def update_settings(self, **kwargs) -> None:
logger.debug("Updating settings")
for key, value in kwargs.items():
# value may contain sensitive information, so we don't want to log it
@ -374,7 +374,7 @@ class Settings(BaseSettings):
return (MyCustomSource(settings_cls),)
def save_settings_to_yaml(settings: Settings, file_path: str):
def save_settings_to_yaml(settings: Settings, file_path: str) -> None:
with Path(file_path).open("w", encoding="utf-8") as f:
settings_dict = settings.model_dump()
yaml.dump(settings_dict, f)

View file

@ -10,7 +10,7 @@ class SettingsServiceFactory(ServiceFactory):
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
def __init__(self) -> None:
super().__init__(SettingsService)
def create(self):

View file

@ -4,7 +4,7 @@ from pathlib import Path
from loguru import logger
def set_secure_permissions(file_path: Path):
def set_secure_permissions(file_path: Path) -> None:
if platform.system() in {"Linux", "Darwin"}: # Unix/Linux/Mac
file_path.chmod(0o600)
elif platform.system() == "Windows":

View file

@ -8,7 +8,7 @@ if TYPE_CHECKING:
class SharedComponentCacheServiceFactory(ServiceFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(SharedComponentCacheService)
def create(self, settings_service: "SettingsService"):

View file

@ -8,7 +8,7 @@ if TYPE_CHECKING:
class SocketIOFactory(ServiceFactory):
def __init__(self):
def __init__(self) -> None:
super().__init__(
service_class=SocketIOService,
)

View file

@ -17,7 +17,7 @@ class SocketIOService(Service):
def __init__(self, cache_service: "CacheService"):
self.cache_service = cache_service
def init(self, sio: socketio.AsyncServer):
def init(self, sio: socketio.AsyncServer) -> None:
# Registering event handlers
self.sio = sio
if self.sio:
@ -28,32 +28,32 @@ class SocketIOService(Service):
self.sio.on("build_vertex")(self.on_build_vertex)
self.sessions = {} # type: dict[str, dict]
async def emit_error(self, sid, error):
async def emit_error(self, sid, error) -> None:
await self.sio.emit("error", to=sid, data=error)
async def connect(self, sid, environ):
async def connect(self, sid, environ) -> None:
logger.info(f"Socket connected: {sid}")
self.sessions[sid] = environ
async def disconnect(self, sid):
async def disconnect(self, sid) -> None:
logger.info(f"Socket disconnected: {sid}")
self.sessions.pop(sid, None)
async def message(self, sid, data=None):
async def message(self, sid, data=None) -> None:
# Logic for handling messages
await self.emit_message(to=sid, data=data or {"foo": "bar", "baz": [1, 2, 3]})
async def emit_message(self, to, data):
async def emit_message(self, to, data) -> None:
# Abstracting sio.emit
await self.sio.emit("message", to=to, data=data)
async def emit_token(self, to, data):
async def emit_token(self, to, data) -> None:
await self.sio.emit("token", to=to, data=data)
async def on_get_vertices(self, sid, flow_id):
async def on_get_vertices(self, sid, flow_id) -> None:
await get_vertices(self.sio, sid, flow_id, get_chat_service())
async def on_build_vertex(self, sid, flow_id, vertex_id):
async def on_build_vertex(self, sid, flow_id, vertex_id) -> None:
await build_vertex(
sio=self.sio,
sid=sid,

Some files were not shown because too many files have changed in this diff Show more