ref: Auto-fix some ruff ANN rules (#4210)
* Auto-fix some ruff ANN rules * Fix mypy errors * Changes following review * Fix ServiceFactory
This commit is contained in:
parent
c85cd91e09
commit
507eda997a
138 changed files with 588 additions and 580 deletions
|
|
@ -45,7 +45,7 @@ def get_number_of_workers(workers=None):
|
|||
return workers
|
||||
|
||||
|
||||
def display_results(results):
|
||||
def display_results(results) -> None:
|
||||
"""Display the results of the migration."""
|
||||
for table_results in results:
|
||||
table = Table(title=f"Migration {table_results.table_name}")
|
||||
|
|
@ -62,7 +62,7 @@ def display_results(results):
|
|||
console.print() # Print a new line
|
||||
|
||||
|
||||
def set_var_for_macos_issue():
|
||||
def set_var_for_macos_issue() -> None:
|
||||
# OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
|
||||
# we need to set this var is we are running on MacOS
|
||||
# otherwise we get an error when running gunicorn
|
||||
|
|
@ -146,7 +146,7 @@ def run(
|
|||
help="Defines the maximum file size for the upload in MB.",
|
||||
show_default=False,
|
||||
),
|
||||
):
|
||||
) -> None:
|
||||
"""Run Langflow."""
|
||||
configure(log_level=log_level, log_file=log_file)
|
||||
set_var_for_macos_issue()
|
||||
|
|
@ -202,7 +202,7 @@ def run(
|
|||
# Run using uvicorn on MacOS and Windows
|
||||
# Windows doesn't support gunicorn
|
||||
# MacOS requires an env variable to be set to use gunicorn
|
||||
process = run_on_windows(host, port, log_level, options, app)
|
||||
run_on_windows(host, port, log_level, options, app)
|
||||
else:
|
||||
# Run using gunicorn on Linux
|
||||
process = run_on_mac_or_linux(host, port, log_level, options, app)
|
||||
|
|
@ -219,7 +219,7 @@ def run(
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
def wait_for_server_ready(host, port):
|
||||
def wait_for_server_ready(host, port) -> None:
|
||||
"""Wait for the server to become ready by polling the health endpoint."""
|
||||
status_code = 0
|
||||
while status_code != httpx.codes.OK:
|
||||
|
|
@ -241,7 +241,7 @@ def run_on_mac_or_linux(host, port, log_level, options, app):
|
|||
return webapp_process
|
||||
|
||||
|
||||
def run_on_windows(host, port, log_level, options, app):
|
||||
def run_on_windows(host, port, log_level, options, app) -> None:
|
||||
"""Run the Langflow server on Windows."""
|
||||
print_banner(host, port)
|
||||
run_langflow(host, port, log_level, options, app)
|
||||
|
|
@ -275,7 +275,7 @@ def get_free_port(port):
|
|||
return port
|
||||
|
||||
|
||||
def get_letter_from_version(version: str):
|
||||
def get_letter_from_version(version: str) -> str | None:
|
||||
"""Get the letter from a pre-release version."""
|
||||
if "a" in version:
|
||||
return "a"
|
||||
|
|
@ -294,7 +294,7 @@ def build_version_notice(current_version: str, package_name: str) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def generate_pip_command(package_names, is_pre_release):
|
||||
def generate_pip_command(package_names, is_pre_release) -> str:
|
||||
"""Generate the pip install command based on the packages and whether it's a pre-release."""
|
||||
base_command = "pip install"
|
||||
if is_pre_release:
|
||||
|
|
@ -309,7 +309,7 @@ def stylize_text(text: str, to_style: str, *, is_prerelease: bool) -> str:
|
|||
return text.replace(to_style, styled_text)
|
||||
|
||||
|
||||
def print_banner(host: str, port: int):
|
||||
def print_banner(host: str, port: int) -> None:
|
||||
notices = []
|
||||
package_names = [] # Track package names for pip install instructions
|
||||
is_pre_release = False # Track if any package is a pre-release
|
||||
|
|
@ -355,7 +355,7 @@ def print_banner(host: str, port: int):
|
|||
rprint(panel)
|
||||
|
||||
|
||||
def run_langflow(host, port, log_level, options, app):
|
||||
def run_langflow(host, port, log_level, options, app) -> None:
|
||||
"""Run Langflow server on localhost."""
|
||||
if platform.system() == "Windows":
|
||||
# Run using uvicorn on MacOS and Windows
|
||||
|
|
@ -381,7 +381,7 @@ def superuser(
|
|||
username: str = typer.Option(..., prompt=True, help="Username for the superuser."),
|
||||
password: str = typer.Option(..., prompt=True, hide_input=True, help="Password for the superuser."),
|
||||
log_level: str = typer.Option("error", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
|
||||
):
|
||||
) -> None:
|
||||
"""Create a superuser."""
|
||||
configure(log_level=log_level)
|
||||
initialize_services()
|
||||
|
|
@ -413,7 +413,7 @@ def superuser(
|
|||
# command to copy the langflow database from the cache to the current directory
|
||||
# because now the database is stored per installation
|
||||
@app.command()
|
||||
def copy_db():
|
||||
def copy_db() -> None:
|
||||
"""Copy the database files to the current directory.
|
||||
|
||||
This function copies the 'langflow.db' and 'langflow-pre.db' files from the cache directory to the current
|
||||
|
|
@ -452,7 +452,7 @@ def migration(
|
|||
default=False,
|
||||
help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.",
|
||||
),
|
||||
):
|
||||
) -> None:
|
||||
"""Run or test migrations."""
|
||||
if fix and not typer.confirm(
|
||||
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
|
||||
|
|
@ -470,7 +470,7 @@ def migration(
|
|||
@app.command()
|
||||
def api_key(
|
||||
log_level: str = typer.Option("error", help="Logging level."),
|
||||
):
|
||||
) -> None:
|
||||
"""Creates an API key for the default superuser if AUTO_LOGIN is enabled.
|
||||
|
||||
Args:
|
||||
|
|
@ -510,7 +510,7 @@ def api_key(
|
|||
api_key_banner(unmasked_api_key)
|
||||
|
||||
|
||||
def api_key_banner(unmasked_api_key):
|
||||
def api_key_banner(unmasked_api_key) -> None:
|
||||
is_mac = platform.system() == "Darwin"
|
||||
import pyperclip
|
||||
|
||||
|
|
@ -529,7 +529,7 @@ def api_key_banner(unmasked_api_key):
|
|||
console.print(panel)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
app()
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def get_is_component_from_data(data: dict):
|
|||
return data.get("is_component")
|
||||
|
||||
|
||||
async def check_langflow_version(component: StoreComponentCreate):
|
||||
async def check_langflow_version(component: StoreComponentCreate) -> None:
|
||||
from langflow.utils.version import get_version_info
|
||||
|
||||
__version__ = get_version_info()["version"]
|
||||
|
|
@ -259,7 +259,7 @@ def parse_value(value: Any, input_type: str) -> Any:
|
|||
return value
|
||||
|
||||
|
||||
async def cascade_delete_flow(session: Session, flow: Flow):
|
||||
async def cascade_delete_flow(session: Session, flow: Flow) -> None:
|
||||
try:
|
||||
session.exec(delete(TransactionTable).where(TransactionTable.flow_id == flow.id))
|
||||
session.exec(delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow.id))
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler):
|
|||
@override
|
||||
async def on_agent_action( # type: ignore[misc]
|
||||
self, action: AgentAction, **kwargs: Any
|
||||
):
|
||||
) -> None:
|
||||
log = f"Thought: {action.log}"
|
||||
# if there are line breaks, split them and send them
|
||||
# as separate messages
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ async def build_flow(
|
|||
event_manager = create_default_event_manager(queue=asyncio_queue)
|
||||
main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed))
|
||||
|
||||
def on_disconnect():
|
||||
def on_disconnect() -> None:
|
||||
logger.debug("Client disconnected, closing tasks")
|
||||
main_task.cancel()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from langflow.api.v1.schemas import (
|
|||
CustomComponentRequest,
|
||||
CustomComponentResponse,
|
||||
InputValueRequest,
|
||||
ProcessResponse,
|
||||
RunResponse,
|
||||
SidebarCategoriesResponse,
|
||||
SimplifiedAPIRequest,
|
||||
|
|
@ -68,7 +67,7 @@ async def get_all():
|
|||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def validate_input_and_tweaks(input_request: SimplifiedAPIRequest):
|
||||
def validate_input_and_tweaks(input_request: SimplifiedAPIRequest) -> None:
|
||||
# If the input_value is not None and the input_type is "chat"
|
||||
# then we need to check the tweaks if the ChatInput component is present
|
||||
# and if its input_value is not None
|
||||
|
|
@ -483,15 +482,13 @@ async def experimental_run_flow(
|
|||
|
||||
@router.post(
|
||||
"/predict/{flow_id}",
|
||||
response_model=ProcessResponse,
|
||||
dependencies=[Depends(api_key_security)],
|
||||
)
|
||||
@router.post(
|
||||
"/process/{flow_id}",
|
||||
response_model=ProcessResponse,
|
||||
dependencies=[Depends(api_key_security)],
|
||||
)
|
||||
async def process():
|
||||
async def process() -> None:
|
||||
"""Endpoint to process an input with a given flow_id."""
|
||||
# Raise a depreciation warning
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ async def get_vertex_builds(
|
|||
async def delete_vertex_builds(
|
||||
flow_id: Annotated[UUID, Query()],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
):
|
||||
) -> None:
|
||||
try:
|
||||
delete_vertex_builds_by_flow_id(session, flow_id)
|
||||
except Exception as e:
|
||||
|
|
@ -75,7 +75,7 @@ async def get_messages(
|
|||
async def delete_messages(
|
||||
message_ids: list[UUID],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
):
|
||||
) -> None:
|
||||
try:
|
||||
session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined]
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ def delete_variable(
|
|||
variable_id: UUID,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
variable_service: VariableService = Depends(get_variable_service),
|
||||
):
|
||||
) -> None:
|
||||
"""Delete a variable."""
|
||||
try:
|
||||
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class LCAgentComponent(Component):
|
|||
self.status = message
|
||||
return message
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
required_output_methods = ["build_agent"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class BaseCrewComponent(Component):
|
|||
def get_task_callback(
|
||||
self,
|
||||
) -> Callable:
|
||||
def task_callback(task_output: TaskOutput):
|
||||
def task_callback(task_output: TaskOutput) -> None:
|
||||
vertex_id = self._vertex.id if self._vertex else self.display_name or self.__class__.__name__
|
||||
self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}")
|
||||
|
||||
|
|
@ -61,7 +61,7 @@ class BaseCrewComponent(Component):
|
|||
def get_step_callback(
|
||||
self,
|
||||
) -> Callable:
|
||||
def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]):
|
||||
def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]) -> None:
|
||||
_id = self._vertex.id if self._vertex else self.display_name
|
||||
if isinstance(agent_output, AgentFinish):
|
||||
messages = agent_output.messages
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ tool_names = []
|
|||
tools_and_names = {}
|
||||
|
||||
|
||||
def tools_from_package(your_package):
|
||||
def tools_from_package(your_package) -> None:
|
||||
# Iterate over all modules in the package
|
||||
package_name = your_package.__name__
|
||||
for module_info in pkgutil.iter_modules(your_package.__path__):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ class LCChainComponent(Component):
|
|||
|
||||
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
required_output_methods = ["invoke_chain"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class LCEmbeddingsModel(Component):
|
|||
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
|
||||
]
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
required_output_methods = ["build_embeddings"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class ChatComponent(Component):
|
|||
self.status = stored_message
|
||||
return stored_message
|
||||
|
||||
def _send_message_event(self, message: Message):
|
||||
def _send_message_event(self, message: Message) -> None:
|
||||
if hasattr(self, "_event_manager") and self._event_manager:
|
||||
self._event_manager.on_message(data=message.data)
|
||||
|
||||
|
|
@ -107,7 +107,7 @@ class ChatComponent(Component):
|
|||
return Message.from_data(input_value)
|
||||
return Message(text=input_value, sender=sender, sender_name=sender_name, files=files, session_id=session_id)
|
||||
|
||||
def _send_messages_events(self, messages):
|
||||
def _send_messages_events(self, messages) -> None:
|
||||
if hasattr(self, "_event_manager") and self._event_manager:
|
||||
for stored_message in messages:
|
||||
self._event_manager.on_message(data=stored_message.data)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class LCToolComponent(Component):
|
|||
Output(name="api_build_tool", display_name="Tool", method="build_tool"),
|
||||
]
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
required_output_methods = ["run_model", "build_tool"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
|
|
|
|||
|
|
@ -45,5 +45,5 @@ class BaseMemoryComponent(CustomComponent):
|
|||
|
||||
def add_message(
|
||||
self, sender: str, sender_name: str, text: str, session_id: str, metadata: dict | None = None, **kwargs
|
||||
):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class LCChatMemoryComponent(Component):
|
|||
)
|
||||
]
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
required_output_methods = ["build_message_history"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class LCModelComponent(Component):
|
|||
def _get_exception_message(self, e: Exception):
|
||||
return str(e)
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
# At least these two outputs must be defined
|
||||
required_output_methods = ["text_response", "build_model"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def _check_variable(var, invalid_chars, wrong_variables, empty_variables):
|
|||
return wrong_variables, empty_variables
|
||||
|
||||
|
||||
def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables):
|
||||
def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables) -> None:
|
||||
if any(var for var in input_variables if var not in fixed_variables):
|
||||
error_message = (
|
||||
f"Error: Input variables contain invalid characters or formats. \n"
|
||||
|
|
@ -159,7 +159,7 @@ def get_old_custom_fields(custom_fields, name):
|
|||
return old_custom_fields
|
||||
|
||||
|
||||
def add_new_variables_to_template(input_variables, custom_fields, template, name):
|
||||
def add_new_variables_to_template(input_variables, custom_fields, template, name) -> None:
|
||||
for variable in input_variables:
|
||||
try:
|
||||
template_field = DefaultPromptField(name=variable, display_name=variable)
|
||||
|
|
@ -177,7 +177,7 @@ def add_new_variables_to_template(input_variables, custom_fields, template, name
|
|||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name):
|
||||
def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name) -> None:
|
||||
for variable in old_custom_fields:
|
||||
if variable not in input_variables:
|
||||
try:
|
||||
|
|
@ -192,7 +192,7 @@ def remove_old_variables_from_template(old_custom_fields, input_variables, custo
|
|||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def update_input_variables_field(input_variables, template):
|
||||
def update_input_variables_field(input_variables, template) -> None:
|
||||
if "input_variables" in template:
|
||||
template["input_variables"]["value"] = input_variables
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from langflow.base.document_transformers.model import LCDocumentTransformerCompo
|
|||
class LCTextSplitterComponent(LCDocumentTransformerComponent):
|
||||
trace_type = "text_splitter"
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
required_output_methods = ["text_splitter"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def _get_input_type(_input: InputTypes):
|
|||
return _input.field_type
|
||||
|
||||
|
||||
def build_description(component: Component, output: Output):
|
||||
def build_description(component: Component, output: Output) -> str:
|
||||
if not output.required_inputs:
|
||||
logger.warning(f"Output {output.name} does not have required inputs defined")
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class LCVectorStoreComponent(Component):
|
|||
),
|
||||
]
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
# At least these three outputs must be defined
|
||||
required_output_methods = [
|
||||
"build_base_retriever",
|
||||
|
|
|
|||
|
|
@ -67,39 +67,39 @@ class AstraAssistantManager(ComponentWithCache):
|
|||
Output(display_name="Assistant Id", name="output_assistant_id", method="get_assistant_id"),
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.lock = asyncio.Lock()
|
||||
self.initialized = False
|
||||
self.assistant_response = None
|
||||
self.tool_output = None
|
||||
self.thread_id = None
|
||||
self.assistant_id = None
|
||||
self.initialized: bool = False
|
||||
self._assistant_response: Message = None # type: ignore[assignment]
|
||||
self._tool_output: Message = None # type: ignore[assignment]
|
||||
self._thread_id: Message = None # type: ignore[assignment]
|
||||
self._assistant_id: Message = None # type: ignore[assignment]
|
||||
self.client = get_patched_openai_client(self._shared_component_cache)
|
||||
|
||||
async def get_assistant_response(self) -> Message:
|
||||
await self.initialize()
|
||||
return self.assistant_response
|
||||
return self._assistant_response
|
||||
|
||||
async def get_tool_output(self) -> Message:
|
||||
await self.initialize()
|
||||
return self.tool_output
|
||||
return self._tool_output
|
||||
|
||||
async def get_thread_id(self) -> Message:
|
||||
await self.initialize()
|
||||
return self.thread_id
|
||||
return self._thread_id
|
||||
|
||||
async def get_assistant_id(self) -> Message:
|
||||
await self.initialize()
|
||||
return self.assistant_id
|
||||
return self._assistant_id
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self) -> None:
|
||||
async with self.lock:
|
||||
if not self.initialized:
|
||||
await self.process_inputs()
|
||||
self.initialized = True
|
||||
|
||||
async def process_inputs(self):
|
||||
async def process_inputs(self) -> None:
|
||||
logger.info(f"env_set is {self.env_set}")
|
||||
logger.info(self.tool)
|
||||
tools = []
|
||||
|
|
@ -126,10 +126,10 @@ class AstraAssistantManager(ComponentWithCache):
|
|||
|
||||
content = self.user_message
|
||||
result = await assistant_manager.run_thread(content=content, tool=tool_obj)
|
||||
self.assistant_response = Message(text=result["text"])
|
||||
self._assistant_response = Message(text=result["text"])
|
||||
if "decision" in result:
|
||||
self.tool_output = Message(text=str(result["decision"].is_complete))
|
||||
self._tool_output = Message(text=str(result["decision"].is_complete))
|
||||
else:
|
||||
self.tool_output = Message(text=result["text"])
|
||||
self.thread_id = Message(text=assistant_manager.thread.id)
|
||||
self.assistant_id = Message(text=assistant_manager.assistant.id)
|
||||
self._tool_output = Message(text=result["text"])
|
||||
self._thread_id = Message(text=assistant_manager.thread.id)
|
||||
self._assistant_id = Message(text=assistant_manager.assistant.id)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class AssistantsCreateAssistant(ComponentWithCache):
|
|||
Output(display_name="Assistant ID", name="assistant_id", method="process_inputs"),
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.client = get_patched_openai_client(self._shared_component_cache)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class AssistantsCreateThread(ComponentWithCache):
|
|||
Output(display_name="Thread ID", name="thread_id", method="process_inputs"),
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.client = get_patched_openai_client(self._shared_component_cache)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class AssistantsGetAssistantName(ComponentWithCache):
|
|||
Output(display_name="Assistant Name", name="assistant_name", method="process_inputs"),
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.client = get_patched_openai_client(self._shared_component_cache)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class AssistantsListAssistants(ComponentWithCache):
|
|||
Output(display_name="Assistants", name="assistants", method="process_inputs"),
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.client = get_patched_openai_client(self._shared_component_cache)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class AssistantsRun(ComponentWithCache):
|
|||
display_name = "Run Assistant"
|
||||
description = "Executes an Assistant Run against a thread"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.client = get_patched_openai_client(self._shared_component_cache)
|
||||
self.thread_id = None
|
||||
|
|
@ -26,7 +26,7 @@ class AssistantsRun(ComponentWithCache):
|
|||
build_config: dotdict,
|
||||
field_value: Any,
|
||||
field_name: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
if field_name == "thread_id":
|
||||
if field_value is None:
|
||||
thread = self.client.beta.threads.create()
|
||||
|
|
@ -75,7 +75,7 @@ class AssistantsRun(ComponentWithCache):
|
|||
self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=self.user_message)
|
||||
|
||||
class EventHandler(AssistantEventHandler):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def on_exception(self, exception: Exception) -> None:
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class GoogleDriveSearchComponent(Component):
|
|||
|
||||
return query
|
||||
|
||||
def on_inputs_changed(self):
|
||||
def on_inputs_changed(self) -> None:
|
||||
# Automatically regenerate the query string when inputs change
|
||||
self.generate_query_string()
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class GoogleGenerativeAIEmbeddingsComponent(Component):
|
|||
raise ValueError(msg)
|
||||
|
||||
class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs)
|
||||
|
||||
def embed_documents(
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ class CreateDataComponent(Component):
|
|||
data.update(_value_dict)
|
||||
return data
|
||||
|
||||
def validate_text_key(self):
|
||||
def validate_text_key(self) -> None:
|
||||
"""This function validates that the Text Key is one of the keys in the Data."""
|
||||
data_keys = self.get_data().keys()
|
||||
if self.text_key not in data_keys and self.text_key != "":
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class UpdateDataComponent(Component):
|
|||
data.update(_value_dict)
|
||||
return data
|
||||
|
||||
def validate_text_key(self, data: Data):
|
||||
def validate_text_key(self, data: Data) -> None:
|
||||
"""This function validates that the Text Key is one of the keys in the Data."""
|
||||
data_keys = data.data.keys()
|
||||
if self.text_key not in data_keys and self.text_key != "":
|
||||
|
|
|
|||
|
|
@ -434,7 +434,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
|
||||
return vector_store
|
||||
|
||||
def _add_documents_to_vector_store(self, vector_store):
|
||||
def _add_documents_to_vector_store(self, vector_store) -> None:
|
||||
documents = []
|
||||
for _input in self.ingest_data or []:
|
||||
if isinstance(_input, Data):
|
||||
|
|
@ -453,7 +453,7 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
|
|||
else:
|
||||
logger.debug("No documents to add to the Vector Store.")
|
||||
|
||||
def _map_search_type(self):
|
||||
def _map_search_type(self) -> str:
|
||||
if self.search_type == "Similarity with score threshold":
|
||||
return "similarity_score_threshold"
|
||||
if self.search_type == "MMR (Max Marginal Relevance)":
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ class CassandraVectorStoreComponent(LCVectorStoreComponent):
|
|||
)
|
||||
return table
|
||||
|
||||
def _map_search_type(self):
|
||||
def _map_search_type(self) -> str:
|
||||
if self.search_type == "Similarity with score threshold":
|
||||
return "similarity_score_threshold"
|
||||
if self.search_type == "MMR (Max Marginal Relevance)":
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class CassandraGraphVectorStoreComponent(LCVectorStoreComponent):
|
|||
)
|
||||
return store
|
||||
|
||||
def _map_search_type(self):
|
||||
def _map_search_type(self) -> str:
|
||||
if self.search_type == "Similarity":
|
||||
return "similarity"
|
||||
if self.search_type == "Similarity with score threshold":
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ class HCDVectorStoreComponent(LCVectorStoreComponent):
|
|||
self._add_documents_to_vector_store(vector_store)
|
||||
return vector_store
|
||||
|
||||
def _add_documents_to_vector_store(self, vector_store):
|
||||
def _add_documents_to_vector_store(self, vector_store) -> None:
|
||||
documents = []
|
||||
for _input in self.ingest_data or []:
|
||||
if isinstance(_input, Data):
|
||||
|
|
@ -272,7 +272,7 @@ class HCDVectorStoreComponent(LCVectorStoreComponent):
|
|||
else:
|
||||
logger.debug("No documents to add to the Vector Store.")
|
||||
|
||||
def _map_search_type(self):
|
||||
def _map_search_type(self) -> str:
|
||||
if self.search_type == "Similarity with score threshold":
|
||||
return "similarity_score_threshold"
|
||||
if self.search_type == "MMR (Max Marginal Relevance)":
|
||||
|
|
|
|||
|
|
@ -318,7 +318,7 @@ class CodeParser:
|
|||
self.process_class_node(_node, class_details)
|
||||
self.data["classes"].append(class_details.model_dump())
|
||||
|
||||
def process_class_node(self, node, class_details):
|
||||
def process_class_node(self, node, class_details) -> None:
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if attr := self.parse_assign(stmt):
|
||||
|
|
|
|||
|
|
@ -31,15 +31,15 @@ class BaseComponent:
|
|||
_user_id: str | UUID | None = None
|
||||
_template_config: dict = {}
|
||||
|
||||
def __init__(self, **data):
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
def __init__(self, **data) -> None:
|
||||
self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60)
|
||||
for key, value in data.items():
|
||||
if key == "user_id":
|
||||
self._user_id = value
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
def __setattr__(self, key, value) -> None:
|
||||
if key == "_user_id" and self._user_id is not None:
|
||||
logger.warning("user_id is immutable and cannot be changed.")
|
||||
super().__setattr__(key, value)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class Component(CustomComponent):
|
|||
_current_output: str = ""
|
||||
_metadata: dict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# if key starts with _ it is a config
|
||||
# else it is an input
|
||||
inputs = {}
|
||||
|
|
@ -107,10 +107,10 @@ class Component(CustomComponent):
|
|||
self.set_class_code()
|
||||
self._set_output_required_inputs()
|
||||
|
||||
def set_event_manager(self, event_manager: EventManager | None = None):
|
||||
def set_event_manager(self, event_manager: EventManager | None = None) -> None:
|
||||
self._event_manager = event_manager
|
||||
|
||||
def _reset_all_output_values(self):
|
||||
def _reset_all_output_values(self) -> None:
|
||||
if isinstance(self._outputs_map, dict):
|
||||
for output in self._outputs_map.values():
|
||||
output.value = UNDEFINED
|
||||
|
|
@ -153,7 +153,7 @@ class Component(CustomComponent):
|
|||
memo[id(self)] = new_component
|
||||
return new_component
|
||||
|
||||
def set_class_code(self):
|
||||
def set_class_code(self) -> None:
|
||||
# Get the source code of the calling class
|
||||
if self._code:
|
||||
return
|
||||
|
|
@ -200,7 +200,7 @@ class Component(CustomComponent):
|
|||
"""
|
||||
return await self._run()
|
||||
|
||||
def set_vertex(self, vertex: Vertex):
|
||||
def set_vertex(self, vertex: Vertex) -> None:
|
||||
"""Sets the vertex for the component.
|
||||
|
||||
Args:
|
||||
|
|
@ -245,7 +245,7 @@ class Component(CustomComponent):
|
|||
msg = f"Output {name} not found in {self.__class__.__name__}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def set_on_output(self, name: str, **kwargs):
|
||||
def set_on_output(self, name: str, **kwargs) -> None:
|
||||
output = self.get_output(name)
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(output, key):
|
||||
|
|
@ -253,14 +253,14 @@ class Component(CustomComponent):
|
|||
raise ValueError(msg)
|
||||
setattr(output, key, value)
|
||||
|
||||
def set_output_value(self, name: str, value: Any):
|
||||
def set_output_value(self, name: str, value: Any) -> None:
|
||||
if name in self._outputs_map:
|
||||
self._outputs_map[name].value = value
|
||||
else:
|
||||
msg = f"Output {name} not found in {self.__class__.__name__}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def map_outputs(self, outputs: list[Output]):
|
||||
def map_outputs(self, outputs: list[Output]) -> None:
|
||||
"""Maps the given list of outputs to the component.
|
||||
|
||||
Args:
|
||||
|
|
@ -280,7 +280,7 @@ class Component(CustomComponent):
|
|||
# allows each instance of each component to modify its own output
|
||||
self._outputs_map[output.name] = deepcopy(output)
|
||||
|
||||
def map_inputs(self, inputs: list[InputTypes]):
|
||||
def map_inputs(self, inputs: list[InputTypes]) -> None:
|
||||
"""Maps the given inputs to the component.
|
||||
|
||||
Args:
|
||||
|
|
@ -296,7 +296,7 @@ class Component(CustomComponent):
|
|||
raise ValueError(msg)
|
||||
self._inputs[input_.name] = deepcopy(input_)
|
||||
|
||||
def validate(self, params: dict):
|
||||
def validate(self, params: dict) -> None:
|
||||
"""Validates the component parameters.
|
||||
|
||||
Args:
|
||||
|
|
@ -309,13 +309,16 @@ class Component(CustomComponent):
|
|||
self._validate_inputs(params)
|
||||
self._validate_outputs()
|
||||
|
||||
def _set_output_types(self):
|
||||
def _set_output_types(self) -> None:
|
||||
for output in self._outputs_map.values():
|
||||
if output.method is None:
|
||||
msg = f"Output {output.name} does not have a method"
|
||||
raise ValueError(msg)
|
||||
return_types = self._get_method_return_type(output.method)
|
||||
output.add_types(return_types)
|
||||
output.set_selected()
|
||||
|
||||
def _set_output_required_inputs(self):
|
||||
def _set_output_required_inputs(self) -> None:
|
||||
for output in self.outputs:
|
||||
if not output.method:
|
||||
continue
|
||||
|
|
@ -326,8 +329,7 @@ class Component(CustomComponent):
|
|||
source_code = inspect.getsource(method)
|
||||
ast_tree = ast.parse(dedent(source_code))
|
||||
except Exception: # noqa: BLE001
|
||||
source_code = self._code
|
||||
ast_tree = ast.parse(dedent(source_code))
|
||||
ast_tree = ast.parse(dedent(self._code or ""))
|
||||
|
||||
visitor = RequiredInputsVisitor(self._inputs)
|
||||
visitor.visit(ast_tree)
|
||||
|
|
@ -420,7 +422,7 @@ class Component(CustomComponent):
|
|||
raise TypeError(msg)
|
||||
return getattr(value, output.method)
|
||||
|
||||
def _process_connection_or_parameter(self, key, value):
|
||||
def _process_connection_or_parameter(self, key, value) -> None:
|
||||
_input = self._get_or_create_input(key)
|
||||
# We need to check if callable AND if it is a method from a class that inherits from Component
|
||||
if isinstance(value, Component):
|
||||
|
|
@ -438,7 +440,7 @@ class Component(CustomComponent):
|
|||
else:
|
||||
self._set_parameter_or_attribute(key, value)
|
||||
|
||||
def _process_connection_or_parameters(self, key, value):
|
||||
def _process_connection_or_parameters(self, key, value) -> None:
|
||||
# if value is a list of components, we need to process each component
|
||||
if isinstance(value, list):
|
||||
for val in value:
|
||||
|
|
@ -455,13 +457,13 @@ class Component(CustomComponent):
|
|||
self.inputs.append(_input)
|
||||
return _input
|
||||
|
||||
def _connect_to_component(self, key, value, _input):
|
||||
def _connect_to_component(self, key, value, _input) -> None:
|
||||
component = value.__self__
|
||||
self._components.append(component)
|
||||
output = component.get_output_by_method(value)
|
||||
self._add_edge(component, key, output, _input)
|
||||
|
||||
def _add_edge(self, component, key, output, _input):
|
||||
def _add_edge(self, component, key, output, _input) -> None:
|
||||
self._edges.append(
|
||||
{
|
||||
"source": component._id,
|
||||
|
|
@ -483,7 +485,7 @@ class Component(CustomComponent):
|
|||
}
|
||||
)
|
||||
|
||||
def _set_parameter_or_attribute(self, key, value):
|
||||
def _set_parameter_or_attribute(self, key, value) -> None:
|
||||
if isinstance(value, Component):
|
||||
methods = ", ".join([f"'{output.method}'" for output in value.outputs])
|
||||
msg = (
|
||||
|
|
@ -527,7 +529,7 @@ class Component(CustomComponent):
|
|||
msg = f"{name} not found in {self.__class__.__name__}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
def _set_input_value(self, name: str, value: Any):
|
||||
def _set_input_value(self, name: str, value: Any) -> None:
|
||||
if name in self._inputs:
|
||||
input_value = self._inputs[name].value
|
||||
if isinstance(input_value, Component):
|
||||
|
|
@ -547,15 +549,15 @@ class Component(CustomComponent):
|
|||
msg = f"Input {name} not found in {self.__class__.__name__}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _validate_outputs(self):
|
||||
def _validate_outputs(self) -> None:
|
||||
# Raise Error if some rule isn't met
|
||||
pass
|
||||
|
||||
def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode):
|
||||
def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode) -> None:
|
||||
for name, value in self._parameters.items():
|
||||
frontend_node.set_field_value_in_template(name, value)
|
||||
|
||||
def _map_parameters_on_template(self, template: dict):
|
||||
def _map_parameters_on_template(self, template: dict) -> None:
|
||||
for name, value in self._parameters.items():
|
||||
try:
|
||||
template[name]["value"] = value
|
||||
|
|
@ -625,7 +627,7 @@ class Component(CustomComponent):
|
|||
"id": self._id,
|
||||
}
|
||||
|
||||
def _validate_inputs(self, params: dict):
|
||||
def _validate_inputs(self, params: dict) -> None:
|
||||
# Params keys are the `name` attribute of the Input objects
|
||||
for key, value in params.copy().items():
|
||||
if key not in self._inputs:
|
||||
|
|
@ -636,7 +638,7 @@ class Component(CustomComponent):
|
|||
input_.value = value
|
||||
params[input_.name] = input_.value
|
||||
|
||||
def set_attributes(self, params: dict):
|
||||
def set_attributes(self, params: dict) -> None:
|
||||
self._validate_inputs(params)
|
||||
_attributes = {}
|
||||
for key, value in params.items():
|
||||
|
|
@ -652,7 +654,7 @@ class Component(CustomComponent):
|
|||
_attributes[key] = input_obj.value or None
|
||||
self._attributes = _attributes
|
||||
|
||||
def _set_outputs(self, outputs: list[dict]):
|
||||
def _set_outputs(self, outputs: list[dict]) -> None:
|
||||
self.outputs = [Output(**output) for output in outputs]
|
||||
for output in self.outputs:
|
||||
setattr(self, output.name, output)
|
||||
|
|
@ -794,7 +796,7 @@ class Component(CustomComponent):
|
|||
except KeyError:
|
||||
return []
|
||||
|
||||
def build(self, **kwargs):
|
||||
def build(self, **kwargs) -> None:
|
||||
self.set_attributes(kwargs)
|
||||
|
||||
def _get_fallback_input(self, **kwargs):
|
||||
|
|
@ -809,7 +811,7 @@ class Component(CustomComponent):
|
|||
return self._tracing_service.project_name
|
||||
return "Langflow"
|
||||
|
||||
def log(self, message: LoggableType | list[LoggableType], name: str | None = None):
|
||||
def log(self, message: LoggableType | list[LoggableType], name: str | None = None) -> None:
|
||||
"""Logs a message.
|
||||
|
||||
Args:
|
||||
|
|
@ -828,6 +830,6 @@ class Component(CustomComponent):
|
|||
data["component_id"] = self._id
|
||||
self._event_manager.on_log(data=data)
|
||||
|
||||
def _append_tool_output(self):
|
||||
def _append_tool_output(self) -> None:
|
||||
if next((output for output in self.outputs if output.name == TOOL_OUTPUT_NAME), None) is None:
|
||||
self.outputs.append(Output(name=TOOL_OUTPUT_NAME, display_name="Tool", method="to_toolkit", types=["Tool"]))
|
||||
|
|
|
|||
|
|
@ -3,6 +3,6 @@ from langflow.services.deps import get_shared_component_cache_service
|
|||
|
||||
|
||||
class ComponentWithCache(Component):
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
self._shared_component_cache = get_shared_component_cache_service()
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class CustomComponent(BaseComponent):
|
|||
_tracing_service: TracingService | None = None
|
||||
_tree: dict | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
def __init__(self, **data) -> None:
|
||||
"""Initializes a new instance of the CustomComponent class.
|
||||
|
||||
Args:
|
||||
|
|
@ -93,22 +93,25 @@ class CustomComponent(BaseComponent):
|
|||
"""
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
self._logs = []
|
||||
self._results = {}
|
||||
self._artifacts = {}
|
||||
self._results: dict = {}
|
||||
self._artifacts: dict = {}
|
||||
super().__init__(**data)
|
||||
|
||||
def set_attributes(self, parameters: dict):
|
||||
def set_attributes(self, parameters: dict) -> None:
|
||||
pass
|
||||
|
||||
def set_parameters(self, parameters: dict):
|
||||
def set_parameters(self, parameters: dict) -> None:
|
||||
self._parameters = parameters
|
||||
self.set_attributes(self._parameters)
|
||||
|
||||
@property
|
||||
def trace_name(self):
|
||||
def trace_name(self) -> str:
|
||||
if self._vertex is None:
|
||||
msg = "Vertex is not set"
|
||||
raise ValueError(msg)
|
||||
return f"{self.display_name} ({self._vertex.id})"
|
||||
|
||||
def update_state(self, name: str, value: Any):
|
||||
def update_state(self, name: str, value: Any) -> None:
|
||||
if not self._vertex:
|
||||
msg = "Vertex is not set"
|
||||
raise ValueError(msg)
|
||||
|
|
@ -118,7 +121,7 @@ class CustomComponent(BaseComponent):
|
|||
msg = f"Error updating state: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
def stop(self, output_name: str | None = None):
|
||||
def stop(self, output_name: str | None = None) -> None:
|
||||
if not output_name and self._vertex and len(self._vertex.outputs) == 1:
|
||||
output_name = self._vertex.outputs[0]["name"]
|
||||
elif not output_name:
|
||||
|
|
@ -133,7 +136,7 @@ class CustomComponent(BaseComponent):
|
|||
msg = f"Error stopping {self.display_name}: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
def append_state(self, name: str, value: Any):
|
||||
def append_state(self, name: str, value: Any) -> None:
|
||||
if not self._vertex:
|
||||
msg = "Vertex is not set"
|
||||
raise ValueError(msg)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class CustomComponentPathValueError(ValueError):
|
|||
|
||||
|
||||
class StringCompressor:
|
||||
def __init__(self, input_string):
|
||||
def __init__(self, input_string) -> None:
|
||||
"""Initialize StringCompressor with a string to compress."""
|
||||
self.input_string = input_string
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ class DirectoryReader:
|
|||
# the custom components from this directory.
|
||||
base_path = ""
|
||||
|
||||
def __init__(self, directory_path, *, compress_code_field=False):
|
||||
def __init__(self, directory_path, *, compress_code_field=False) -> None:
|
||||
"""Initialize DirectoryReader with a directory path and a flag indicating whether to compress the code."""
|
||||
self.directory_path = directory_path
|
||||
self.compress_code_field = compress_code_field
|
||||
|
|
@ -76,7 +76,7 @@ class DirectoryReader:
|
|||
logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}')
|
||||
return {"menu": filtered}
|
||||
|
||||
def validate_code(self, file_content):
|
||||
def validate_code(self, file_content) -> bool:
|
||||
"""Validate the Python code by trying to parse it with ast.parse."""
|
||||
try:
|
||||
ast.parse(file_content)
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ def create_invalid_component_template(component, component_name):
|
|||
return component_frontend_node.model_dump(by_alias=True, exclude_none=True)
|
||||
|
||||
|
||||
def log_invalid_component_details(component):
|
||||
def log_invalid_component_details(component) -> None:
|
||||
"""Log details of an invalid component."""
|
||||
logger.debug(component)
|
||||
logger.debug(f"Component Path: {component.get('path', None)}")
|
||||
|
|
|
|||
|
|
@ -28,5 +28,5 @@ class CallableCodeDetails(BaseModel):
|
|||
class MissingDefault:
|
||||
"""A class to represent a missing default value."""
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "MISSING"
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
import ast
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class RequiredInputsVisitor(ast.NodeVisitor):
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
self.required_inputs = set()
|
||||
def __init__(self, inputs: dict[str, Any]):
|
||||
self.inputs: dict[str, Any] = inputs
|
||||
self.required_inputs: set[str] = set()
|
||||
|
||||
@override
|
||||
def visit_Attribute(self, node):
|
||||
def visit_Attribute(self, node) -> None:
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "self" and node.attr in self.inputs:
|
||||
self.required_inputs.add(node.attr)
|
||||
self.generic_visit(node)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class UpdateBuildConfigError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: list[str]):
|
||||
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: list[str]) -> None:
|
||||
"""Add output types to the frontend node."""
|
||||
for return_type in return_types:
|
||||
if return_type is None:
|
||||
|
|
@ -55,7 +55,7 @@ def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: l
|
|||
frontend_node.add_output_type(_return_type)
|
||||
|
||||
|
||||
def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list[str]):
|
||||
def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list[str]) -> None:
|
||||
"""Reorder fields in the frontend node based on the specified field_order."""
|
||||
if not field_order:
|
||||
return
|
||||
|
|
@ -69,7 +69,7 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list
|
|||
frontend_node.field_order = field_order
|
||||
|
||||
|
||||
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: list[str]):
|
||||
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: list[str]) -> None:
|
||||
"""Add base classes to the frontend node."""
|
||||
for return_type_instance in return_types:
|
||||
if return_type_instance is None:
|
||||
|
|
@ -196,7 +196,7 @@ def add_new_custom_field(
|
|||
return frontend_node
|
||||
|
||||
|
||||
def add_extra_fields(frontend_node, field_config, function_args):
|
||||
def add_extra_fields(frontend_node, field_config, function_args) -> None:
|
||||
"""Add extra fields to the frontend node."""
|
||||
if not function_args:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class EventManager:
|
|||
self.events: dict[str, PartialEventCallback] = {}
|
||||
|
||||
@staticmethod
|
||||
def _validate_callback(callback: EventCallback):
|
||||
def _validate_callback(callback: EventCallback) -> None:
|
||||
if not callable(callback):
|
||||
msg = "Callback must be callable"
|
||||
raise TypeError(msg)
|
||||
|
|
@ -39,7 +39,7 @@ class EventManager:
|
|||
msg = "Callback must have exactly 3 parameters: manager, event_type, and data"
|
||||
raise ValueError(msg)
|
||||
|
||||
def register_event(self, name: str, event_type: str, callback: EventCallback | None = None):
|
||||
def register_event(self, name: str, event_type: str, callback: EventCallback | None = None) -> None:
|
||||
if not name:
|
||||
msg = "Event name cannot be empty"
|
||||
raise ValueError(msg)
|
||||
|
|
@ -52,14 +52,14 @@ class EventManager:
|
|||
_callback = partial(callback, manager=self, event_type=event_type)
|
||||
self.events[name] = _callback
|
||||
|
||||
def send_event(self, *, event_type: str, data: LoggableType):
|
||||
def send_event(self, *, event_type: str, data: LoggableType) -> None:
|
||||
jsonable_data = jsonable_encoder(data)
|
||||
json_data = {"event": event_type, "data": jsonable_data}
|
||||
event_id = uuid.uuid4()
|
||||
str_data = json.dumps(json_data) + "\n\n"
|
||||
self.queue.put_nowait((event_id, str_data.encode("utf-8"), time.time()))
|
||||
|
||||
def noop(self, *, data: LoggableType):
|
||||
def noop(self, *, data: LoggableType) -> None:
|
||||
pass
|
||||
|
||||
def __getattr__(self, name: str) -> PartialEventCallback:
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class VertexViewer:
|
|||
|
||||
HEIGHT = 3 # top and bottom box edges + text
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name) -> None:
|
||||
self._h = self.HEIGHT # top and bottom box edges + text
|
||||
self._w = len(name) + 2 # right and left bottom edges + text
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ class VertexViewer:
|
|||
class AsciiCanvas:
|
||||
"""Class for drawing in ASCII."""
|
||||
|
||||
def __init__(self, cols, lines):
|
||||
def __init__(self, cols, lines) -> None:
|
||||
assert cols > 1
|
||||
assert lines > 1
|
||||
self.cols = cols
|
||||
|
|
@ -53,19 +53,19 @@ class AsciiCanvas:
|
|||
def draws(self):
|
||||
return "\n".join(self.get_lines())
|
||||
|
||||
def draw(self):
|
||||
def draw(self) -> None:
|
||||
"""Draws ASCII canvas on the screen."""
|
||||
lines = self.get_lines()
|
||||
print("\n".join(lines)) # noqa: T201
|
||||
|
||||
def point(self, x, y, char):
|
||||
def point(self, x, y, char) -> None:
|
||||
"""Create a point on ASCII canvas."""
|
||||
assert len(char) == 1
|
||||
assert 0 <= x < self.cols
|
||||
assert 0 <= y < self.lines
|
||||
self.canvas[y][x] = char
|
||||
|
||||
def line(self, x0, y0, x1, y1, char):
|
||||
def line(self, x0, y0, x1, y1, char) -> None:
|
||||
"""Create a line on ASCII canvas."""
|
||||
if x0 > x1:
|
||||
x1, x0 = x0, x1
|
||||
|
|
@ -85,12 +85,12 @@ class AsciiCanvas:
|
|||
x = x0 + int(round((y - y0) * dx / float(dy))) if dy else x0
|
||||
self.point(x, y, char)
|
||||
|
||||
def text(self, x, y, text):
|
||||
def text(self, x, y, text) -> None:
|
||||
"""Print a text on ASCII canvas."""
|
||||
for i, char in enumerate(text):
|
||||
self.point(x + i, y, char)
|
||||
|
||||
def box(self, x0, y0, width, height):
|
||||
def box(self, x0, y0, width, height) -> None:
|
||||
"""Create a box on ASCII canvas."""
|
||||
assert width > 1
|
||||
assert height > 1
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ class Graph:
|
|||
graph_dict["endpoint_name"] = str(endpoint_name)
|
||||
return graph_dict
|
||||
|
||||
def add_nodes_and_edges(self, nodes: list[NodeData], edges: list[EdgeData]):
|
||||
def add_nodes_and_edges(self, nodes: list[NodeData], edges: list[EdgeData]) -> None:
|
||||
self._vertices = nodes
|
||||
self._edges = edges
|
||||
self.raw_graph_data = {"nodes": nodes, "edges": edges}
|
||||
|
|
@ -238,7 +238,7 @@ class Graph:
|
|||
|
||||
return component_id
|
||||
|
||||
def _set_start_and_end(self, start: Component, end: Component):
|
||||
def _set_start_and_end(self, start: Component, end: Component) -> None:
|
||||
if not hasattr(start, "to_frontend_node"):
|
||||
msg = f"start must be a Component. Got {type(start)}"
|
||||
raise TypeError(msg)
|
||||
|
|
@ -248,7 +248,7 @@ class Graph:
|
|||
self.add_component(start, start._id)
|
||||
self.add_component(end, end._id)
|
||||
|
||||
def add_component_edge(self, source_id: str, output_input_tuple: tuple[str, str], target_id: str):
|
||||
def add_component_edge(self, source_id: str, output_input_tuple: tuple[str, str], target_id: str) -> None:
|
||||
source_vertex = self.get_vertex(source_id)
|
||||
if not isinstance(source_vertex, ComponentVertex):
|
||||
msg = f"Source vertex {source_id} is not a component vertex."
|
||||
|
|
@ -337,7 +337,7 @@ class Graph:
|
|||
"run_manager": copy.deepcopy(self.run_manager.to_dict()),
|
||||
}
|
||||
|
||||
def __apply_config(self, config: StartConfigDict):
|
||||
def __apply_config(self, config: StartConfigDict) -> None:
|
||||
for vertex in self.vertices:
|
||||
if vertex._custom_component is None:
|
||||
continue
|
||||
|
|
@ -373,7 +373,7 @@ class Graph:
|
|||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def _add_edge(self, edge: EdgeData):
|
||||
def _add_edge(self, edge: EdgeData) -> None:
|
||||
self.add_edge(edge)
|
||||
source_id = edge["data"]["sourceHandle"]["id"]
|
||||
target_id = edge["data"]["targetHandle"]["id"]
|
||||
|
|
@ -382,16 +382,16 @@ class Graph:
|
|||
self.in_degree_map[target_id] += 1
|
||||
self.parent_child_map[source_id].append(target_id)
|
||||
|
||||
def add_node(self, node: NodeData):
|
||||
def add_node(self, node: NodeData) -> None:
|
||||
self._vertices.append(node)
|
||||
|
||||
def add_edge(self, edge: EdgeData):
|
||||
def add_edge(self, edge: EdgeData) -> None:
|
||||
# Check if the edge already exists
|
||||
if edge in self._edges:
|
||||
return
|
||||
self._edges.append(edge)
|
||||
|
||||
def initialize(self):
|
||||
def initialize(self) -> None:
|
||||
self._build_graph()
|
||||
self.build_graph_maps(self.edges)
|
||||
self.define_vertices_lists()
|
||||
|
|
@ -424,7 +424,7 @@ class Graph:
|
|||
|
||||
self.state_manager.update_state(name, record, run_id=self._run_id)
|
||||
|
||||
def activate_state_vertices(self, name: str, caller: str):
|
||||
def activate_state_vertices(self, name: str, caller: str) -> None:
|
||||
"""Activates the state vertices in the graph with the given name and caller.
|
||||
|
||||
Args:
|
||||
|
|
@ -473,7 +473,7 @@ class Graph:
|
|||
vertices_to_run=self.vertices_to_run,
|
||||
)
|
||||
|
||||
def reset_activated_vertices(self):
|
||||
def reset_activated_vertices(self) -> None:
|
||||
"""Resets the activated vertices in the graph."""
|
||||
self.activated_vertices = []
|
||||
|
||||
|
|
@ -490,7 +490,7 @@ class Graph:
|
|||
|
||||
self.state_manager.append_state(name, record, run_id=self._run_id)
|
||||
|
||||
def validate_stream(self):
|
||||
def validate_stream(self) -> None:
|
||||
"""Validates the stream configuration of the graph.
|
||||
|
||||
If there are two vertices in the same graph (connected by edges)
|
||||
|
|
@ -548,7 +548,7 @@ class Graph:
|
|||
raise ValueError(msg)
|
||||
return self._run_id
|
||||
|
||||
def set_run_id(self, run_id: uuid.UUID | None = None):
|
||||
def set_run_id(self, run_id: uuid.UUID | None = None) -> None:
|
||||
"""Sets the ID of the current run.
|
||||
|
||||
Args:
|
||||
|
|
@ -564,7 +564,7 @@ class Graph:
|
|||
if self.tracing_service:
|
||||
self.tracing_service.set_run_id(run_id)
|
||||
|
||||
def set_run_name(self):
|
||||
def set_run_name(self) -> None:
|
||||
# Given a flow name, flow_id
|
||||
if not self.tracing_service:
|
||||
return
|
||||
|
|
@ -573,16 +573,16 @@ class Graph:
|
|||
self.set_run_id()
|
||||
self.tracing_service.set_run_name(name)
|
||||
|
||||
async def initialize_run(self):
|
||||
async def initialize_run(self) -> None:
|
||||
if self.tracing_service:
|
||||
await self.tracing_service.initialize_tracers()
|
||||
|
||||
def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None):
|
||||
def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
|
||||
task = asyncio.create_task(self.end_all_traces(outputs, error))
|
||||
self._end_trace_tasks.add(task)
|
||||
task.add_done_callback(self._end_trace_tasks.discard)
|
||||
|
||||
async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None):
|
||||
async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None:
|
||||
if not self.tracing_service:
|
||||
return
|
||||
self._end_time = datetime.now(timezone.utc)
|
||||
|
|
@ -602,7 +602,7 @@ class Graph:
|
|||
self.sort_vertices()
|
||||
return self._sorted_vertices_layers
|
||||
|
||||
def define_vertices_lists(self):
|
||||
def define_vertices_lists(self) -> None:
|
||||
"""Defines the lists of vertices that are inputs, outputs, and have session_id."""
|
||||
attributes = ["is_input", "is_output", "has_session_id", "is_state"]
|
||||
for vertex in self.vertices:
|
||||
|
|
@ -610,7 +610,7 @@ class Graph:
|
|||
if getattr(vertex, attribute):
|
||||
getattr(self, f"_{attribute}_vertices").append(vertex.id)
|
||||
|
||||
def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None):
|
||||
def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None) -> None:
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
# If the vertex is not in the input_components list
|
||||
|
|
@ -838,7 +838,7 @@ class Graph:
|
|||
"flow_name": self.flow_name,
|
||||
}
|
||||
|
||||
def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list[Vertex] | None = None):
|
||||
def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list[Vertex] | None = None) -> None:
|
||||
"""Builds the adjacency maps for the graph."""
|
||||
if edges is None:
|
||||
edges = self.edges
|
||||
|
|
@ -851,26 +851,28 @@ class Graph:
|
|||
self.in_degree_map = self.build_in_degree(edges)
|
||||
self.parent_child_map = self.build_parent_child_map(vertices)
|
||||
|
||||
def reset_inactivated_vertices(self):
|
||||
def reset_inactivated_vertices(self) -> None:
|
||||
"""Resets the inactivated vertices in the graph."""
|
||||
for vertex_id in self.inactivated_vertices.copy():
|
||||
self.mark_vertex(vertex_id, "ACTIVE")
|
||||
self.inactivated_vertices = []
|
||||
self.inactivated_vertices = set()
|
||||
self.inactivated_vertices = set()
|
||||
|
||||
def mark_all_vertices(self, state: str):
|
||||
def mark_all_vertices(self, state: str) -> None:
|
||||
"""Marks all vertices in the graph."""
|
||||
for vertex in self.vertices:
|
||||
vertex.set_state(state)
|
||||
|
||||
def mark_vertex(self, vertex_id: str, state: str):
|
||||
def mark_vertex(self, vertex_id: str, state: str) -> None:
|
||||
"""Marks a vertex in the graph."""
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
vertex.set_state(state)
|
||||
if state == VertexStates.INACTIVE:
|
||||
self.run_manager.remove_from_predecessors(vertex_id)
|
||||
|
||||
def _mark_branch(self, vertex_id: str, state: str, visited: set | None = None, output_name: str | None = None):
|
||||
def _mark_branch(
|
||||
self, vertex_id: str, state: str, visited: set | None = None, output_name: str | None = None
|
||||
) -> None:
|
||||
"""Marks a branch of the graph."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
|
@ -889,7 +891,7 @@ class Graph:
|
|||
continue
|
||||
self._mark_branch(child_id, state, visited)
|
||||
|
||||
def mark_branch(self, vertex_id: str, state: str, output_name: str | None = None):
|
||||
def mark_branch(self, vertex_id: str, state: str, output_name: str | None = None) -> None:
|
||||
self._mark_branch(vertex_id=vertex_id, state=state, output_name=output_name)
|
||||
new_predecessor_map, _ = self.build_adjacency_maps(self.edges)
|
||||
self.run_manager.update_run_state(
|
||||
|
|
@ -910,10 +912,10 @@ class Graph:
|
|||
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
|
||||
return parent_child_map
|
||||
|
||||
def increment_run_count(self):
|
||||
def increment_run_count(self) -> None:
|
||||
self._runs += 1
|
||||
|
||||
def increment_update_count(self):
|
||||
def increment_update_count(self) -> None:
|
||||
self._updates += 1
|
||||
|
||||
def __getstate__(self):
|
||||
|
|
@ -1239,7 +1241,7 @@ class Graph:
|
|||
return None
|
||||
return self._run_queue.popleft()
|
||||
|
||||
def extend_run_queue(self, vertices: list[str]):
|
||||
def extend_run_queue(self, vertices: list[str]) -> None:
|
||||
self._run_queue.extend(vertices)
|
||||
|
||||
async def astep(
|
||||
|
|
@ -1292,7 +1294,7 @@ class Graph:
|
|||
}
|
||||
)
|
||||
|
||||
def _record_snapshot(self, vertex_id: str | None = None):
|
||||
def _record_snapshot(self, vertex_id: str | None = None) -> None:
|
||||
self._snapshots.append(self.get_snapshot())
|
||||
if vertex_id:
|
||||
self._call_order.append(vertex_id)
|
||||
|
|
@ -1557,7 +1559,7 @@ class Graph:
|
|||
state = dict.fromkeys(self.vertices, 0)
|
||||
sorted_vertices = []
|
||||
|
||||
def dfs(vertex):
|
||||
def dfs(vertex) -> None:
|
||||
if state[vertex] == 1:
|
||||
# We have a cycle
|
||||
msg = "Graph contains a cycle, cannot perform topological sort"
|
||||
|
|
@ -1770,7 +1772,7 @@ class Graph:
|
|||
children.append(vertex)
|
||||
return children
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
vertex_ids = [vertex.id for vertex in self.vertices]
|
||||
edges_repr = "\n".join([f" {edge.source_id} --> {edge.target_id}" for edge in self.edges])
|
||||
|
||||
|
|
@ -2009,7 +2011,7 @@ class Graph:
|
|||
is_active = self.get_vertex(vertex_id).is_active()
|
||||
return self.run_manager.is_vertex_runnable(vertex_id, is_active=is_active)
|
||||
|
||||
def build_run_map(self):
|
||||
def build_run_map(self) -> None:
|
||||
"""Builds the run map for the graph.
|
||||
|
||||
This method is responsible for building the run map for the graph,
|
||||
|
|
@ -2036,7 +2038,7 @@ class Graph:
|
|||
runnable_vertices = []
|
||||
visited = set()
|
||||
|
||||
def find_runnable_predecessors(predecessor: Vertex):
|
||||
def find_runnable_predecessors(predecessor: Vertex) -> None:
|
||||
predecessor_id = predecessor.id
|
||||
if predecessor_id in visited:
|
||||
return
|
||||
|
|
@ -2052,10 +2054,10 @@ class Graph:
|
|||
find_runnable_predecessors(self.get_vertex(predecessor_id))
|
||||
return runnable_vertices
|
||||
|
||||
def remove_from_predecessors(self, vertex_id: str):
|
||||
def remove_from_predecessors(self, vertex_id: str) -> None:
|
||||
self.run_manager.remove_from_predecessors(vertex_id)
|
||||
|
||||
def remove_vertex_from_runnables(self, vertex_id: str):
|
||||
def remove_vertex_from_runnables(self, vertex_id: str) -> None:
|
||||
self.run_manager.remove_vertex_from_runnables(vertex_id)
|
||||
|
||||
def get_top_level_vertices(self, vertices_ids):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.utils.lazy_load import LazyLoadDictBase
|
|||
|
||||
|
||||
class Finish:
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
|
|
@ -17,7 +17,7 @@ def _import_vertex_types():
|
|||
|
||||
|
||||
class VertexTypesDict(LazyLoadDictBase):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._all_types_dict = None
|
||||
self._types = _import_vertex_types
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from collections import defaultdict
|
|||
|
||||
|
||||
class RunnableVerticesManager:
|
||||
def __init__(self):
|
||||
self.run_map = defaultdict(list) # Tracks successors of each vertex
|
||||
self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex
|
||||
self.vertices_to_run = set() # Set of vertices that are ready to run
|
||||
self.vertices_being_run = set() # Set of vertices that are currently running
|
||||
def __init__(self) -> None:
|
||||
self.run_map: dict[str, list[str]] = defaultdict(list) # Tracks successors of each vertex
|
||||
self.run_predecessors: dict[str, set[str]] = defaultdict(set) # Tracks predecessors for each vertex
|
||||
self.vertices_to_run: set[str] = set() # Set of vertices that are ready to run
|
||||
self.vertices_being_run: set[str] = set() # Set of vertices that are currently running
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
|
@ -42,7 +42,7 @@ class RunnableVerticesManager:
|
|||
def all_predecessors_are_fulfilled(self) -> bool:
|
||||
return all(not value for value in self.run_predecessors.values())
|
||||
|
||||
def update_run_state(self, run_predecessors: dict, vertices_to_run: set):
|
||||
def update_run_state(self, run_predecessors: dict, vertices_to_run: set) -> None:
|
||||
self.run_predecessors.update(run_predecessors)
|
||||
self.vertices_to_run.update(vertices_to_run)
|
||||
self.build_run_map(self.run_predecessors, self.vertices_to_run)
|
||||
|
|
@ -60,14 +60,14 @@ class RunnableVerticesManager:
|
|||
def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool:
|
||||
return not any(self.run_predecessors.get(vertex_id, []))
|
||||
|
||||
def remove_from_predecessors(self, vertex_id: str):
|
||||
def remove_from_predecessors(self, vertex_id: str) -> None:
|
||||
"""Removes a vertex from the predecessor list of its successors."""
|
||||
predecessors = self.run_map.get(vertex_id, [])
|
||||
for predecessor in predecessors:
|
||||
if vertex_id in self.run_predecessors[predecessor]:
|
||||
self.run_predecessors[predecessor].remove(vertex_id)
|
||||
|
||||
def build_run_map(self, predecessor_map, vertices_to_run):
|
||||
def build_run_map(self, predecessor_map, vertices_to_run) -> None:
|
||||
"""Builds a map of vertices and their runnable successors."""
|
||||
self.run_map = defaultdict(list)
|
||||
for vertex_id, predecessors in predecessor_map.items():
|
||||
|
|
@ -76,16 +76,16 @@ class RunnableVerticesManager:
|
|||
self.run_predecessors = predecessor_map.copy()
|
||||
self.vertices_to_run = vertices_to_run
|
||||
|
||||
def update_vertex_run_state(self, vertex_id: str, *, is_runnable: bool):
|
||||
def update_vertex_run_state(self, vertex_id: str, *, is_runnable: bool) -> None:
|
||||
"""Updates the runnable state of a vertex."""
|
||||
if is_runnable:
|
||||
self.vertices_to_run.add(vertex_id)
|
||||
else:
|
||||
self.vertices_being_run.discard(vertex_id)
|
||||
|
||||
def remove_vertex_from_runnables(self, v_id):
|
||||
def remove_vertex_from_runnables(self, v_id) -> None:
|
||||
self.update_vertex_run_state(v_id, is_runnable=False)
|
||||
self.remove_from_predecessors(v_id)
|
||||
|
||||
def add_to_vertices_being_run(self, v_id):
|
||||
def add_to_vertices_being_run(self, v_id) -> None:
|
||||
self.vertices_being_run.add(v_id)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class GraphStateManager:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
try:
|
||||
self.state_service: StateService = get_state_service()
|
||||
except Exception: # noqa: BLE001
|
||||
|
|
@ -22,26 +22,14 @@ class GraphStateManager:
|
|||
|
||||
self.state_service = InMemoryStateService(get_settings_service())
|
||||
|
||||
def append_state(self, key, new_state, run_id: str):
|
||||
def append_state(self, key, new_state, run_id: str) -> None:
|
||||
self.state_service.append_state(key, new_state, run_id)
|
||||
|
||||
def update_state(self, key, new_state, run_id: str):
|
||||
def update_state(self, key, new_state, run_id: str) -> None:
|
||||
self.state_service.update_state(key, new_state, run_id)
|
||||
|
||||
def get_state(self, key, run_id: str):
|
||||
return self.state_service.get_state(key, run_id)
|
||||
|
||||
def subscribe(self, key, observer: Callable):
|
||||
def subscribe(self, key, observer: Callable) -> None:
|
||||
self.state_service.subscribe(key, observer)
|
||||
|
||||
def notify_observers(self, key, new_state):
|
||||
for callback in self.observers[key]:
|
||||
callback(key, new_state, append=False)
|
||||
|
||||
def notify_append_observers(self, key, new_state):
|
||||
for callback in self.observers[key]:
|
||||
try:
|
||||
callback(key, new_state, append=True)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error in observer {callback} for key {key}")
|
||||
logger.warning("Callbacks not implemented yet")
|
||||
|
|
|
|||
|
|
@ -27,13 +27,13 @@ def find_last_node(nodes, edges):
|
|||
return next((n for n in nodes if all(e["source"] != n["id"] for e in edges)), None)
|
||||
|
||||
|
||||
def add_parent_node_id(nodes, parent_node_id):
|
||||
def add_parent_node_id(nodes, parent_node_id) -> None:
|
||||
"""This function receives a list of nodes and adds a parent_node_id to each node."""
|
||||
for node in nodes:
|
||||
node["parent_node_id"] = parent_node_id
|
||||
|
||||
|
||||
def add_frozen(nodes, frozen):
|
||||
def add_frozen(nodes, frozen) -> None:
|
||||
"""This function receives a list of nodes and adds a frozen to each node."""
|
||||
for node in nodes:
|
||||
node["data"]["node"]["frozen"] = frozen
|
||||
|
|
@ -75,7 +75,7 @@ def process_flow(flow_object):
|
|||
cloned_flow = copy.deepcopy(flow_object)
|
||||
processed_nodes = set() # To keep track of processed nodes
|
||||
|
||||
def process_node(node):
|
||||
def process_node(node) -> None:
|
||||
node_id = node.get("id")
|
||||
|
||||
# If node already processed, skip
|
||||
|
|
@ -100,7 +100,7 @@ def process_flow(flow_object):
|
|||
return cloned_flow
|
||||
|
||||
|
||||
def update_template(template, g_nodes):
|
||||
def update_template(template, g_nodes) -> None:
|
||||
"""Updates the template of a node in a graph with the given template.
|
||||
|
||||
Args:
|
||||
|
|
@ -149,7 +149,7 @@ def update_target_handle(new_edge, g_nodes):
|
|||
return new_edge
|
||||
|
||||
|
||||
def set_new_target_handle(proxy_id, new_edge, target_handle, node):
|
||||
def set_new_target_handle(proxy_id, new_edge, target_handle, node) -> None:
|
||||
"""Sets a new target handle for a given edge.
|
||||
|
||||
Args:
|
||||
|
|
@ -330,7 +330,7 @@ def has_cycle(vertex_ids: list[str], edges: list[tuple[str, str]]) -> bool:
|
|||
graph[u].append(v)
|
||||
|
||||
# Utility function to perform DFS
|
||||
def dfs(v, visited, rec_stack):
|
||||
def dfs(v, visited, rec_stack) -> bool:
|
||||
visited.add(v)
|
||||
rec_stack.add(v)
|
||||
|
||||
|
|
|
|||
|
|
@ -126,10 +126,10 @@ def build_output_setter(method: Callable, *, validate: bool = True) -> Callable:
|
|||
>>> print(component.get_output_by_method(component.set_message).value) # Prints "New message"
|
||||
"""
|
||||
|
||||
def output_setter(self, value): # noqa: ARG001
|
||||
def output_setter(self, value) -> None: # noqa: ARG001
|
||||
if validate:
|
||||
__validate_method(method)
|
||||
methods_class = method.__self__
|
||||
methods_class = method.__self__ # type: ignore[attr-defined]
|
||||
output = methods_class.get_output_by_method(method)
|
||||
output.value = value
|
||||
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ def log_vertex_build(
|
|||
params: Any,
|
||||
data: ResultDataResponse,
|
||||
artifacts: dict | None = None,
|
||||
):
|
||||
) -> None:
|
||||
try:
|
||||
if not get_settings_service().settings.vertex_builds_storage_enabled:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -75,8 +75,8 @@ class Vertex:
|
|||
self.base_type: str | None = base_type
|
||||
self.outputs: list[dict] = []
|
||||
self._parse_data()
|
||||
self._built_object = UnbuiltObject()
|
||||
self._built_result = None
|
||||
self._built_object: Any = UnbuiltObject()
|
||||
self._built_result: Any = None
|
||||
self._built = False
|
||||
self._successors_ids: list[str] | None = None
|
||||
self.artifacts: dict[str, Any] = {}
|
||||
|
|
@ -106,7 +106,7 @@ class Vertex:
|
|||
self.state = VertexStates.ACTIVE
|
||||
self.log_transaction_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def set_input_value(self, name: str, value: Any):
|
||||
def set_input_value(self, name: str, value: Any) -> None:
|
||||
if self._custom_component is None:
|
||||
msg = f"Vertex {self.id} does not have a component instance."
|
||||
raise ValueError(msg)
|
||||
|
|
@ -115,20 +115,20 @@ class Vertex:
|
|||
def to_data(self):
|
||||
return self._data
|
||||
|
||||
def add_component_instance(self, component_instance: Component):
|
||||
def add_component_instance(self, component_instance: Component) -> None:
|
||||
component_instance.set_vertex(self)
|
||||
self._custom_component = component_instance
|
||||
|
||||
def add_result(self, name: str, result: Any):
|
||||
def add_result(self, name: str, result: Any) -> None:
|
||||
self.results[name] = result
|
||||
|
||||
def update_graph_state(self, key, new_state, *, append: bool):
|
||||
def update_graph_state(self, key, new_state, *, append: bool) -> None:
|
||||
if append:
|
||||
self.graph.append_state(key, new_state, caller=self.id)
|
||||
else:
|
||||
self.graph.update_state(key, new_state, caller=self.id)
|
||||
|
||||
def set_state(self, state: str):
|
||||
def set_state(self, state: str) -> None:
|
||||
self.state = VertexStates[state]
|
||||
if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] <= 1:
|
||||
# If the vertex is inactive and has only one in degree
|
||||
|
|
@ -144,7 +144,7 @@ class Vertex:
|
|||
def avg_build_time(self):
|
||||
return sum(self.build_times) / len(self.build_times) if self.build_times else 0
|
||||
|
||||
def add_build_time(self, time):
|
||||
def add_build_time(self, time) -> None:
|
||||
self.build_times.append(time)
|
||||
|
||||
def set_result(self, result: ResultData) -> None:
|
||||
|
|
@ -300,7 +300,7 @@ class Vertex:
|
|||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
return params
|
||||
|
||||
def _build_params(self):
|
||||
def _build_params(self) -> None:
|
||||
# sourcery skip: merge-list-append, remove-redundant-if
|
||||
# Some params are required, some are optional
|
||||
# but most importantly, some params are python base classes
|
||||
|
|
@ -326,7 +326,7 @@ class Vertex:
|
|||
return
|
||||
|
||||
template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)}
|
||||
params = {}
|
||||
params: dict = {}
|
||||
|
||||
for edge in self.edges:
|
||||
if not hasattr(edge, "target_param"):
|
||||
|
|
@ -438,7 +438,7 @@ class Vertex:
|
|||
self.load_from_db_fields = load_from_db_fields
|
||||
self._raw_params = params.copy()
|
||||
|
||||
def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False):
|
||||
def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False) -> None:
|
||||
"""Update the raw parameters of the vertex with the given new parameters.
|
||||
|
||||
Args:
|
||||
|
|
@ -466,7 +466,7 @@ class Vertex:
|
|||
"""Checks if the vertex has any cycle edges."""
|
||||
return self._has_cycle_edges
|
||||
|
||||
async def instantiate_component(self, user_id=None):
|
||||
async def instantiate_component(self, user_id=None) -> None:
|
||||
if not self._custom_component:
|
||||
self._custom_component, _ = await initialize.loading.instantiate_class(
|
||||
user_id=user_id,
|
||||
|
|
@ -478,7 +478,7 @@ class Vertex:
|
|||
fallback_to_env_vars,
|
||||
user_id=None,
|
||||
event_manager: EventManager | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Initiate the build process."""
|
||||
logger.debug(f"Building {self.display_name}")
|
||||
await self._build_each_vertex_in_params_dict()
|
||||
|
|
@ -500,6 +500,7 @@ class Vertex:
|
|||
custom_component=custom_component,
|
||||
custom_params=custom_params,
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
base_type=self.base_type,
|
||||
)
|
||||
|
||||
self._validate_built_object()
|
||||
|
|
@ -545,7 +546,7 @@ class Vertex:
|
|||
|
||||
return messages
|
||||
|
||||
def _finalize_build(self):
|
||||
def _finalize_build(self) -> None:
|
||||
result_dict = self.get_built_result()
|
||||
# We need to set the artifacts to pass information
|
||||
# to the frontend
|
||||
|
|
@ -563,7 +564,7 @@ class Vertex:
|
|||
)
|
||||
self.set_result(result_dict)
|
||||
|
||||
async def _build_each_vertex_in_params_dict(self):
|
||||
async def _build_each_vertex_in_params_dict(self) -> None:
|
||||
"""Iterates over each vertex in the params dictionary and builds it."""
|
||||
for key, value in self._raw_params.items():
|
||||
if self._is_vertex(value):
|
||||
|
|
@ -588,7 +589,7 @@ class Vertex:
|
|||
self,
|
||||
key,
|
||||
vertices_dict: dict[str, Vertex],
|
||||
):
|
||||
) -> None:
|
||||
"""Iterates over a dictionary of vertices, builds each and updates the params dictionary."""
|
||||
for sub_key, value in vertices_dict.items():
|
||||
if not self._is_vertex(value):
|
||||
|
|
@ -647,7 +648,7 @@ class Vertex:
|
|||
self._log_transaction_async(str(flow_id), source=self, target=requester, status="success")
|
||||
return result
|
||||
|
||||
async def _build_vertex_and_update_params(self, key, vertex: Vertex):
|
||||
async def _build_vertex_and_update_params(self, key, vertex: Vertex) -> None:
|
||||
"""Builds a given vertex and updates the params dictionary accordingly."""
|
||||
result = await vertex.get_result(self, target_handle_name=key)
|
||||
self._handle_func(key, result)
|
||||
|
|
@ -659,7 +660,7 @@ class Vertex:
|
|||
self,
|
||||
key,
|
||||
vertices: list[Vertex],
|
||||
):
|
||||
) -> None:
|
||||
"""Iterates over a list of vertices, builds each and updates the params dictionary."""
|
||||
self.params[key] = []
|
||||
for vertex in vertices:
|
||||
|
|
@ -685,7 +686,7 @@ class Vertex:
|
|||
)
|
||||
raise ValueError(msg) from e
|
||||
|
||||
def _handle_func(self, key, result):
|
||||
def _handle_func(self, key, result) -> None:
|
||||
"""Handles 'func' key by checking if the result is a function and setting it as coroutine."""
|
||||
if key == "func":
|
||||
if not isinstance(result, types.FunctionType):
|
||||
|
|
@ -698,19 +699,21 @@ class Vertex:
|
|||
else:
|
||||
self.params["coroutine"] = sync_to_async(result)
|
||||
|
||||
def _extend_params_list_with_result(self, key, result):
|
||||
def _extend_params_list_with_result(self, key, result) -> None:
|
||||
"""Extends a list in the params dictionary with the given result if it exists."""
|
||||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
|
||||
async def _build_results(self, custom_component, custom_params, *, fallback_to_env_vars=False):
|
||||
async def _build_results(
|
||||
self, custom_component, custom_params, base_type: str, *, fallback_to_env_vars=False
|
||||
) -> None:
|
||||
try:
|
||||
result = await initialize.loading.get_instance_results(
|
||||
custom_component=custom_component,
|
||||
custom_params=custom_params,
|
||||
vertex=self,
|
||||
fallback_to_env_vars=fallback_to_env_vars,
|
||||
base_type=self.base_type,
|
||||
base_type=base_type,
|
||||
)
|
||||
|
||||
self.outputs_logs = build_output_logs(self, result)
|
||||
|
|
@ -722,7 +725,7 @@ class Vertex:
|
|||
msg = f"Error building Component {self.display_name}: \n\n{exc}"
|
||||
raise ComponentBuildError(msg, tb) from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]):
|
||||
def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]) -> None:
|
||||
"""Updates the built object and its artifacts."""
|
||||
if isinstance(result, tuple):
|
||||
if len(result) == 2: # noqa: PLR2004
|
||||
|
|
@ -738,7 +741,7 @@ class Vertex:
|
|||
else:
|
||||
self._built_object = result
|
||||
|
||||
def _validate_built_object(self):
|
||||
def _validate_built_object(self) -> None:
|
||||
"""Checks if the built object is None and raises a ValueError if so."""
|
||||
if isinstance(self._built_object, UnbuiltObject):
|
||||
msg = f"{self.display_name}: {self._built_object_repr()}"
|
||||
|
|
@ -754,7 +757,7 @@ class Vertex:
|
|||
msg = f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead."
|
||||
raise ValueError(msg)
|
||||
|
||||
def _reset(self):
|
||||
def _reset(self) -> None:
|
||||
self._built = False
|
||||
self._built_object = UnbuiltObject()
|
||||
self._built_result = UnbuiltResult()
|
||||
|
|
@ -762,10 +765,10 @@ class Vertex:
|
|||
self.steps_ran = []
|
||||
self._build_params()
|
||||
|
||||
def _is_chat_input(self):
|
||||
def _is_chat_input(self) -> bool:
|
||||
return False
|
||||
|
||||
def build_inactive(self):
|
||||
def build_inactive(self) -> None:
|
||||
# Just set the results to None
|
||||
self._built = True
|
||||
self._built_object = None
|
||||
|
|
@ -865,11 +868,11 @@ class Vertex:
|
|||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def _built_object_repr(self):
|
||||
def _built_object_repr(self) -> str:
|
||||
# Add a message with an emoji, stars for sucess,
|
||||
return "Built successfully ✨" if self._built_object is not None else "Failed to build 😵💫"
|
||||
|
||||
def apply_on_outputs(self, func: Callable[[Any], Any]):
|
||||
def apply_on_outputs(self, func: Callable[[Any], Any]) -> None:
|
||||
"""Applies a function to the outputs of the vertex."""
|
||||
if not self._custom_component or not self._custom_component.outputs:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class ComponentVertex(Vertex):
|
|||
return self.artifacts["repr"] or super()._built_object_repr()
|
||||
return None
|
||||
|
||||
def _update_built_object_and_artifacts(self, result):
|
||||
def _update_built_object_and_artifacts(self, result) -> None:
|
||||
"""Updates the built object and its artifacts."""
|
||||
if isinstance(result, tuple):
|
||||
if len(result) == 2: # noqa: PLR2004
|
||||
|
|
@ -182,7 +182,7 @@ class ComponentVertex(Vertex):
|
|||
)
|
||||
return messages
|
||||
|
||||
def _finalize_build(self):
|
||||
def _finalize_build(self) -> None:
|
||||
result_dict = self.get_built_result()
|
||||
# We need to set the artifacts to pass information
|
||||
# to the frontend
|
||||
|
|
@ -206,7 +206,7 @@ class InterfaceVertex(ComponentVertex):
|
|||
self.steps = [self._build, self._run]
|
||||
self.is_interface_component = True
|
||||
|
||||
def build_stream_url(self):
|
||||
def build_stream_url(self) -> str:
|
||||
return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream"
|
||||
|
||||
def _built_object_repr(self):
|
||||
|
|
@ -352,21 +352,17 @@ class InterfaceVertex(ComponentVertex):
|
|||
self.artifacts = DataOutputResponse(data=artifacts)
|
||||
return self._built_object
|
||||
|
||||
async def _run(self, *args, **kwargs):
|
||||
if self.is_interface_component:
|
||||
if self.vertex_type in CHAT_COMPONENTS:
|
||||
message = self._process_chat_component()
|
||||
elif self.vertex_type in RECORDS_COMPONENTS:
|
||||
message = self._process_data_component()
|
||||
if isinstance(self._built_object, AsyncIterator | Iterator):
|
||||
if self.params.get("return_data", False):
|
||||
self._built_object = Data(text=message, data=self.artifacts)
|
||||
else:
|
||||
self._built_object = message
|
||||
self._built_result = self._built_object
|
||||
|
||||
else:
|
||||
await super()._run(*args, **kwargs)
|
||||
async def _run(self, *args, **kwargs) -> None: # noqa: ARG002
|
||||
if self.vertex_type in CHAT_COMPONENTS:
|
||||
message = self._process_chat_component()
|
||||
elif self.vertex_type in RECORDS_COMPONENTS:
|
||||
message = self._process_data_component()
|
||||
if isinstance(self._built_object, AsyncIterator | Iterator):
|
||||
if self.params.get("return_data", False):
|
||||
self._built_object = Data(text=message, data=self.artifacts)
|
||||
else:
|
||||
self._built_object = message
|
||||
self._built_result = self._built_object
|
||||
|
||||
async def stream(self):
|
||||
iterator = self.params.get(INPUT_FIELD_NAME, None)
|
||||
|
|
@ -452,7 +448,7 @@ class InterfaceVertex(ComponentVertex):
|
|||
self._validate_built_object()
|
||||
self._built = True
|
||||
|
||||
async def consume_async_generator(self):
|
||||
async def consume_async_generator(self) -> None:
|
||||
async for _ in self.stream():
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from uuid import UUID
|
|||
import orjson
|
||||
from emoji import demojize, purely_emoji
|
||||
from loguru import logger
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.base.constants import (
|
||||
|
|
@ -340,7 +341,7 @@ def update_edges_with_latest_component_versions(project_data):
|
|||
return project_data_copy
|
||||
|
||||
|
||||
def log_node_changes(node_changes_log):
|
||||
def log_node_changes(node_changes_log) -> None:
|
||||
# The idea here is to log the changes that were made to the nodes in debug
|
||||
# Something like:
|
||||
# Node: "Node Name" was updated with the following changes:
|
||||
|
|
@ -377,8 +378,11 @@ def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]:
|
|||
return starter_projects
|
||||
|
||||
|
||||
def copy_profile_pictures():
|
||||
def copy_profile_pictures() -> None:
|
||||
config_dir = get_storage_service().settings_service.settings.config_dir
|
||||
if config_dir is None:
|
||||
msg = "Config dir is not set in the settings"
|
||||
raise ValueError(msg)
|
||||
origin = Path(__file__).parent / "profile_pictures"
|
||||
target = Path(config_dir) / "profile_pictures"
|
||||
|
||||
|
|
@ -425,7 +429,7 @@ def get_project_data(project):
|
|||
)
|
||||
|
||||
|
||||
def update_project_file(project_path: Path, project: dict, updated_project_data):
|
||||
def update_project_file(project_path: Path, project: dict, updated_project_data) -> None:
|
||||
project["data"] = updated_project_data
|
||||
project_path.write_text(orjson.dumps(project, option=ORJSON_OPTIONS).decode(), encoding="utf-8")
|
||||
logger.info(f"Updated starter project {project['name']} file")
|
||||
|
|
@ -440,7 +444,7 @@ def update_existing_project(
|
|||
project_data,
|
||||
project_icon,
|
||||
project_icon_bg_color,
|
||||
):
|
||||
) -> None:
|
||||
logger.info(f"Updating starter project {project_name}")
|
||||
existing_project.data = project_data
|
||||
existing_project.folder = STARTER_FOLDER_NAME
|
||||
|
|
@ -463,7 +467,7 @@ def create_new_project(
|
|||
project_icon,
|
||||
project_icon_bg_color,
|
||||
new_folder_id,
|
||||
):
|
||||
) -> None:
|
||||
logger.debug(f"Creating starter project {project_name}")
|
||||
new_project = FlowCreate(
|
||||
name=project_name,
|
||||
|
|
@ -485,7 +489,7 @@ def get_all_flows_similar_to_project(session, folder_id):
|
|||
return session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
|
||||
|
||||
|
||||
def delete_start_projects(session, folder_id):
|
||||
def delete_start_projects(session, folder_id) -> None:
|
||||
flows = session.exec(select(Folder).where(Folder.id == folder_id)).first().flows
|
||||
for flow in flows:
|
||||
session.delete(flow)
|
||||
|
|
@ -516,7 +520,7 @@ def _is_valid_uuid(val):
|
|||
return str(uuid_obj) == val
|
||||
|
||||
|
||||
def load_flows_from_directory():
|
||||
def load_flows_from_directory() -> None:
|
||||
"""On langflow startup, this loads all flows from the directory specified in the settings.
|
||||
|
||||
All flows are uploaded into the default folder for the superuser.
|
||||
|
|
@ -531,7 +535,11 @@ def load_flows_from_directory():
|
|||
return
|
||||
|
||||
with session_scope() as session:
|
||||
user_id = get_user_by_username(session, settings_service.auth_settings.SUPERUSER).id
|
||||
user = get_user_by_username(session, settings_service.auth_settings.SUPERUSER)
|
||||
if user is None:
|
||||
msg = "Superuser not found in the database"
|
||||
raise NoResultFound(msg)
|
||||
user_id = user.id
|
||||
_flows_path = Path(flows_path)
|
||||
files = [f for f in _flows_path.iterdir() if f.is_file()]
|
||||
for f in files:
|
||||
|
|
@ -592,7 +600,7 @@ def find_existing_flow(session, flow_id, flow_endpoint_name):
|
|||
return None
|
||||
|
||||
|
||||
async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]):
|
||||
async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]) -> None:
|
||||
try:
|
||||
all_types_dict = await get_all_components_coro
|
||||
except Exception:
|
||||
|
|
@ -647,7 +655,7 @@ async def create_or_update_starter_projects(get_all_components_coro: Awaitable[d
|
|||
)
|
||||
|
||||
|
||||
def initialize_super_user_if_needed():
|
||||
def initialize_super_user_if_needed() -> None:
|
||||
settings_service = get_settings_service()
|
||||
if not settings_service.auth_settings.AUTO_LOGIN:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.utils.lazy_load import LazyLoadDictBase
|
|||
|
||||
|
||||
class AllTypesDict(LazyLoadDictBase):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._all_types_dict = None
|
||||
|
||||
def _build_dict(self):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def get_memory_key(langchain_object):
|
|||
return None # or some other default value or action
|
||||
|
||||
|
||||
def update_memory_keys(langchain_object, possible_new_mem_key):
|
||||
def update_memory_keys(langchain_object, possible_new_mem_key) -> None:
|
||||
"""Update the memory keys in the LangChain object's memory attribute.
|
||||
|
||||
Given a LangChain object and a possible new memory key, this function updates the input and output keys in the
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
|||
return variables
|
||||
|
||||
|
||||
def setup_llm_caching():
|
||||
def setup_llm_caching() -> None:
|
||||
"""Setup LLM caching."""
|
||||
settings_service = get_settings_service()
|
||||
try:
|
||||
|
|
@ -100,7 +100,7 @@ def setup_llm_caching():
|
|||
logger.opt(exception=True).warning("Could not setup LLM caching.")
|
||||
|
||||
|
||||
def set_langchain_cache(settings):
|
||||
def set_langchain_cache(settings) -> None:
|
||||
from langchain.globals import set_llm_cache
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class SizedLogBuffer:
|
|||
def get_write_lock(self) -> Lock:
|
||||
return self._wlock
|
||||
|
||||
def write(self, message: str):
|
||||
def write(self, message: str) -> None:
|
||||
record = json.loads(message)
|
||||
log_entry = record["text"]
|
||||
epoch = int(record["record"]["time"]["timestamp"] * 1000)
|
||||
|
|
@ -52,7 +52,7 @@ class SizedLogBuffer:
|
|||
self.buffer.popleft()
|
||||
self.buffer.append((epoch, log_entry))
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.buffer)
|
||||
|
||||
def get_after_timestamp(self, timestamp: int, lines: int = 5) -> dict[int, str]:
|
||||
|
|
@ -123,7 +123,7 @@ def serialize_log(record):
|
|||
return orjson.dumps(subset)
|
||||
|
||||
|
||||
def patching(record):
|
||||
def patching(record) -> None:
|
||||
record["extra"]["serialized"] = serialize_log(record)
|
||||
if DEV is False:
|
||||
record.pop("exception", None)
|
||||
|
|
@ -142,7 +142,7 @@ def configure(
|
|||
log_file: Path | None = None,
|
||||
disable: bool | None = False,
|
||||
log_env: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
if disable and log_level is None and log_file is None:
|
||||
logger.disable("langflow")
|
||||
if os.getenv("LANGFLOW_LOG_LEVEL", "").upper() in VALID_LOG_LEVELS and log_level is None:
|
||||
|
|
@ -205,14 +205,14 @@ def configure(
|
|||
setup_gunicorn_logger()
|
||||
|
||||
|
||||
def setup_uvicorn_logger():
|
||||
def setup_uvicorn_logger() -> None:
|
||||
loggers = (logging.getLogger(name) for name in logging.root.manager.loggerDict if name.startswith("uvicorn."))
|
||||
for uvicorn_logger in loggers:
|
||||
uvicorn_logger.handlers = []
|
||||
logging.getLogger("uvicorn").handlers = [InterceptHandler()]
|
||||
|
||||
|
||||
def setup_gunicorn_logger():
|
||||
def setup_gunicorn_logger() -> None:
|
||||
logging.getLogger("gunicorn.error").handlers = [InterceptHandler()]
|
||||
logging.getLogger("gunicorn.access").handlers = [InterceptHandler()]
|
||||
|
||||
|
|
@ -223,7 +223,7 @@ class InterceptHandler(logging.Handler):
|
|||
See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging.
|
||||
"""
|
||||
|
||||
def emit(self, record):
|
||||
def emit(self, record) -> None:
|
||||
# Get corresponding Loguru level if it exists
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
|
|
@ -232,7 +232,7 @@ class InterceptHandler(logging.Handler):
|
|||
|
||||
# Find caller from where originated the logged message
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
while frame.f_code.co_filename == logging.__file__ and frame.f_back:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
|
|
|
|||
|
|
@ -3,14 +3,14 @@ from loguru import logger
|
|||
LOGGING_CONFIGURED = False
|
||||
|
||||
|
||||
def disable_logging():
|
||||
def disable_logging() -> None:
|
||||
global LOGGING_CONFIGURED # noqa: PLW0603
|
||||
if not LOGGING_CONFIGURED:
|
||||
logger.disable("langflow")
|
||||
LOGGING_CONFIGURED = True
|
||||
|
||||
|
||||
def enable_logging():
|
||||
def enable_logging() -> None:
|
||||
global LOGGING_CONFIGURED # noqa: PLW0603
|
||||
logger.enable("langflow")
|
||||
LOGGING_CONFIGURED = True
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ MAX_PORT = 65535
|
|||
|
||||
|
||||
class RequestCancelledMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app):
|
||||
def __init__(self, app) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
|
|
@ -224,7 +224,7 @@ def create_app():
|
|||
return app
|
||||
|
||||
|
||||
def setup_sentry(app: FastAPI):
|
||||
def setup_sentry(app: FastAPI) -> None:
|
||||
settings = get_settings_service().settings
|
||||
if settings.sentry_dsn:
|
||||
import sentry_sdk
|
||||
|
|
@ -238,7 +238,7 @@ def setup_sentry(app: FastAPI):
|
|||
app.add_middleware(SentryAsgiMiddleware)
|
||||
|
||||
|
||||
def setup_static_files(app: FastAPI, static_files_dir: Path):
|
||||
def setup_static_files(app: FastAPI, static_files_dir: Path) -> None:
|
||||
"""Setup the static files directory.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def add_messagetables(messages: list[MessageTable], session: Session):
|
|||
return [MessageRead.model_validate(message, from_attributes=True) for message in messages]
|
||||
|
||||
|
||||
def delete_messages(session_id: str):
|
||||
def delete_messages(session_id: str) -> None:
|
||||
"""Delete messages from the monitor service based on the provided session ID.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def get_langfuse_callback(trace_id):
|
|||
return None
|
||||
|
||||
|
||||
def flush_langfuse_callback_if_present(callbacks: list[BaseCallbackHandler | CallbackHandler]):
|
||||
def flush_langfuse_callback_if_present(callbacks: list[BaseCallbackHandler | CallbackHandler]) -> None:
|
||||
"""If langfuse callback is present, run callback.langfuse.flush()."""
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "langfuse") and hasattr(callback.langfuse, "flush"):
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class Data(BaseModel):
|
|||
msg = f"'{type(self).__name__}' object has no attribute '{key}'"
|
||||
raise AttributeError(msg) from e
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
def __setattr__(self, key, value) -> None:
|
||||
"""Set attribute-like values in the data dictionary.
|
||||
|
||||
Allows attribute-like setting of values in the data dictionary.
|
||||
|
|
@ -179,7 +179,7 @@ class Data(BaseModel):
|
|||
else:
|
||||
self.data[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
def __delattr__(self, key) -> None:
|
||||
"""Allows attribute-like deletion from the data dictionary."""
|
||||
if key in {"data", "text_key"} or key.startswith("_"):
|
||||
super().__delattr__(key)
|
||||
|
|
@ -204,7 +204,7 @@ class Data(BaseModel):
|
|||
logger.opt(exception=True).debug("Error converting Data to JSON")
|
||||
return str(self.data)
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
return key in self.data
|
||||
|
||||
def __eq__(self, other):
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class dotdict(dict): # noqa: N801
|
|||
else:
|
||||
return value
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
def __setattr__(self, key, value) -> None:
|
||||
"""Override attribute setting to work as dictionary item assignment.
|
||||
|
||||
Args:
|
||||
|
|
@ -44,7 +44,7 @@ class dotdict(dict): # noqa: N801
|
|||
value = dotdict(value)
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
def __delattr__(self, key) -> None:
|
||||
"""Override attribute deletion to work as dictionary item deletion.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -37,10 +37,10 @@ class Tweaks(RootModel):
|
|||
def __getitem__(self, key):
|
||||
return self.root[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key, value) -> None:
|
||||
self.root[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
def __delitem__(self, key) -> None:
|
||||
del self.root[key]
|
||||
|
||||
def items(self):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from langflow.services.deps import get_storage_service
|
|||
IMAGE_ENDPOINT = "/files/images/"
|
||||
|
||||
|
||||
def is_image_file(file_path):
|
||||
def is_image_file(file_path) -> bool:
|
||||
try:
|
||||
with PILImage.open(file_path) as img:
|
||||
img.verify() # Verify that it is, in fact, an image
|
||||
|
|
@ -61,5 +61,5 @@ class Image(BaseModel):
|
|||
"image_url": self.to_base64(),
|
||||
}
|
||||
|
||||
def get_url(self):
|
||||
def get_url(self) -> str:
|
||||
return f"{IMAGE_ENDPOINT}{self.path}"
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class Message(Data):
|
|||
if "timestamp" not in self.data:
|
||||
self.data["timestamp"] = self.timestamp
|
||||
|
||||
def set_flow_id(self, flow_id: str):
|
||||
def set_flow_id(self, flow_id: str) -> None:
|
||||
self.flow_id = flow_id
|
||||
|
||||
def to_lc_message(
|
||||
|
|
|
|||
|
|
@ -35,14 +35,14 @@ class Logger(glogging.Logger):
|
|||
gunicorn logs to loguru.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg) -> None:
|
||||
super().__init__(cfg)
|
||||
logging.getLogger("gunicorn.error").handlers = [InterceptHandler()]
|
||||
logging.getLogger("gunicorn.access").handlers = [InterceptHandler()]
|
||||
|
||||
|
||||
class LangflowApplication(BaseApplication):
|
||||
def __init__(self, app, options=None):
|
||||
def __init__(self, app, options=None) -> None:
|
||||
self.options = options or {}
|
||||
|
||||
self.options["worker_class"] = "langflow.server.LangflowUvicornWorker"
|
||||
|
|
@ -50,7 +50,7 @@ class LangflowApplication(BaseApplication):
|
|||
self.application = app
|
||||
super().__init__()
|
||||
|
||||
def load_config(self):
|
||||
def load_config(self) -> None:
|
||||
config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None}
|
||||
for key, value in config.items():
|
||||
self.cfg.set(key.lower(), value)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langflow.services.factory import ServiceFactory
|
|||
class AuthServiceFactory(ServiceFactory):
|
||||
name = "auth_service"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(AuthService)
|
||||
|
||||
def create(self, settings_service):
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ class Service(ABC):
|
|||
}
|
||||
return schema
|
||||
|
||||
async def teardown(self):
|
||||
async def teardown(self) -> None:
|
||||
return
|
||||
|
||||
def set_ready(self):
|
||||
def set_ready(self) -> None:
|
||||
self.ready = True
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class CacheService(Service, Generic[LockType]):
|
|||
"""Clear all items from the cache."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Check if the key is in the cache.
|
||||
|
||||
Args:
|
||||
|
|
@ -79,7 +79,7 @@ class CacheService(Service, Generic[LockType]):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key, value) -> None:
|
||||
"""Add an item to the cache using the square bracket notation.
|
||||
|
||||
Args:
|
||||
|
|
@ -88,7 +88,7 @@ class CacheService(Service, Generic[LockType]):
|
|||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __delitem__(self, key):
|
||||
def __delitem__(self, key) -> None:
|
||||
"""Remove an item from the cache using the square bracket notation.
|
||||
|
||||
Args:
|
||||
|
|
@ -147,7 +147,7 @@ class AsyncBaseCacheService(Service, Generic[AsyncLockType]):
|
|||
"""Clear all items from the cache."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Check if the key is in the cache.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
22
src/backend/base/langflow/services/cache/disk.py
vendored
22
src/backend/base/langflow/services/cache/disk.py
vendored
|
|
@ -11,7 +11,7 @@ from langflow.services.cache.utils import CACHE_MISS
|
|||
|
||||
|
||||
class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
||||
def __init__(self, cache_dir, max_size=None, expiration_time=3600):
|
||||
def __init__(self, cache_dir, max_size=None, expiration_time=3600) -> None:
|
||||
self.cache = Cache(cache_dir)
|
||||
# Let's clear the cache for now to maintain a similar
|
||||
# behavior as the in-memory cache
|
||||
|
|
@ -40,56 +40,56 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
|||
await self._delete(key) # Log before deleting the expired item
|
||||
return CACHE_MISS
|
||||
|
||||
async def set(self, key, value, lock: asyncio.Lock | None = None):
|
||||
async def set(self, key, value, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._set(key, value)
|
||||
else:
|
||||
await self._set(key, value)
|
||||
|
||||
async def _set(self, key, value):
|
||||
async def _set(self, key, value) -> None:
|
||||
if self.max_size and len(self.cache) >= self.max_size:
|
||||
await asyncio.to_thread(self.cache.cull)
|
||||
item = {"value": pickle.dumps(value) if not isinstance(value, str | bytes) else value, "time": time.time()}
|
||||
await asyncio.to_thread(self.cache.set, key, item)
|
||||
|
||||
async def delete(self, key, lock: asyncio.Lock | None = None):
|
||||
async def delete(self, key, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._delete(key)
|
||||
else:
|
||||
await self._delete(key)
|
||||
|
||||
async def _delete(self, key):
|
||||
async def _delete(self, key) -> None:
|
||||
await asyncio.to_thread(self.cache.delete, key)
|
||||
|
||||
async def clear(self, lock: asyncio.Lock | None = None):
|
||||
async def clear(self, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._clear()
|
||||
else:
|
||||
await self._clear()
|
||||
|
||||
async def _clear(self):
|
||||
async def _clear(self) -> None:
|
||||
await asyncio.to_thread(self.cache.clear)
|
||||
|
||||
async def upsert(self, key, value, lock: asyncio.Lock | None = None):
|
||||
async def upsert(self, key, value, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._upsert(key, value)
|
||||
else:
|
||||
await self._upsert(key, value)
|
||||
|
||||
async def _upsert(self, key, value):
|
||||
async def _upsert(self, key, value) -> None:
|
||||
existing_value = await self.get(key)
|
||||
if existing_value is not CACHE_MISS and isinstance(existing_value, dict) and isinstance(value, dict):
|
||||
existing_value.update(value)
|
||||
value = existing_value
|
||||
await self.set(key, value)
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
return asyncio.run(asyncio.to_thread(self.cache.__contains__, key))
|
||||
|
||||
async def teardown(self):
|
||||
async def teardown(self) -> None:
|
||||
# Clean up the cache directory
|
||||
self.cache.clear(retry=True)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class CacheServiceFactory(ServiceFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(CacheService)
|
||||
|
||||
def create(self, settings_service: SettingsService):
|
||||
|
|
|
|||
|
|
@ -36,14 +36,14 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
b = cache["b"]
|
||||
"""
|
||||
|
||||
def __init__(self, max_size=None, expiration_time=60 * 60):
|
||||
def __init__(self, max_size=None, expiration_time=60 * 60) -> None:
|
||||
"""Initialize a new InMemoryCache instance.
|
||||
|
||||
Args:
|
||||
max_size (int, optional): Maximum number of items to store in the cache.
|
||||
expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
|
||||
"""
|
||||
self._cache = OrderedDict()
|
||||
self._cache: OrderedDict = OrderedDict()
|
||||
self._lock = threading.RLock()
|
||||
self.max_size = max_size
|
||||
self.expiration_time = expiration_time
|
||||
|
|
@ -72,7 +72,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
self.delete(key)
|
||||
return None
|
||||
|
||||
def set(self, key, value, lock: Union[threading.Lock, None] = None): # noqa: UP007
|
||||
def set(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
|
||||
"""Add an item to the cache.
|
||||
|
||||
If the cache is full, the least recently used item is evicted.
|
||||
|
|
@ -93,7 +93,7 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
|
||||
self._cache[key] = {"value": value, "time": time.time()}
|
||||
|
||||
def upsert(self, key, value, lock: Union[threading.Lock, None] = None): # noqa: UP007
|
||||
def upsert(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
|
||||
"""Inserts or updates a value in the cache.
|
||||
|
||||
If the existing value and the new value are both dictionaries, they are merged.
|
||||
|
|
@ -130,16 +130,16 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
self.set(key, value)
|
||||
return value
|
||||
|
||||
def delete(self, key, lock: Union[threading.Lock, None] = None): # noqa: UP007
|
||||
def delete(self, key, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
|
||||
with lock or self._lock:
|
||||
self._cache.pop(key, None)
|
||||
|
||||
def clear(self, lock: Union[threading.Lock, None] = None): # noqa: UP007
|
||||
def clear(self, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007
|
||||
"""Clear all items from the cache."""
|
||||
with lock or self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Check if the key is in the cache."""
|
||||
return key in self._cache
|
||||
|
||||
|
|
@ -147,19 +147,19 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]):
|
|||
"""Retrieve an item from the cache using the square bracket notation."""
|
||||
return self.get(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
def __setitem__(self, key, value) -> None:
|
||||
"""Add an item to the cache using the square bracket notation."""
|
||||
self.set(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
def __delitem__(self, key) -> None:
|
||||
"""Remove an item from the cache using the square bracket notation."""
|
||||
self.delete(key)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of items in the cache."""
|
||||
return len(self._cache)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the InMemoryCache instance."""
|
||||
return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})"
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
|
|||
b = cache["b"]
|
||||
"""
|
||||
|
||||
def __init__(self, host="localhost", port=6379, db=0, url=None, expiration_time=60 * 60):
|
||||
def __init__(self, host="localhost", port=6379, db=0, url=None, expiration_time=60 * 60) -> None:
|
||||
"""Initialize a new RedisCache instance.
|
||||
|
||||
Args:
|
||||
|
|
@ -215,7 +215,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
|
|||
self.expiration_time = expiration_time
|
||||
|
||||
# check connection
|
||||
def is_connected(self):
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the Redis client is connected."""
|
||||
import redis
|
||||
|
||||
|
|
@ -234,7 +234,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
|
|||
return pickle.loads(value) if value else None
|
||||
|
||||
@override
|
||||
async def set(self, key, value, lock=None):
|
||||
async def set(self, key, value, lock=None) -> None:
|
||||
try:
|
||||
if pickled := pickle.dumps(value):
|
||||
result = await self._client.setex(str(key), self.expiration_time, pickled)
|
||||
|
|
@ -246,7 +246,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
|
|||
raise TypeError(msg) from exc
|
||||
|
||||
@override
|
||||
async def upsert(self, key, value, lock=None):
|
||||
async def upsert(self, key, value, lock=None) -> None:
|
||||
"""Inserts or updates a value in the cache.
|
||||
|
||||
If the existing value and the new value are both dictionaries, they are merged.
|
||||
|
|
@ -266,28 +266,28 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]):
|
|||
await self.set(key, value)
|
||||
|
||||
@override
|
||||
async def delete(self, key, lock=None):
|
||||
async def delete(self, key, lock=None) -> None:
|
||||
await self._client.delete(key)
|
||||
|
||||
@override
|
||||
async def clear(self, lock=None):
|
||||
async def clear(self, lock=None) -> None:
|
||||
"""Clear all items from the cache."""
|
||||
await self._client.flushdb()
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
"""Check if the key is in the cache."""
|
||||
if key is None:
|
||||
return False
|
||||
return asyncio.run(self._client.exists(str(key)))
|
||||
return bool(asyncio.run(self._client.exists(str(key))))
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the RedisCache instance."""
|
||||
return f"RedisCache(expiration_time={self.expiration_time})"
|
||||
|
||||
|
||||
class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
||||
def __init__(self, max_size=None, expiration_time=3600):
|
||||
self.cache = OrderedDict()
|
||||
def __init__(self, max_size=None, expiration_time=3600) -> None:
|
||||
self.cache: OrderedDict = OrderedDict()
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
self.max_size = max_size
|
||||
|
|
@ -310,7 +310,7 @@ class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
|||
await self._delete(key) # Log before deleting the expired item
|
||||
return CACHE_MISS
|
||||
|
||||
async def set(self, key, value, lock: asyncio.Lock | None = None):
|
||||
async def set(self, key, value, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._set(
|
||||
|
|
@ -323,46 +323,46 @@ class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]):
|
|||
value,
|
||||
)
|
||||
|
||||
async def _set(self, key, value):
|
||||
async def _set(self, key, value) -> None:
|
||||
if self.max_size and len(self.cache) >= self.max_size:
|
||||
self.cache.popitem(last=False)
|
||||
self.cache[key] = {"value": value, "time": time.time()}
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
async def delete(self, key, lock: asyncio.Lock | None = None):
|
||||
async def delete(self, key, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._delete(key)
|
||||
else:
|
||||
await self._delete(key)
|
||||
|
||||
async def _delete(self, key):
|
||||
async def _delete(self, key) -> None:
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
|
||||
async def clear(self, lock: asyncio.Lock | None = None):
|
||||
async def clear(self, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._clear()
|
||||
else:
|
||||
await self._clear()
|
||||
|
||||
async def _clear(self):
|
||||
async def _clear(self) -> None:
|
||||
self.cache.clear()
|
||||
|
||||
async def upsert(self, key, value, lock: asyncio.Lock | None = None):
|
||||
async def upsert(self, key, value, lock: asyncio.Lock | None = None) -> None:
|
||||
if not lock:
|
||||
async with self.lock:
|
||||
await self._upsert(key, value)
|
||||
else:
|
||||
await self._upsert(key, value)
|
||||
|
||||
async def _upsert(self, key, value):
|
||||
async def _upsert(self, key, value) -> None:
|
||||
existing_value = await self.get(key)
|
||||
if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict):
|
||||
existing_value.update(value)
|
||||
value = existing_value
|
||||
await self.set(key, value)
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key) -> bool:
|
||||
return key in self.cache
|
||||
|
|
|
|||
|
|
@ -19,10 +19,10 @@ PREFIX = "langflow_cache"
|
|||
|
||||
|
||||
class CacheMiss:
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "<CACHE_MISS>"
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ def create_cache_folder(func):
|
|||
|
||||
|
||||
@create_cache_folder
|
||||
def clear_old_cache_files(max_cache_size: int = 3):
|
||||
def clear_old_cache_files(max_cache_size: int = 3) -> None:
|
||||
cache_dir = Path(tempfile.gettempdir()) / PREFIX
|
||||
cache_files = list(cache_dir.glob("*.dill"))
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ def save_uploaded_file(file: UploadFile, folder_name):
|
|||
return file_path
|
||||
|
||||
|
||||
def update_build_status(cache_service, flow_id: str, status: "BuildStatus"):
|
||||
def update_build_status(cache_service, flow_id: str, status: "BuildStatus") -> None:
|
||||
cached_flow = cache_service[flow_id]
|
||||
if cached_flow is None:
|
||||
msg = f"Flow {flow_id} not found in cache"
|
||||
|
|
|
|||
|
|
@ -11,18 +11,18 @@ from langflow.services.base import Service
|
|||
class Subject:
|
||||
"""Base class for implementing the observer pattern."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.observers: list[Callable[[], None]] = []
|
||||
|
||||
def attach(self, observer: Callable[[], None]):
|
||||
def attach(self, observer: Callable[[], None]) -> None:
|
||||
"""Attach an observer to the subject."""
|
||||
self.observers.append(observer)
|
||||
|
||||
def detach(self, observer: Callable[[], None]):
|
||||
def detach(self, observer: Callable[[], None]) -> None:
|
||||
"""Detach an observer from the subject."""
|
||||
self.observers.remove(observer)
|
||||
|
||||
def notify(self):
|
||||
def notify(self) -> None:
|
||||
"""Notify all observers about an event."""
|
||||
for observer in self.observers:
|
||||
if observer is None:
|
||||
|
|
@ -33,18 +33,18 @@ class Subject:
|
|||
class AsyncSubject:
|
||||
"""Base class for implementing the async observer pattern."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.observers: list[Callable[[], Awaitable]] = []
|
||||
|
||||
def attach(self, observer: Callable[[], Awaitable]):
|
||||
def attach(self, observer: Callable[[], Awaitable]) -> None:
|
||||
"""Attach an observer to the subject."""
|
||||
self.observers.append(observer)
|
||||
|
||||
def detach(self, observer: Callable[[], Awaitable]):
|
||||
def detach(self, observer: Callable[[], Awaitable]) -> None:
|
||||
"""Detach an observer from the subject."""
|
||||
self.observers.remove(observer)
|
||||
|
||||
async def notify(self):
|
||||
async def notify(self) -> None:
|
||||
"""Notify all observers about an event."""
|
||||
for observer in self.observers:
|
||||
if observer is None:
|
||||
|
|
@ -57,11 +57,11 @@ class CacheService(Subject, Service):
|
|||
|
||||
name = "cache_service"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._cache = {}
|
||||
self.current_client_id = None
|
||||
self.current_cache = {}
|
||||
self._cache: dict[str, Any] = {}
|
||||
self.current_client_id: str | None = None
|
||||
self.current_cache: dict[str, Any] = {}
|
||||
|
||||
@contextmanager
|
||||
def set_client_id(self, client_id: str):
|
||||
|
|
@ -77,9 +77,9 @@ class CacheService(Subject, Service):
|
|||
yield
|
||||
finally:
|
||||
self.current_client_id = previous_client_id
|
||||
self.current_cache = self._cache.get(self.current_client_id, {})
|
||||
self.current_cache = self._cache.setdefault(previous_client_id, {}) if previous_client_id else {}
|
||||
|
||||
def add(self, name: str, obj: Any, obj_type: str, extension: str | None = None):
|
||||
def add(self, name: str, obj: Any, obj_type: str, extension: str | None = None) -> None:
|
||||
"""Add an object to the current client's cache.
|
||||
|
||||
Args:
|
||||
|
|
@ -100,7 +100,7 @@ class CacheService(Subject, Service):
|
|||
}
|
||||
self.notify()
|
||||
|
||||
def add_pandas(self, name: str, obj: Any):
|
||||
def add_pandas(self, name: str, obj: Any) -> None:
|
||||
"""Add a pandas DataFrame or Series to the current client's cache.
|
||||
|
||||
Args:
|
||||
|
|
@ -113,7 +113,7 @@ class CacheService(Subject, Service):
|
|||
msg = "Object is not a pandas DataFrame or Series"
|
||||
raise TypeError(msg)
|
||||
|
||||
def add_image(self, name: str, obj: Any, extension: str = "png"):
|
||||
def add_image(self, name: str, obj: Any, extension: str = "png") -> None:
|
||||
"""Add a PIL Image to the current client's cache.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.services.factory import ServiceFactory
|
|||
|
||||
|
||||
class ChatServiceFactory(ServiceFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ChatService)
|
||||
|
||||
def create(self):
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ class ChatService(Service):
|
|||
|
||||
name = "chat_service"
|
||||
|
||||
def __init__(self):
|
||||
self._async_cache_locks = defaultdict(asyncio.Lock)
|
||||
self._sync_cache_locks = defaultdict(RLock)
|
||||
def __init__(self) -> None:
|
||||
self._async_cache_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self._sync_cache_locks: dict[str, RLock] = defaultdict(RLock)
|
||||
self.cache_service = get_cache_service()
|
||||
|
||||
def _get_lock(self, key: str):
|
||||
|
|
@ -101,7 +101,7 @@ class ChatService(Service):
|
|||
"""
|
||||
return await self._perform_cache_operation("get", key, lock=lock or self._get_lock(key))
|
||||
|
||||
async def clear_cache(self, key: str, lock: asyncio.Lock | None = None):
|
||||
async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None:
|
||||
"""Clear the cache for a client.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class DatabaseServiceFactory(ServiceFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(DatabaseService)
|
||||
|
||||
def create(self, settings_service: SettingsService):
|
||||
|
|
|
|||
|
|
@ -59,6 +59,6 @@ class ApiKeyRead(ApiKeyBase):
|
|||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def mask_api_key(cls, v):
|
||||
def mask_api_key(cls, v) -> str:
|
||||
# This validator will always run, and will mask the API key
|
||||
return f"{v[:8]}{'*' * (len(v) - 8)}"
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class DatabaseService(Service):
|
|||
self.alembic_cfg_path = langflow_dir / "alembic.ini"
|
||||
self.engine = self._create_engine()
|
||||
|
||||
def reload_engine(self):
|
||||
def reload_engine(self) -> None:
|
||||
self.engine = self._create_engine()
|
||||
|
||||
def _create_engine(self) -> Engine:
|
||||
|
|
@ -82,13 +82,13 @@ class DatabaseService(Service):
|
|||
msg = "Error creating database engine"
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
def on_connection(self, dbapi_connection, _connection_record):
|
||||
def on_connection(self, dbapi_connection, _connection_record) -> None:
|
||||
from sqlite3 import Connection as sqliteConnection
|
||||
|
||||
if isinstance(dbapi_connection, sqliteConnection):
|
||||
pragmas: dict | None = self.settings_service.settings.sqlite_pragmas
|
||||
pragmas: dict = self.settings_service.settings.sqlite_pragmas or {}
|
||||
pragmas_list = []
|
||||
for key, val in pragmas.items() or {}:
|
||||
for key, val in pragmas.items():
|
||||
pragmas_list.append(f"PRAGMA {key} = {val}")
|
||||
logger.info(f"sqlite connection, setting pragmas: {pragmas_list}")
|
||||
if pragmas_list:
|
||||
|
|
@ -107,7 +107,7 @@ class DatabaseService(Service):
|
|||
with Session(self.engine) as session:
|
||||
yield session
|
||||
|
||||
def migrate_flows_if_auto_login(self):
|
||||
def migrate_flows_if_auto_login(self) -> None:
|
||||
# if auto_login is enabled, we need to migrate the flows
|
||||
# to the default superuser if they don't have a user id
|
||||
# associated with them
|
||||
|
|
@ -161,14 +161,14 @@ class DatabaseService(Service):
|
|||
|
||||
return True
|
||||
|
||||
def init_alembic(self, alembic_cfg):
|
||||
def init_alembic(self, alembic_cfg) -> None:
|
||||
logger.info("Initializing alembic")
|
||||
command.ensure_version(alembic_cfg)
|
||||
# alembic_cfg.attributes["connection"].commit()
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
logger.info("Alembic initialized")
|
||||
|
||||
def run_migrations(self, *, fix=False):
|
||||
def run_migrations(self, *, fix=False) -> None:
|
||||
# First we need to check if alembic has been initialized
|
||||
# If not, we need to initialize it
|
||||
# if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized
|
||||
|
|
@ -227,7 +227,7 @@ class DatabaseService(Service):
|
|||
if fix:
|
||||
self.try_downgrade_upgrade_until_success(alembic_cfg)
|
||||
|
||||
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5):
|
||||
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None:
|
||||
# Try -1 then head, if it fails, try -2 then head, etc.
|
||||
# until we reach the number of retries
|
||||
for i in range(1, retries + 1):
|
||||
|
|
@ -273,7 +273,7 @@ class DatabaseService(Service):
|
|||
results.append(Result(name=column, type="column", success=True))
|
||||
return results
|
||||
|
||||
def create_db_and_tables(self):
|
||||
def create_db_and_tables(self) -> None:
|
||||
from sqlalchemy import inspect
|
||||
|
||||
inspector = inspect(self.engine)
|
||||
|
|
@ -308,7 +308,7 @@ class DatabaseService(Service):
|
|||
|
||||
logger.debug("Database and tables created successfully")
|
||||
|
||||
async def teardown(self):
|
||||
async def teardown(self) -> None:
|
||||
logger.debug("Tearing down database")
|
||||
try:
|
||||
settings_service = get_settings_service()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
from langflow.services.database.service import DatabaseService
|
||||
|
||||
|
||||
def initialize_database(*, fix_migration: bool = False):
|
||||
def initialize_database(*, fix_migration: bool = False) -> None:
|
||||
logger.debug("Initializing database")
|
||||
from langflow.services.deps import get_db_service
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class ServiceFactory:
|
|||
def __init__(
|
||||
self,
|
||||
service_class,
|
||||
):
|
||||
) -> None:
|
||||
self.service_class = service_class
|
||||
self.dependencies = infer_service_types(self, import_all_services_into_a_dict())
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ class ServiceFactory:
|
|||
raise self.service_class(*args, **kwargs)
|
||||
|
||||
|
||||
def hash_factory(factory: type[ServiceFactory]) -> str:
|
||||
def hash_factory(factory: ServiceFactory) -> str:
|
||||
return factory.service_class.__name__
|
||||
|
||||
|
||||
|
|
@ -31,15 +31,15 @@ def hash_dict(d: dict) -> str:
|
|||
return str(d)
|
||||
|
||||
|
||||
def hash_infer_service_types_args(factory_class: type[ServiceFactory], available_services=None) -> str:
|
||||
factory_hash = hash_factory(factory_class)
|
||||
def hash_infer_service_types_args(factory: ServiceFactory, available_services=None) -> str:
|
||||
factory_hash = hash_factory(factory)
|
||||
services_hash = hash_dict(available_services)
|
||||
return f"{factory_hash}_{services_hash}"
|
||||
|
||||
|
||||
@cached(cache=LRUCache(maxsize=10), key=hash_infer_service_types_args)
|
||||
def infer_service_types(factory_class: type[ServiceFactory], available_services=None) -> list["ServiceType"]:
|
||||
create_method = factory_class.create
|
||||
def infer_service_types(factory: ServiceFactory, available_services=None) -> list["ServiceType"]:
|
||||
create_method = factory.create
|
||||
type_hints = get_type_hints(create_method, globalns=available_services)
|
||||
service_types = []
|
||||
for param_name, param_type in type_hints.items():
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ class NoFactoryRegisteredError(Exception):
|
|||
class ServiceManager:
|
||||
"""Manages the creation of different services."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.services: dict[str, Service] = {}
|
||||
self.factories = {}
|
||||
self.factories: dict[str, ServiceFactory] = {}
|
||||
self.register_factories()
|
||||
self.keyed_lock = KeyedMemoryLockManager()
|
||||
|
||||
def register_factories(self):
|
||||
def register_factories(self) -> None:
|
||||
for factory in self.get_factories():
|
||||
try:
|
||||
self.register_factory(factory)
|
||||
|
|
@ -38,7 +38,7 @@ class ServiceManager:
|
|||
def register_factory(
|
||||
self,
|
||||
service_factory: ServiceFactory,
|
||||
):
|
||||
) -> None:
|
||||
"""Registers a new factory with dependencies."""
|
||||
service_name = service_factory.service_class.name
|
||||
self.factories[service_name] = service_factory
|
||||
|
|
@ -51,7 +51,7 @@ class ServiceManager:
|
|||
|
||||
return self.services[service_name]
|
||||
|
||||
def _create_service(self, service_name: ServiceType, default: ServiceFactory | None = None):
|
||||
def _create_service(self, service_name: ServiceType, default: ServiceFactory | None = None) -> None:
|
||||
"""Create a new service given its name, handling dependencies."""
|
||||
logger.debug(f"Create service {service_name}")
|
||||
self._validate_service_creation(service_name, default)
|
||||
|
|
@ -61,6 +61,9 @@ class ServiceManager:
|
|||
if factory is None and default is not None:
|
||||
self.register_factory(default)
|
||||
factory = default
|
||||
if factory is None:
|
||||
msg = f"No factory registered for {service_name}"
|
||||
raise NoFactoryRegisteredError(msg)
|
||||
for dependency in factory.dependencies:
|
||||
if dependency not in self.services:
|
||||
self._create_service(dependency)
|
||||
|
|
@ -72,20 +75,20 @@ class ServiceManager:
|
|||
self.services[service_name] = self.factories[service_name].create(**dependent_services)
|
||||
self.services[service_name].set_ready()
|
||||
|
||||
def _validate_service_creation(self, service_name: ServiceType, default: ServiceFactory | None = None):
|
||||
def _validate_service_creation(self, service_name: ServiceType, default: ServiceFactory | None = None) -> None:
|
||||
"""Validate whether the service can be created."""
|
||||
if service_name not in self.factories and default is None:
|
||||
msg = f"No factory registered for the service class '{service_name.name}'"
|
||||
raise NoFactoryRegisteredError(msg)
|
||||
|
||||
def update(self, service_name: ServiceType):
|
||||
def update(self, service_name: ServiceType) -> None:
|
||||
"""Update a service by its name."""
|
||||
if service_name in self.services:
|
||||
logger.debug(f"Update service {service_name}")
|
||||
self.services.pop(service_name, None)
|
||||
self.get(service_name)
|
||||
|
||||
async def teardown(self):
|
||||
async def teardown(self) -> None:
|
||||
"""Teardown all the services."""
|
||||
for service in self.services.values():
|
||||
if service is None:
|
||||
|
|
@ -131,14 +134,14 @@ class ServiceManager:
|
|||
service_manager = ServiceManager()
|
||||
|
||||
|
||||
def initialize_settings_service():
|
||||
def initialize_settings_service() -> None:
|
||||
"""Initialize the settings manager."""
|
||||
from langflow.services.settings import factory as settings_factory
|
||||
|
||||
service_manager.register_factory(settings_factory.SettingsServiceFactory())
|
||||
|
||||
|
||||
def initialize_session_service():
|
||||
def initialize_session_service() -> None:
|
||||
"""Initialize the session manager."""
|
||||
from langflow.services.cache import factory as cache_factory
|
||||
from langflow.services.session import factory as session_service_factory
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ from typing import Any
|
|||
|
||||
|
||||
class BasePlugin:
|
||||
def initialize(self):
|
||||
def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
pass
|
||||
|
||||
def get(self) -> Any:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langflow.services.plugins.service import PluginService
|
|||
|
||||
|
||||
class PluginServiceFactory(ServiceFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(PluginService)
|
||||
|
||||
def create(self):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class LangfuseInstance:
|
|||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def create(cls):
|
||||
def create(cls) -> None:
|
||||
try:
|
||||
logger.debug("Creating Langfuse instance")
|
||||
from langfuse import Langfuse
|
||||
|
|
@ -44,13 +44,13 @@ class LangfuseInstance:
|
|||
cls._instance = None
|
||||
|
||||
@classmethod
|
||||
def update(cls):
|
||||
def update(cls) -> None:
|
||||
logger.debug("Updating Langfuse instance")
|
||||
cls._instance = None
|
||||
cls.create()
|
||||
|
||||
@classmethod
|
||||
def teardown(cls):
|
||||
def teardown(cls) -> None:
|
||||
logger.debug("Tearing down Langfuse instance")
|
||||
if cls._instance is not None:
|
||||
cls._instance.flush()
|
||||
|
|
@ -58,10 +58,10 @@ class LangfuseInstance:
|
|||
|
||||
|
||||
class LangfusePlugin(CallbackPlugin):
|
||||
def initialize(self):
|
||||
def initialize(self) -> None:
|
||||
LangfuseInstance.create()
|
||||
|
||||
def teardown(self):
|
||||
def teardown(self) -> None:
|
||||
LangfuseInstance.teardown()
|
||||
|
||||
def get(self):
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ from langflow.services.plugins.base import BasePlugin, CallbackPlugin
|
|||
class PluginService(Service):
|
||||
name = "plugin_service"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.plugins: dict[str, BasePlugin] = {}
|
||||
self.plugin_dir = Path(__file__).parent
|
||||
self.plugins_base_module = "langflow.services.plugins"
|
||||
self.load_plugins()
|
||||
|
||||
def load_plugins(self):
|
||||
def load_plugins(self) -> None:
|
||||
base_files = ["base.py", "service.py", "factory.py", "__init__.py"]
|
||||
for module in self.plugin_dir.iterdir():
|
||||
if module.suffix == ".py" and module.name not in base_files:
|
||||
|
|
@ -38,7 +38,7 @@ class PluginService(Service):
|
|||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error loading plugin {plugin_name}")
|
||||
|
||||
def register_plugin(self, plugin_name, plugin_instance):
|
||||
def register_plugin(self, plugin_name, plugin_instance) -> None:
|
||||
self.plugins[plugin_name] = plugin_instance
|
||||
plugin_instance.initialize()
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ class PluginService(Service):
|
|||
return plugin.get()
|
||||
return None
|
||||
|
||||
async def teardown(self):
|
||||
async def teardown(self) -> None:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.teardown()
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class SessionServiceFactory(ServiceFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(SessionService)
|
||||
|
||||
def create(self, cache_service: "CacheService"):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
|||
class SessionService(Service):
|
||||
name = "session_service"
|
||||
|
||||
def __init__(self, cache_service):
|
||||
def __init__(self, cache_service) -> None:
|
||||
self.cache_service: CacheService = cache_service
|
||||
|
||||
async def load_session(self, key, flow_id: str, data_graph: dict | None = None):
|
||||
|
|
@ -35,7 +35,7 @@ class SessionService(Service):
|
|||
|
||||
return graph, artifacts
|
||||
|
||||
def build_key(self, session_id, data_graph):
|
||||
def build_key(self, session_id, data_graph) -> str:
|
||||
json_hash = compute_dict_hash(data_graph)
|
||||
return f"{session_id}{':' if session_id else ''}{json_hash}"
|
||||
|
||||
|
|
@ -46,13 +46,13 @@ class SessionService(Service):
|
|||
session_id = session_id_generator()
|
||||
return self.build_key(session_id, data_graph=data_graph)
|
||||
|
||||
async def update_session(self, session_id, value):
|
||||
async def update_session(self, session_id, value) -> None:
|
||||
result = self.cache_service.set(session_id, value)
|
||||
# if it is a coroutine, await it
|
||||
if isinstance(result, Coroutine):
|
||||
await result
|
||||
|
||||
async def clear_session(self, session_id):
|
||||
async def clear_session(self, session_id) -> None:
|
||||
result = self.cache_service.delete(session_id)
|
||||
# if it is a coroutine, await it
|
||||
if isinstance(result, Coroutine):
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class AuthSettings(BaseSettings):
|
|||
extra = "ignore"
|
||||
env_prefix = "LANGFLOW_"
|
||||
|
||||
def reset_credentials(self):
|
||||
def reset_credentials(self) -> None:
|
||||
self.SUPERUSER = DEFAULT_SUPERUSER
|
||||
self.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD
|
||||
|
||||
|
|
|
|||
|
|
@ -326,12 +326,12 @@ class Settings(BaseSettings):
|
|||
|
||||
model_config = SettingsConfigDict(validate_assignment=True, extra="ignore", env_prefix="LANGFLOW_")
|
||||
|
||||
def update_from_yaml(self, file_path: str, *, dev: bool = False):
|
||||
def update_from_yaml(self, file_path: str, *, dev: bool = False) -> None:
|
||||
new_settings = load_settings_from_yaml(file_path)
|
||||
self.components_path = new_settings.components_path or []
|
||||
self.dev = dev
|
||||
|
||||
def update_settings(self, **kwargs):
|
||||
def update_settings(self, **kwargs) -> None:
|
||||
logger.debug("Updating settings")
|
||||
for key, value in kwargs.items():
|
||||
# value may contain sensitive information, so we don't want to log it
|
||||
|
|
@ -374,7 +374,7 @@ class Settings(BaseSettings):
|
|||
return (MyCustomSource(settings_cls),)
|
||||
|
||||
|
||||
def save_settings_to_yaml(settings: Settings, file_path: str):
|
||||
def save_settings_to_yaml(settings: Settings, file_path: str) -> None:
|
||||
with Path(file_path).open("w", encoding="utf-8") as f:
|
||||
settings_dict = settings.model_dump()
|
||||
yaml.dump(settings_dict, f)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class SettingsServiceFactory(ServiceFactory):
|
|||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(SettingsService)
|
||||
|
||||
def create(self):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
from loguru import logger
|
||||
|
||||
|
||||
def set_secure_permissions(file_path: Path):
|
||||
def set_secure_permissions(file_path: Path) -> None:
|
||||
if platform.system() in {"Linux", "Darwin"}: # Unix/Linux/Mac
|
||||
file_path.chmod(0o600)
|
||||
elif platform.system() == "Windows":
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class SharedComponentCacheServiceFactory(ServiceFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(SharedComponentCacheService)
|
||||
|
||||
def create(self, settings_service: "SettingsService"):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class SocketIOFactory(ServiceFactory):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
service_class=SocketIOService,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class SocketIOService(Service):
|
|||
def __init__(self, cache_service: "CacheService"):
|
||||
self.cache_service = cache_service
|
||||
|
||||
def init(self, sio: socketio.AsyncServer):
|
||||
def init(self, sio: socketio.AsyncServer) -> None:
|
||||
# Registering event handlers
|
||||
self.sio = sio
|
||||
if self.sio:
|
||||
|
|
@ -28,32 +28,32 @@ class SocketIOService(Service):
|
|||
self.sio.on("build_vertex")(self.on_build_vertex)
|
||||
self.sessions = {} # type: dict[str, dict]
|
||||
|
||||
async def emit_error(self, sid, error):
|
||||
async def emit_error(self, sid, error) -> None:
|
||||
await self.sio.emit("error", to=sid, data=error)
|
||||
|
||||
async def connect(self, sid, environ):
|
||||
async def connect(self, sid, environ) -> None:
|
||||
logger.info(f"Socket connected: {sid}")
|
||||
self.sessions[sid] = environ
|
||||
|
||||
async def disconnect(self, sid):
|
||||
async def disconnect(self, sid) -> None:
|
||||
logger.info(f"Socket disconnected: {sid}")
|
||||
self.sessions.pop(sid, None)
|
||||
|
||||
async def message(self, sid, data=None):
|
||||
async def message(self, sid, data=None) -> None:
|
||||
# Logic for handling messages
|
||||
await self.emit_message(to=sid, data=data or {"foo": "bar", "baz": [1, 2, 3]})
|
||||
|
||||
async def emit_message(self, to, data):
|
||||
async def emit_message(self, to, data) -> None:
|
||||
# Abstracting sio.emit
|
||||
await self.sio.emit("message", to=to, data=data)
|
||||
|
||||
async def emit_token(self, to, data):
|
||||
async def emit_token(self, to, data) -> None:
|
||||
await self.sio.emit("token", to=to, data=data)
|
||||
|
||||
async def on_get_vertices(self, sid, flow_id):
|
||||
async def on_get_vertices(self, sid, flow_id) -> None:
|
||||
await get_vertices(self.sio, sid, flow_id, get_chat_service())
|
||||
|
||||
async def on_build_vertex(self, sid, flow_id, vertex_id):
|
||||
async def on_build_vertex(self, sid, flow_id, vertex_id) -> None:
|
||||
await build_vertex(
|
||||
sio=self.sio,
|
||||
sid=sid,
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue