fix: Add session_id and graph parameters to flow execution functions and update graph init in Flow as Tool (#4558)

* Add session_id and graph parameters to flow function and update graph initialization logic

* Enhance `FlowTool` with user and session context in graph creation

* Add session_id and graph parameters to flow execution functions
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-11-13 00:45:38 -03:00 committed by GitHub
commit f978223744
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 19 additions and 3 deletions

View file

@ -23,6 +23,7 @@ class FlowTool(BaseTool):
graph: Graph | None = None
flow_id: str | None = None
user_id: str | None = None
session_id: str | None = None
inputs: list[Vertex] = []
get_final_results_only: bool = True
@ -59,9 +60,11 @@ class FlowTool(BaseTool):
run_outputs = run_until_complete(
run_flow(
graph=self.graph,
tweaks={key: {"input_value": value} for key, value in tweaks.items()},
flow_id=self.flow_id,
user_id=self.user_id,
session_id=self.session_id,
)
)
if not run_outputs:
@ -113,6 +116,8 @@ class FlowTool(BaseTool):
flow_id=self.flow_id,
user_id=self.user_id,
run_id=run_id,
session_id=self.session_id,
graph=self.graph,
)
if not run_outputs:
return "No output"

View file

@ -75,7 +75,7 @@ class FlowToolComponent(LCToolComponent):
]
def build_tool(self) -> Tool:
FlowTool.update_forward_refs()
FlowTool.model_rebuild()
if "flow_name" not in self._attributes or not self._attributes["flow_name"]:
msg = "Flow name is required"
raise ValueError(msg)
@ -84,7 +84,10 @@ class FlowToolComponent(LCToolComponent):
if not flow_data:
msg = "Flow not found."
raise ValueError(msg)
graph = Graph.from_payload(flow_data.data["data"])
graph = Graph.from_payload(
flow_data.data["data"],
user_id=str(self.user_id),
)
try:
graph.set_run_id(self.graph.run_id)
except Exception: # noqa: BLE001
@ -98,6 +101,7 @@ class FlowToolComponent(LCToolComponent):
inputs=inputs,
flow_id=str(flow_data.id),
user_id=str(self.user_id),
session_id=self.graph.session_id if hasattr(self, "graph") else None,
)
description_repr = repr(tool.description).strip("'")
args_str = "\n".join([f"- {arg_name}: {arg_data['description']}" for arg_name, arg_data in tool.args.items()])

View file

@ -82,13 +82,20 @@ async def run_flow(
output_type: str | None = "chat",
user_id: str | None = None,
run_id: str | None = None,
session_id: str | None = None,
graph: Graph | None = None,
) -> list[RunOutputs]:
if user_id is None:
msg = "Session is invalid"
raise ValueError(msg)
graph = await load_flow(user_id, flow_id, flow_name, tweaks)
if graph is None:
graph = await load_flow(user_id, flow_id, flow_name, tweaks)
if run_id:
graph.set_run_id(UUID(run_id))
if session_id:
graph.session_id = session_id
if user_id:
graph.user_id = user_id
if inputs is None:
inputs = []