ref: Make list_flows async (#5222)

* Make list_flows async

* Keep sync list_flows method for backward compatibility
This commit is contained in:
Christophe Bornet 2024-12-12 17:02:43 +01:00 committed by GitHub
commit 2e19403660
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 42 additions and 38 deletions

View file

@ -12,9 +12,9 @@ class ListFlowsComponent(CustomComponent):
def build_config(self):
return {}
def build(
async def build(
self,
) -> list[Data]:
flows = self.list_flows()
flows = await self.alist_flows()
self.status = flows
return flows

View file

@ -24,28 +24,28 @@ class SubFlowComponent(CustomComponent):
field_order = ["flow_name"]
name = "SubFlow"
def get_flow_names(self) -> list[str]:
flow_datas = self.list_flows()
async def get_flow_names(self) -> list[str]:
flow_datas = await self.alist_flows()
return [flow_data.data["name"] for flow_data in flow_datas]
def get_flow(self, flow_name: str) -> Data | None:
flow_datas = self.list_flows()
async def get_flow(self, flow_name: str) -> Data | None:
flow_datas = await self.alist_flows()
for flow_data in flow_datas:
if flow_data.data["name"] == flow_name:
return flow_data
return None
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
logger.debug(f"Updating build config with field value {field_value} and field name {field_name}")
if field_name == "flow_name":
build_config["flow_name"]["options"] = self.get_flow_names()
build_config["flow_name"]["options"] = await self.get_flow_names()
# Clean up the build config
for key in list(build_config.keys()):
if key not in {*self.field_order, "code", "_type", "get_final_results_only"}:
del build_config[key]
if field_value is not None and field_name == "flow_name":
try:
flow_data = self.get_flow(field_value)
flow_data = await self.get_flow(field_value)
except Exception: # noqa: BLE001
logger.exception(f"Error getting flow {field_value}")
else:

View file

@ -22,11 +22,11 @@ class FlowToolComponent(LCToolComponent):
beta = True
icon = "hammer"
def get_flow_names(self) -> list[str]:
flow_datas = self.list_flows()
async def get_flow_names(self) -> list[str]:
flow_datas = await self.alist_flows()
return [flow_data.data["name"] for flow_data in flow_datas]
def get_flow(self, flow_name: str) -> Data | None:
async def get_flow(self, flow_name: str) -> Data | None:
"""Retrieves a flow by its name.
Args:
@ -35,14 +35,14 @@ class FlowToolComponent(LCToolComponent):
Returns:
Optional[Text]: The flow record if found, None otherwise.
"""
flow_datas = self.list_flows()
flow_datas = await self.alist_flows()
for flow_data in flow_datas:
if flow_data.data["name"] == flow_name:
return flow_data
return None
@override
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "flow_name":
build_config["flow_name"]["options"] = self.get_flow_names()
@ -74,13 +74,13 @@ class FlowToolComponent(LCToolComponent):
Output(name="api_build_tool", display_name="Tool", method="build_tool"),
]
def build_tool(self) -> Tool:
async def build_tool(self) -> Tool:
FlowTool.model_rebuild()
if "flow_name" not in self._attributes or not self._attributes["flow_name"]:
msg = "Flow name is required"
raise ValueError(msg)
flow_name = self._attributes["flow_name"]
flow_data = self.get_flow(flow_name)
flow_data = await self.get_flow(flow_name)
if not flow_data:
msg = "Flow not found."
raise ValueError(msg)

View file

@ -18,14 +18,14 @@ class RunFlowComponent(Component):
legacy: bool = True
icon = "workflow"
def get_flow_names(self) -> list[str]:
flow_data = self.list_flows()
async def get_flow_names(self) -> list[str]:
flow_data = await self.alist_flows()
return [flow_data.data["name"] for flow_data in flow_data]
@override
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "flow_name":
build_config["flow_name"]["options"] = self.get_flow_names()
build_config["flow_name"]["options"] = await self.get_flow_names()
return build_config

View file

@ -18,27 +18,27 @@ class SubFlowComponent(Component):
beta: bool = True
icon = "Workflow"
def get_flow_names(self) -> list[str]:
flow_data = self.list_flows()
async def get_flow_names(self) -> list[str]:
flow_data = await self.alist_flows()
return [flow_data.data["name"] for flow_data in flow_data]
def get_flow(self, flow_name: str) -> Data | None:
flow_datas = self.list_flows()
async def get_flow(self, flow_name: str) -> Data | None:
flow_datas = await self.alist_flows()
for flow_data in flow_datas:
if flow_data.data["name"] == flow_name:
return flow_data
return None
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "flow_name":
build_config["flow_name"]["options"] = self.get_flow_names()
build_config["flow_name"]["options"] = await self.get_flow_names()
for key in list(build_config.keys()):
if key not in [x.name for x in self.inputs] + ["code", "_type", "get_final_results_only"]:
del build_config[key]
if field_value is not None and field_name == "flow_name":
try:
flow_data = self.get_flow(field_value)
flow_data = await self.get_flow(field_value)
except Exception: # noqa: BLE001
logger.exception(f"Error getting flow {field_value}")
else:

View file

@ -19,6 +19,7 @@ from langflow.services.storage.service import StorageService
from langflow.template.utils import update_frontend_node_with_template_values
from langflow.type_extraction.type_extraction import post_process_type
from langflow.utils import validate
from langflow.utils.async_helpers import run_until_complete
if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackHandler
@ -509,11 +510,15 @@ class CustomComponent(BaseComponent):
)
def list_flows(self) -> list[Data]:
"""This is kept for backward compatibility. Using alist_flows instead is recommended."""
return run_until_complete(self.alist_flows())
async def alist_flows(self) -> list[Data]:
if not self.user_id:
msg = "Session is invalid"
raise ValueError(msg)
try:
return list_flows(user_id=str(self.user_id))
return await list_flows(user_id=str(self.user_id))
except Exception as e:
msg = f"Error listing flows: {e}"
raise ValueError(msg) from e

View file

@ -10,7 +10,7 @@ from sqlmodel import select
from langflow.schema.schema import INPUT_FIELD_NAME
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.deps import async_session_scope, get_settings_service, session_scope
from langflow.services.deps import async_session_scope, get_settings_service
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
@ -27,16 +27,15 @@ INPUT_TYPE_MAP = {
}
def list_flows(*, user_id: str | None = None) -> list[Data]:
async def list_flows(*, user_id: str | None = None) -> list[Data]:
if not user_id:
msg = "Session is invalid"
raise ValueError(msg)
try:
with session_scope() as session:
async with async_session_scope() as session:
uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id
flows = session.exec(
select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712
).all()
stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712
flows = (await session.exec(stmt)).all()
return [flow.to_data() for flow in flows]
except Exception as e:

View file

@ -20,13 +20,13 @@ def component(
)
def test_list_flows_flow_objects(component):
flows = component.list_flows()
async def test_list_flows_flow_objects(component):
flows = await component.alist_flows()
are_flows = [isinstance(flow, Data) for flow in flows]
flow_types = [type(flow) for flow in flows]
assert all(are_flows), f"Expected all flows to be Data objects, got {flow_types}"
def test_list_flows_return_type(component):
flows = component.list_flows()
async def test_list_flows_return_type(component):
flows = await component.alist_flows()
assert isinstance(flows, list)