From b912a71e02811cc658ab6d2475f62988e4d9b86d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 28 May 2024 07:18:32 -0700 Subject: [PATCH] Fixed Sub Flow, Run Flow and Flow as Tool components (#1986) * feat(langflow): add utility functions to build records from run outputs and result data for better code organization and reusability * chore: Generate dynamic flow function with user ID parameter for better flow customization and tracking * chore: Refactor build_records_from_run_outputs and build_records_from_result_data for better code organization and reusability * chore: Update FlowToolComponent to include user ID parameter in build_function_and_schema method call * chore: Add conditional check for result_data in build_records_from_run_outputs * chore: Generate dynamic flow function with optional user ID parameter for better flow customization and tracking * feat: Add user ID parameter to Graph.from_payload method * chore: Add FlowTool class for flow processing and customization * chore: Update FlowToolComponent to use get_flow_inputs instead of build_function_and_schema * chore: Update FlowTool to handle optional user ID parameter --- .../langflow/base/flow_processing/__init__.py | 0 .../langflow/base/flow_processing/utils.py | 67 ++++++++++ .../base/langflow/base/tools/flow_tool.py | 117 ++++++++++++++++++ .../components/experimental/FlowTool.py | 20 +-- .../components/experimental/RunFlow.py | 20 +-- .../components/experimental/SubFlow.py | 20 +-- src/backend/base/langflow/helpers/flow.py | 53 ++++++-- 7 files changed, 247 insertions(+), 50 deletions(-) create mode 100644 src/backend/base/langflow/base/flow_processing/__init__.py create mode 100644 src/backend/base/langflow/base/flow_processing/utils.py create mode 100644 src/backend/base/langflow/base/tools/flow_tool.py diff --git a/src/backend/base/langflow/base/flow_processing/__init__.py b/src/backend/base/langflow/base/flow_processing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/base/langflow/base/flow_processing/utils.py b/src/backend/base/langflow/base/flow_processing/utils.py new file mode 100644 index 000000000..4e121f128 --- /dev/null +++ b/src/backend/base/langflow/base/flow_processing/utils.py @@ -0,0 +1,67 @@ +from typing import List + +from langflow.graph.schema import ResultData, RunOutputs +from langflow.schema.schema import Record + + +def build_records_from_run_outputs(run_outputs: RunOutputs) -> List[Record]: + """ + Build a list of records from the given RunOutputs. + + Args: + run_outputs (RunOutputs): The RunOutputs object containing the output data. + + Returns: + List[Record]: A list of records built from the RunOutputs. + + """ + if not run_outputs: + return [] + records = [] + for result_data in run_outputs.outputs: + if result_data: + records.extend(build_records_from_result_data(result_data)) + return records + + +def build_records_from_result_data(result_data: ResultData, get_final_results_only: bool = True) -> List[Record]: + """ + Build a list of records from the given ResultData. + + Args: + result_data (ResultData): The ResultData object containing the result data. + get_final_results_only (bool, optional): Whether to include only final results. Defaults to True. + + Returns: + List[Record]: A list of records built from the ResultData. + + """ + messages = result_data.messages + if not messages: + return [] + records = [] + for message in messages: + message_dict = message if isinstance(message, dict) else message.model_dump() + if get_final_results_only: + result_data_dict = result_data.model_dump() + results = result_data_dict.get("results", {}) + inner_result = results.get("result", {}) + record = Record(data={"result": inner_result, "message": message_dict}, text_key="result") + records.append(record) + return records + + +def format_flow_output_records(records: List[Record]) -> str: + """ + Format the flow output records into a string. + + Args: + records (List[Record]): The list of records to format. + + Returns: + str: The formatted flow output records. + + """ + result = "Flow run output:\n" + results = "\n".join([record.result for record in records if record.data["message"]]) + return result + results diff --git a/src/backend/base/langflow/base/tools/flow_tool.py b/src/backend/base/langflow/base/tools/flow_tool.py new file mode 100644 index 000000000..d0993bd99 --- /dev/null +++ b/src/backend/base/langflow/base/tools/flow_tool.py @@ -0,0 +1,117 @@ +from typing import Any, List, Optional, Type + +from asyncer import syncify +from langchain.tools import BaseTool +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import ToolException +from pydantic.v1 import BaseModel + +from langflow.base.flow_processing.utils import build_records_from_result_data, format_flow_output_records +from langflow.graph.graph.base import Graph +from langflow.graph.vertex.base import Vertex +from langflow.helpers.flow import build_schema_from_inputs, get_arg_names, get_flow_inputs, run_flow + + +class FlowTool(BaseTool): + name: str + description: str + graph: Optional[Graph] = None + flow_id: Optional[str] = None + user_id: Optional[str] = None + inputs: List["Vertex"] = [] + get_final_results_only: bool = True + + @property + def args(self) -> dict: + schema = self.get_input_schema() + return schema.schema()["properties"] + + def get_input_schema(self, config: Optional[RunnableConfig] = None) -> Type[BaseModel]: + """The tool's input schema.""" + if self.args_schema is not None: + return self.args_schema + elif self.graph is not None: + return build_schema_from_inputs(self.name, get_flow_inputs(self.graph)) + else: + raise ToolException("No input schema available.") + + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> str: + """Use the tool.""" + args_names = get_arg_names(self.inputs) + if len(args_names) == len(args): + kwargs = {arg["arg_name"]: arg_value for arg, arg_value in zip(args_names, args)} + elif len(args_names) != len(args) and len(args) != 0: + raise ToolException( + "Number of arguments does not match the number of inputs. Pass keyword arguments instead." + ) + tweaks = {arg["component_name"]: kwargs[arg["arg_name"]] for arg in args_names} + + run_outputs = syncify(run_flow, raise_sync_error=False)( + tweaks={key: {"input_value": value} for key, value in tweaks.items()}, + flow_id=self.flow_id, + user_id=self.user_id, + ) + if not run_outputs: + return "No output" + run_output = run_outputs[0] + + records = [] + if run_output is not None: + for output in run_output.outputs: + if output: + records.extend( + build_records_from_result_data(output, get_final_results_only=self.get_final_results_only) + ) + return format_flow_output_records(records) + + def validate_inputs(self, args_names: List[dict[str, str]], args: Any, kwargs: Any): + """Validate the inputs.""" + + if len(args) > 0 and len(args) != len(args_names): + raise ToolException( + "Number of positional arguments does not match the number of inputs. Pass keyword arguments instead." + ) + + if len(args) == len(args_names): + kwargs = {arg_name["arg_name"]: arg_value for arg_name, arg_value in zip(args_names, args)} + + missing_args = [arg["arg_name"] for arg in args_names if arg["arg_name"] not in kwargs] + if missing_args: + raise ToolException(f"Missing required arguments: {', '.join(missing_args)}") + + return kwargs + + def build_tweaks_dict(self, args, kwargs): + args_names = get_arg_names(self.inputs) + kwargs = self.validate_inputs(args_names=args_names, args=args, kwargs=kwargs) + tweaks = {arg["component_name"]: kwargs[arg["arg_name"]] for arg in args_names} + return tweaks + + async def _arun( + self, + *args: Any, + **kwargs: Any, + ) -> str: + """Use the tool asynchronously.""" + tweaks = self.build_tweaks_dict(args, kwargs) + run_outputs = await run_flow( + tweaks={key: {"input_value": value} for key, value in tweaks.items()}, + flow_id=self.flow_id, + user_id=self.user_id, + ) + if not run_outputs: + return "No output" + run_output = run_outputs[0] + + records = [] + if run_output is not None: + for output in run_output.outputs: + if output: + records.extend( + build_records_from_result_data(output, get_final_results_only=self.get_final_results_only) + ) + return format_flow_output_records(records) diff --git a/src/backend/base/langflow/components/experimental/FlowTool.py b/src/backend/base/langflow/components/experimental/FlowTool.py index 07f3b0e38..fa81f6351 100644 --- a/src/backend/base/langflow/components/experimental/FlowTool.py +++ b/src/backend/base/langflow/components/experimental/FlowTool.py @@ -1,14 +1,14 @@ from typing import Any, List, Optional -from asyncer import syncify -from langchain_core.tools import StructuredTool +from loguru import logger + +from langflow.base.tools.flow_tool import FlowTool from langflow.custom import CustomComponent from langflow.field_typing import Tool from langflow.graph.graph.base import Graph -from langflow.helpers.flow import build_function_and_schema +from langflow.helpers.flow import get_flow_inputs from langflow.schema.dotdict import dotdict from langflow.schema.schema import Record -from loguru import logger class FlowToolComponent(CustomComponent): @@ -68,18 +68,20 @@ class FlowToolComponent(CustomComponent): } async def build(self, flow_name: str, name: str, description: str, return_direct: bool = False) -> Tool: + FlowTool.update_forward_refs() flow_record = self.get_flow(flow_name) if not flow_record: raise ValueError("Flow not found.") graph = Graph.from_payload(flow_record.data["data"]) - dynamic_flow_function, schema = build_function_and_schema(flow_record, graph) - tool = StructuredTool.from_function( - func=syncify(dynamic_flow_function, raise_sync_error=False), # type: ignore - coroutine=dynamic_flow_function, + inputs = get_flow_inputs(graph) + tool = FlowTool( name=name, description=description, + graph=graph, return_direct=return_direct, - args_schema=schema, + inputs=inputs, + flow_id=str(flow_record.id), + user_id=str(self._user_id), ) description_repr = repr(tool.description).strip("'") args_str = "\n".join([f"- {arg_name}: {arg_data['description']}" for arg_name, arg_data in tool.args.items()]) diff --git a/src/backend/base/langflow/components/experimental/RunFlow.py b/src/backend/base/langflow/components/experimental/RunFlow.py index d3769de7a..d2e7dd285 100644 --- a/src/backend/base/langflow/components/experimental/RunFlow.py +++ b/src/backend/base/langflow/components/experimental/RunFlow.py @@ -1,8 +1,9 @@ from typing import Any, List, Optional +from langflow.base.flow_processing.utils import build_records_from_run_outputs from langflow.custom import CustomComponent from langflow.field_typing import NestedDict, Text -from langflow.graph.schema import ResultData +from langflow.graph.schema import RunOutputs from langflow.schema import Record, dotdict @@ -39,28 +40,17 @@ class RunFlowComponent(CustomComponent): }, } - def build_records_from_result_data(self, result_data: ResultData) -> List[Record]: - messages = result_data.messages - if not messages: - return [] - records = [] - for message in messages: - message_dict = message if isinstance(message, dict) else message.model_dump() - record = Record(text=message_dict.get("text", ""), data={"result": result_data}) - records.append(record) - return records - async def build(self, input_value: Text, flow_name: str, tweaks: NestedDict) -> List[Record]: - results: List[Optional[ResultData]] = await self.run_flow( + results: List[Optional[RunOutputs]] = await self.run_flow( inputs={"input_value": input_value}, flow_name=flow_name, tweaks=tweaks ) if isinstance(results, list): records = [] for result in results: if result: - records.extend(self.build_records_from_result_data(result)) + records.extend(build_records_from_run_outputs(result)) else: - records = self.build_records_from_result_data(results) + records = build_records_from_run_outputs()(results) self.status = records return records diff --git a/src/backend/base/langflow/components/experimental/SubFlow.py b/src/backend/base/langflow/components/experimental/SubFlow.py index 80e15c6ad..76a9538a4 100644 --- a/src/backend/base/langflow/components/experimental/SubFlow.py +++ b/src/backend/base/langflow/components/experimental/SubFlow.py @@ -2,9 +2,10 @@ from typing import Any, List, Optional from loguru import logger +from langflow.base.flow_processing.utils import build_records_from_result_data from langflow.custom import CustomComponent from langflow.graph.graph.base import Graph -from langflow.graph.schema import ResultData, RunOutputs +from langflow.graph.schema import RunOutputs from langflow.graph.vertex.base import Vertex from langflow.helpers.flow import get_flow_inputs from langflow.schema import Record @@ -92,21 +93,6 @@ class SubFlowComponent(CustomComponent): }, } - def build_records_from_result_data(self, result_data: ResultData, get_final_results_only: bool) -> List[Record]: - messages = result_data.messages - if not messages: - return [] - records = [] - for message in messages: - message_dict = message if isinstance(message, dict) else message.model_dump() - if get_final_results_only: - result_data_dict = result_data.model_dump() - results = result_data_dict.get("results", {}) - inner_result = results.get("result", {}) - record = Record(data={"result": inner_result, "message": message_dict}, text_key="result") - records.append(record) - return records - async def build(self, flow_name: str, get_final_results_only: bool = True, **kwargs) -> List[Record]: tweaks = {key: {"input_value": value} for key, value in kwargs.items()} run_outputs: List[Optional[RunOutputs]] = await self.run_flow( @@ -121,7 +107,7 @@ class SubFlowComponent(CustomComponent): if run_output is not None: for output in run_output.outputs: if output: - records.extend(self.build_records_from_result_data(output, get_final_results_only)) + records.extend(build_records_from_result_data(output, get_final_results_only)) self.status = records logger.debug(records) diff --git a/src/backend/base/langflow/helpers/flow.py b/src/backend/base/langflow/helpers/flow.py index 36b852a99..a20462f3d 100644 --- a/src/backend/base/langflow/helpers/flow.py +++ b/src/backend/base/langflow/helpers/flow.py @@ -1,10 +1,13 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple, Type, Union, cast +from uuid import UUID +from pydantic.v1 import BaseModel, Field, create_model +from sqlmodel import select + +from langflow.graph.schema import RunOutputs from langflow.schema.schema import INPUT_FIELD_NAME, Record from langflow.services.database.models.flow.model import Flow from langflow.services.deps import session_scope -from pydantic.v1 import BaseModel, Field, create_model -from sqlmodel import select if TYPE_CHECKING: from langflow.graph.graph.base import Graph @@ -51,7 +54,7 @@ async def load_flow( raise ValueError(f"Flow {flow_id} not found") if tweaks: graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks) - graph = Graph.from_payload(graph_data, flow_id=flow_id) + graph = Graph.from_payload(graph_data, flow_id=flow_id, user_id=user_id) return graph @@ -67,25 +70,29 @@ async def run_flow( flow_id: Optional[str] = None, flow_name: Optional[str] = None, user_id: Optional[str] = None, -) -> Any: +) -> List[RunOutputs]: if user_id is None: raise ValueError("Session is invalid") graph = await load_flow(user_id, flow_id, flow_name, tweaks) if inputs is None: inputs = [] + if isinstance(inputs, dict): + inputs = [inputs] inputs_list = [] inputs_components = [] types = [] for input_dict in inputs: inputs_list.append({INPUT_FIELD_NAME: cast(str, input_dict.get("input_value"))}) inputs_components.append(input_dict.get("components", [])) - types.append(input_dict.get("type", [])) + types.append(input_dict.get("type", "chat")) return await graph.arun(inputs_list, inputs_components=inputs_components, types=types) -def generate_function_for_flow(inputs: List["Vertex"], flow_id: str) -> Callable[..., Awaitable[Any]]: +def generate_function_for_flow( + inputs: List["Vertex"], flow_id: str, user_id: str | UUID | None +) -> Callable[..., Awaitable[Any]]: """ Generate a dynamic flow function based on the given inputs and flow ID. @@ -129,11 +136,23 @@ async def flow_function({func_args}): tweaks = {{ {arg_mappings} }} from langflow.helpers.flow import run_flow from langchain_core.tools import ToolException + from langflow.base.flow_processing.utils import build_records_from_result_data, format_flow_output_records try: - return await run_flow( + run_outputs = await run_flow( tweaks={{key: {{'input_value': value}} for key, value in tweaks.items()}}, flow_id="{flow_id}", + user_id="{user_id}" ) + if not run_outputs: + return [] + run_output = run_outputs[0] + + records = [] + if run_output is not None: + for output in run_output.outputs: + if output: + records.extend(build_records_from_result_data(output, get_final_results_only=True)) + return format_flow_output_records(records) except Exception as e: raise ToolException(f'Error running flow: ' + e) """ @@ -145,7 +164,7 @@ async def flow_function({func_args}): def build_function_and_schema( - flow_record: Record, graph: "Graph" + flow_record: Record, graph: "Graph", user_id: str | UUID | None ) -> Tuple[Callable[..., Awaitable[Any]], Type[BaseModel]]: """ Builds a dynamic function and schema for a given flow. @@ -159,7 +178,7 @@ def build_function_and_schema( """ flow_id = flow_record.id inputs = get_flow_inputs(graph) - dynamic_flow_function = generate_function_for_flow(inputs, flow_id) + dynamic_flow_function = generate_function_for_flow(inputs, flow_id, user_id=user_id) schema = build_schema_from_inputs(flow_record.name, inputs) return dynamic_flow_function, schema @@ -200,3 +219,19 @@ def build_schema_from_inputs(name: str, inputs: List["Vertex"]) -> Type[BaseMode description = input_.description fields[field_name] = (str, Field(default="", description=description)) return create_model(name, **fields) # type: ignore + + +def get_arg_names(inputs: List["Vertex"]) -> List[dict[str, str]]: + """ + Returns a list of dictionaries containing the component name and its corresponding argument name. + + Args: + inputs (List[Vertex]): A list of Vertex objects representing the inputs. + + Returns: + List[dict[str, str]]: A list of dictionaries, where each dictionary contains the component name and its argument name. + """ + return [ + {"component_name": input_.display_name, "arg_name": input_.display_name.lower().replace(" ", "_")} + for input_ in inputs + ]