From 6c7d06b2f9e4cff72938798067fee508cfe1833d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 22 Mar 2024 23:07:07 -0300 Subject: [PATCH] Refactor graph run method and custom component usage --- src/backend/langflow/graph/graph/base.py | 40 ++++++++++++++++++- .../custom_component/custom_component.py | 19 ++------- src/backend/langflow/processing/process.py | 2 +- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 5b50e56cd..3621d539d 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -261,7 +261,45 @@ class Graph: return vertex_outputs - async def run( + def run( + self, + inputs: Dict[str, str], + input_components: Optional[list[str]] = None, + types: Optional[list[str]] = None, + outputs: Optional[list[str]] = None, + session_id: Optional[str] = None, + stream: bool = False, + ) -> List[RunOutputs]: + """ + Run the graph with the given inputs and return the outputs. + + Args: + inputs (Dict[str, str]): A dictionary of input values. + input_components (Optional[list[str]]): A list of input components. + types (Optional[list[str]]): A list of types. + outputs (Optional[list[str]]): A list of output components. + session_id (Optional[str]): The session ID. + stream (bool): Whether to stream the outputs. + + Returns: + List[RunOutputs]: A list of RunOutputs objects representing the outputs. + """ + # run the async function in a sync way + # this could be used in a FastAPI endpoint + # so we should take care of the event loop + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.arun( + inputs=inputs, + inputs_components=input_components, + types=types, + outputs=outputs, + session_id=session_id, + stream=stream, + ) + ) + + async def arun( self, inputs: list[Dict[str, str]], inputs_components: Optional[list[list[str]]] = None, diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index b19fbbf30..7d90263ef 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -1,15 +1,6 @@ import operator from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - List, - Optional, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union from uuid import UUID import yaml @@ -27,11 +18,7 @@ from langflow.schema import Record from langflow.schema.dotdict import dotdict from langflow.services.database.models.flow import Flow from langflow.services.database.utils import session_getter -from langflow.services.deps import ( - get_credential_service, - get_db_service, - get_storage_service, -) +from langflow.services.deps import get_credential_service, get_db_service, get_storage_service from langflow.services.storage.service import StorageService from langflow.utils import validate @@ -370,7 +357,7 @@ class CustomComponent(Component): input_value = [input_value] graph = await self.load_flow(flow_id, tweaks) input_value_dict = [{"input_value": input_val} for input_val in input_value] - return await graph.run(input_value_dict, stream=False) + return await graph.arun(input_value_dict, stream=False) def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]: if not self._user_id: diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index cd2c5f3e0..ff8c97bb5 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -226,7 +226,7 @@ async def run_graph( inputs_list.append({INPUT_FIELD_NAME: input_value_request.input_value}) types.append(input_value_request.type) - run_outputs = await graph.run( + run_outputs = await graph.arun( inputs_list, components, types,