From 2e194036602871851969946b978cd8eb683a49d1 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 12 Dec 2024 17:02:43 +0100 Subject: [PATCH] ref: Make list_flows async (#5222) * Make list_flows async * Keep sync list_flows method for backward compatibility --- .../langflow/components/deactivated/list_flows.py | 4 ++-- .../langflow/components/deactivated/sub_flow.py | 14 +++++++------- .../base/langflow/components/logic/flow_tool.py | 14 +++++++------- .../base/langflow/components/logic/run_flow.py | 8 ++++---- .../base/langflow/components/logic/sub_flow.py | 14 +++++++------- .../custom/custom_component/custom_component.py | 7 ++++++- src/backend/base/langflow/helpers/flow.py | 11 +++++------ .../unit/test_custom_component_with_client.py | 8 ++++---- 8 files changed, 42 insertions(+), 38 deletions(-) diff --git a/src/backend/base/langflow/components/deactivated/list_flows.py b/src/backend/base/langflow/components/deactivated/list_flows.py index 5e18e9cfa..70d77d8e4 100644 --- a/src/backend/base/langflow/components/deactivated/list_flows.py +++ b/src/backend/base/langflow/components/deactivated/list_flows.py @@ -12,9 +12,9 @@ class ListFlowsComponent(CustomComponent): def build_config(self): return {} - def build( + async def build( self, ) -> list[Data]: - flows = self.list_flows() + flows = await self.alist_flows() self.status = flows return flows diff --git a/src/backend/base/langflow/components/deactivated/sub_flow.py b/src/backend/base/langflow/components/deactivated/sub_flow.py index d23abcfd2..056eab3a0 100644 --- a/src/backend/base/langflow/components/deactivated/sub_flow.py +++ b/src/backend/base/langflow/components/deactivated/sub_flow.py @@ -24,28 +24,28 @@ class SubFlowComponent(CustomComponent): field_order = ["flow_name"] name = "SubFlow" - def get_flow_names(self) -> list[str]: - flow_datas = self.list_flows() + async def get_flow_names(self) -> list[str]: + flow_datas = await self.alist_flows() return [flow_data.data["name"] for flow_data in flow_datas] - def get_flow(self, flow_name: str) -> Data | None: - flow_datas = self.list_flows() + async def get_flow(self, flow_name: str) -> Data | None: + flow_datas = await self.alist_flows() for flow_data in flow_datas: if flow_data.data["name"] == flow_name: return flow_data return None - def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): + async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): logger.debug(f"Updating build config with field value {field_value} and field name {field_name}") if field_name == "flow_name": - build_config["flow_name"]["options"] = self.get_flow_names() + build_config["flow_name"]["options"] = await self.get_flow_names() # Clean up the build config for key in list(build_config.keys()): if key not in {*self.field_order, "code", "_type", "get_final_results_only"}: del build_config[key] if field_value is not None and field_name == "flow_name": try: - flow_data = self.get_flow(field_value) + flow_data = await self.get_flow(field_value) except Exception: # noqa: BLE001 logger.exception(f"Error getting flow {field_value}") else: diff --git a/src/backend/base/langflow/components/logic/flow_tool.py b/src/backend/base/langflow/components/logic/flow_tool.py index a8655335d..a6b425d3f 100644 --- a/src/backend/base/langflow/components/logic/flow_tool.py +++ b/src/backend/base/langflow/components/logic/flow_tool.py @@ -22,11 +22,11 @@ class FlowToolComponent(LCToolComponent): beta = True icon = "hammer" - def get_flow_names(self) -> list[str]: - flow_datas = self.list_flows() + async def get_flow_names(self) -> list[str]: + flow_datas = await self.alist_flows() return [flow_data.data["name"] for flow_data in flow_datas] - def get_flow(self, flow_name: str) -> Data | None: + async def get_flow(self, flow_name: str) -> Data | None: """Retrieves a flow by its name. Args: @@ -35,14 +35,14 @@ class FlowToolComponent(LCToolComponent): Returns: Optional[Text]: The flow record if found, None otherwise. """ - flow_datas = self.list_flows() + flow_datas = await self.alist_flows() for flow_data in flow_datas: if flow_data.data["name"] == flow_name: return flow_data return None @override - def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): + async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): if field_name == "flow_name": build_config["flow_name"]["options"] = self.get_flow_names() @@ -74,13 +74,13 @@ class FlowToolComponent(LCToolComponent): Output(name="api_build_tool", display_name="Tool", method="build_tool"), ] - def build_tool(self) -> Tool: + async def build_tool(self) -> Tool: FlowTool.model_rebuild() if "flow_name" not in self._attributes or not self._attributes["flow_name"]: msg = "Flow name is required" raise ValueError(msg) flow_name = self._attributes["flow_name"] - flow_data = self.get_flow(flow_name) + flow_data = await self.get_flow(flow_name) if not flow_data: msg = "Flow not found." raise ValueError(msg) diff --git a/src/backend/base/langflow/components/logic/run_flow.py b/src/backend/base/langflow/components/logic/run_flow.py index 2fac48610..c452d5de1 100644 --- a/src/backend/base/langflow/components/logic/run_flow.py +++ b/src/backend/base/langflow/components/logic/run_flow.py @@ -18,14 +18,14 @@ class RunFlowComponent(Component): legacy: bool = True icon = "workflow" - def get_flow_names(self) -> list[str]: - flow_data = self.list_flows() + async def get_flow_names(self) -> list[str]: + flow_data = await self.alist_flows() return [flow_data.data["name"] for flow_data in flow_data] @override - def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): + async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): if field_name == "flow_name": - build_config["flow_name"]["options"] = self.get_flow_names() + build_config["flow_name"]["options"] = await self.get_flow_names() return build_config diff --git a/src/backend/base/langflow/components/logic/sub_flow.py b/src/backend/base/langflow/components/logic/sub_flow.py index 2fb0fe958..eef7529fe 100644 --- a/src/backend/base/langflow/components/logic/sub_flow.py +++ b/src/backend/base/langflow/components/logic/sub_flow.py @@ -18,27 +18,27 @@ class SubFlowComponent(Component): beta: bool = True icon = "Workflow" - def get_flow_names(self) -> list[str]: - flow_data = self.list_flows() + async def get_flow_names(self) -> list[str]: + flow_data = await self.alist_flows() return [flow_data.data["name"] for flow_data in flow_data] - def get_flow(self, flow_name: str) -> Data | None: - flow_datas = self.list_flows() + async def get_flow(self, flow_name: str) -> Data | None: + flow_datas = await self.alist_flows() for flow_data in flow_datas: if flow_data.data["name"] == flow_name: return flow_data return None - def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): + async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None): if field_name == "flow_name": - build_config["flow_name"]["options"] = self.get_flow_names() + build_config["flow_name"]["options"] = await self.get_flow_names() for key in list(build_config.keys()): if key not in [x.name for x in self.inputs] + ["code", "_type", "get_final_results_only"]: del build_config[key] if field_value is not None and field_name == "flow_name": try: - flow_data = self.get_flow(field_value) + flow_data = await self.get_flow(field_value) except Exception: # noqa: BLE001 logger.exception(f"Error getting flow {field_value}") else: diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index ed3584693..0cc96e02d 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -19,6 +19,7 @@ from langflow.services.storage.service import StorageService from langflow.template.utils import update_frontend_node_with_template_values from langflow.type_extraction.type_extraction import post_process_type from langflow.utils import validate +from langflow.utils.async_helpers import run_until_complete if TYPE_CHECKING: from langchain.callbacks.base import BaseCallbackHandler @@ -509,11 +510,15 @@ class CustomComponent(BaseComponent): ) def list_flows(self) -> list[Data]: + """This is kept for backward compatibility. Using alist_flows instead is recommended.""" + return run_until_complete(self.alist_flows()) + + async def alist_flows(self) -> list[Data]: if not self.user_id: msg = "Session is invalid" raise ValueError(msg) try: - return list_flows(user_id=str(self.user_id)) + return await list_flows(user_id=str(self.user_id)) except Exception as e: msg = f"Error listing flows: {e}" raise ValueError(msg) from e diff --git a/src/backend/base/langflow/helpers/flow.py b/src/backend/base/langflow/helpers/flow.py index d3b3e5ffa..1e379c1c6 100644 --- a/src/backend/base/langflow/helpers/flow.py +++ b/src/backend/base/langflow/helpers/flow.py @@ -10,7 +10,7 @@ from sqlmodel import select from langflow.schema.schema import INPUT_FIELD_NAME from langflow.services.database.models.flow import Flow from langflow.services.database.models.flow.model import FlowRead -from langflow.services.deps import async_session_scope, get_settings_service, session_scope +from langflow.services.deps import async_session_scope, get_settings_service if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -27,16 +27,15 @@ INPUT_TYPE_MAP = { } -def list_flows(*, user_id: str | None = None) -> list[Data]: +async def list_flows(*, user_id: str | None = None) -> list[Data]: if not user_id: msg = "Session is invalid" raise ValueError(msg) try: - with session_scope() as session: + async with async_session_scope() as session: uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id - flows = session.exec( - select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712 - ).all() + stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712 + flows = (await session.exec(stmt)).all() return [flow.to_data() for flow in flows] except Exception as e: diff --git a/src/backend/tests/unit/test_custom_component_with_client.py b/src/backend/tests/unit/test_custom_component_with_client.py index 736c4913f..30d639aca 100644 --- a/src/backend/tests/unit/test_custom_component_with_client.py +++ b/src/backend/tests/unit/test_custom_component_with_client.py @@ -20,13 +20,13 @@ def component( ) -def test_list_flows_flow_objects(component): - flows = component.list_flows() +async def test_list_flows_flow_objects(component): + flows = await component.alist_flows() are_flows = [isinstance(flow, Data) for flow in flows] flow_types = [type(flow) for flow in flows] assert all(are_flows), f"Expected all flows to be Data objects, got {flow_types}" -def test_list_flows_return_type(component): - flows = component.list_flows() +async def test_list_flows_return_type(component): + flows = await component.alist_flows() assert isinstance(flows, list)