ref: Make list_flows async (#5222)
* Make list_flows async * Keep sync list_flows method for backward compatibility
This commit is contained in:
parent
d73bf214a8
commit
2e19403660
8 changed files with 42 additions and 38 deletions
|
|
@ -12,9 +12,9 @@ class ListFlowsComponent(CustomComponent):
|
|||
def build_config(self):
|
||||
return {}
|
||||
|
||||
def build(
|
||||
async def build(
|
||||
self,
|
||||
) -> list[Data]:
|
||||
flows = self.list_flows()
|
||||
flows = await self.alist_flows()
|
||||
self.status = flows
|
||||
return flows
|
||||
|
|
|
|||
|
|
@ -24,28 +24,28 @@ class SubFlowComponent(CustomComponent):
|
|||
field_order = ["flow_name"]
|
||||
name = "SubFlow"
|
||||
|
||||
def get_flow_names(self) -> list[str]:
|
||||
flow_datas = self.list_flows()
|
||||
async def get_flow_names(self) -> list[str]:
|
||||
flow_datas = await self.alist_flows()
|
||||
return [flow_data.data["name"] for flow_data in flow_datas]
|
||||
|
||||
def get_flow(self, flow_name: str) -> Data | None:
|
||||
flow_datas = self.list_flows()
|
||||
async def get_flow(self, flow_name: str) -> Data | None:
|
||||
flow_datas = await self.alist_flows()
|
||||
for flow_data in flow_datas:
|
||||
if flow_data.data["name"] == flow_name:
|
||||
return flow_data
|
||||
return None
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
logger.debug(f"Updating build config with field value {field_value} and field name {field_name}")
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = self.get_flow_names()
|
||||
build_config["flow_name"]["options"] = await self.get_flow_names()
|
||||
# Clean up the build config
|
||||
for key in list(build_config.keys()):
|
||||
if key not in {*self.field_order, "code", "_type", "get_final_results_only"}:
|
||||
del build_config[key]
|
||||
if field_value is not None and field_name == "flow_name":
|
||||
try:
|
||||
flow_data = self.get_flow(field_value)
|
||||
flow_data = await self.get_flow(field_value)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error getting flow {field_value}")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -22,11 +22,11 @@ class FlowToolComponent(LCToolComponent):
|
|||
beta = True
|
||||
icon = "hammer"
|
||||
|
||||
def get_flow_names(self) -> list[str]:
|
||||
flow_datas = self.list_flows()
|
||||
async def get_flow_names(self) -> list[str]:
|
||||
flow_datas = await self.alist_flows()
|
||||
return [flow_data.data["name"] for flow_data in flow_datas]
|
||||
|
||||
def get_flow(self, flow_name: str) -> Data | None:
|
||||
async def get_flow(self, flow_name: str) -> Data | None:
|
||||
"""Retrieves a flow by its name.
|
||||
|
||||
Args:
|
||||
|
|
@ -35,14 +35,14 @@ class FlowToolComponent(LCToolComponent):
|
|||
Returns:
|
||||
Optional[Text]: The flow record if found, None otherwise.
|
||||
"""
|
||||
flow_datas = self.list_flows()
|
||||
flow_datas = await self.alist_flows()
|
||||
for flow_data in flow_datas:
|
||||
if flow_data.data["name"] == flow_name:
|
||||
return flow_data
|
||||
return None
|
||||
|
||||
@override
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = self.get_flow_names()
|
||||
|
||||
|
|
@ -74,13 +74,13 @@ class FlowToolComponent(LCToolComponent):
|
|||
Output(name="api_build_tool", display_name="Tool", method="build_tool"),
|
||||
]
|
||||
|
||||
def build_tool(self) -> Tool:
|
||||
async def build_tool(self) -> Tool:
|
||||
FlowTool.model_rebuild()
|
||||
if "flow_name" not in self._attributes or not self._attributes["flow_name"]:
|
||||
msg = "Flow name is required"
|
||||
raise ValueError(msg)
|
||||
flow_name = self._attributes["flow_name"]
|
||||
flow_data = self.get_flow(flow_name)
|
||||
flow_data = await self.get_flow(flow_name)
|
||||
if not flow_data:
|
||||
msg = "Flow not found."
|
||||
raise ValueError(msg)
|
||||
|
|
|
|||
|
|
@ -18,14 +18,14 @@ class RunFlowComponent(Component):
|
|||
legacy: bool = True
|
||||
icon = "workflow"
|
||||
|
||||
def get_flow_names(self) -> list[str]:
|
||||
flow_data = self.list_flows()
|
||||
async def get_flow_names(self) -> list[str]:
|
||||
flow_data = await self.alist_flows()
|
||||
return [flow_data.data["name"] for flow_data in flow_data]
|
||||
|
||||
@override
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = self.get_flow_names()
|
||||
build_config["flow_name"]["options"] = await self.get_flow_names()
|
||||
|
||||
return build_config
|
||||
|
||||
|
|
|
|||
|
|
@ -18,27 +18,27 @@ class SubFlowComponent(Component):
|
|||
beta: bool = True
|
||||
icon = "Workflow"
|
||||
|
||||
def get_flow_names(self) -> list[str]:
|
||||
flow_data = self.list_flows()
|
||||
async def get_flow_names(self) -> list[str]:
|
||||
flow_data = await self.alist_flows()
|
||||
return [flow_data.data["name"] for flow_data in flow_data]
|
||||
|
||||
def get_flow(self, flow_name: str) -> Data | None:
|
||||
flow_datas = self.list_flows()
|
||||
async def get_flow(self, flow_name: str) -> Data | None:
|
||||
flow_datas = await self.alist_flows()
|
||||
for flow_data in flow_datas:
|
||||
if flow_data.data["name"] == flow_name:
|
||||
return flow_data
|
||||
return None
|
||||
|
||||
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
async def aupdate_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
|
||||
if field_name == "flow_name":
|
||||
build_config["flow_name"]["options"] = self.get_flow_names()
|
||||
build_config["flow_name"]["options"] = await self.get_flow_names()
|
||||
|
||||
for key in list(build_config.keys()):
|
||||
if key not in [x.name for x in self.inputs] + ["code", "_type", "get_final_results_only"]:
|
||||
del build_config[key]
|
||||
if field_value is not None and field_name == "flow_name":
|
||||
try:
|
||||
flow_data = self.get_flow(field_value)
|
||||
flow_data = await self.get_flow(field_value)
|
||||
except Exception: # noqa: BLE001
|
||||
logger.exception(f"Error getting flow {field_value}")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from langflow.services.storage.service import StorageService
|
|||
from langflow.template.utils import update_frontend_node_with_template_values
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
from langflow.utils import validate
|
||||
from langflow.utils.async_helpers import run_until_complete
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
|
@ -509,11 +510,15 @@ class CustomComponent(BaseComponent):
|
|||
)
|
||||
|
||||
def list_flows(self) -> list[Data]:
|
||||
"""This is kept for backward compatibility. Using alist_flows instead is recommended."""
|
||||
return run_until_complete(self.alist_flows())
|
||||
|
||||
async def alist_flows(self) -> list[Data]:
|
||||
if not self.user_id:
|
||||
msg = "Session is invalid"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
return list_flows(user_id=str(self.user_id))
|
||||
return await list_flows(user_id=str(self.user_id))
|
||||
except Exception as e:
|
||||
msg = f"Error listing flows: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from sqlmodel import select
|
|||
from langflow.schema.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.flow.model import FlowRead
|
||||
from langflow.services.deps import async_session_scope, get_settings_service, session_scope
|
||||
from langflow.services.deps import async_session_scope, get_settings_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
|
@ -27,16 +27,15 @@ INPUT_TYPE_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def list_flows(*, user_id: str | None = None) -> list[Data]:
|
||||
async def list_flows(*, user_id: str | None = None) -> list[Data]:
|
||||
if not user_id:
|
||||
msg = "Session is invalid"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
with session_scope() as session:
|
||||
async with async_session_scope() as session:
|
||||
uuid_user_id = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
flows = session.exec(
|
||||
select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712
|
||||
).all()
|
||||
stmt = select(Flow).where(Flow.user_id == uuid_user_id).where(Flow.is_component == False) # noqa: E712
|
||||
flows = (await session.exec(stmt)).all()
|
||||
|
||||
return [flow.to_data() for flow in flows]
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ def component(
|
|||
)
|
||||
|
||||
|
||||
def test_list_flows_flow_objects(component):
|
||||
flows = component.list_flows()
|
||||
async def test_list_flows_flow_objects(component):
|
||||
flows = await component.alist_flows()
|
||||
are_flows = [isinstance(flow, Data) for flow in flows]
|
||||
flow_types = [type(flow) for flow in flows]
|
||||
assert all(are_flows), f"Expected all flows to be Data objects, got {flow_types}"
|
||||
|
||||
|
||||
def test_list_flows_return_type(component):
|
||||
flows = component.list_flows()
|
||||
async def test_list_flows_return_type(component):
|
||||
flows = await component.alist_flows()
|
||||
assert isinstance(flows, list)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue