Refactor RunFlowComponent and CustomComponent classes

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-04 10:30:41 -03:00
commit 0094732218
2 changed files with 27 additions and 13 deletions

View file

@ -12,7 +12,6 @@ class RunFlowComponent(CustomComponent):
def get_flow_names(self) -> List[str]:
flow_records = self.list_flows()
self.flow_records = flow_records
return [flow_record.data["name"] for flow_record in flow_records]
def build_config(self):
@ -21,7 +20,7 @@ class RunFlowComponent(CustomComponent):
"display_name": "Input Value",
"multiline": True,
},
"flow": {
"flow_name": {
"display_name": "Flow ID",
"info": "The ID of the flow to run.",
"options": self.get_flow_names,
@ -32,19 +31,13 @@ class RunFlowComponent(CustomComponent):
},
}
async def build(self, input_value: Text, flow: str, tweaks: NestedDict) -> Record:
async def build(
self, input_value: Text, flow_name: str, tweaks: NestedDict
) -> Record:
input_dict = {"input_value": input_value}
flow_ids = [
flow_record.data["id"]
for flow_record in self.flow_records
if flow_record.data["name"] == flow
]
if not flow_ids:
raise ValueError(f"Flow {flow} not found.")
flow_id = flow_ids[0]
result: List[Optional[ResultData]] = await self.run_flow(
input_value=input_dict, flow_id=flow_id, tweaks=tweaks
input_value=input_dict, flow_name=flow_name, tweaks=tweaks
)
record = Record(data=result)
self.status = record

View file

@ -74,6 +74,7 @@ class CustomComponent(Component):
user_id: Optional[Union[UUID, str]] = None
status: Optional[Any] = None
"""The status of the component. This is displayed on the frontend. Defaults to None."""
_flows_records: Optional[List[Record]] = None
_tree: Optional[dict] = None
@ -334,9 +335,28 @@ class CustomComponent(Component):
async def run_flow(
self,
input_value: Union[str, list[str]],
flow_id: str,
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
tweaks: Optional[dict] = None,
) -> Any:
if not flow_id and not flow_name:
raise ValueError("Flow ID or Flow Name is required")
if not flow_id and self._flows_records:
flow_ids = [
flow.data["id"]
for flow in self._flows_records
if flow.data["name"] == flow_name
]
if not flow_ids:
raise ValueError(f"Flow {flow_name} not found")
elif len(flow_ids) > 1:
raise ValueError(f"Multiple flows found with the name {flow_name}")
flow_id = flow_ids[0]
if not flow_id:
raise ValueError(f"Flow {flow_name} not found")
graph = await self.load_flow(flow_id, tweaks)
input_value_dict = {"input_value": input_value}
return await graph.run(input_value_dict, stream=False)
@ -355,6 +375,7 @@ class CustomComponent(Component):
).all()
flows_records = [flow.to_record() for flow in flows]
self._flows_records = flows_records
return flows_records
except Exception as e:
raise ValueError(f"Error listing flows: {e}")