Refactor graph run method and custom component usage

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-22 23:07:07 -03:00
commit 6c7d06b2f9
3 changed files with 43 additions and 18 deletions

View file

@ -261,7 +261,45 @@ class Graph:
return vertex_outputs
async def run(
def run(
self,
inputs: Dict[str, str],
input_components: Optional[list[str]] = None,
types: Optional[list[str]] = None,
outputs: Optional[list[str]] = None,
session_id: Optional[str] = None,
stream: bool = False,
) -> List[RunOutputs]:
"""
Run the graph with the given inputs and return the outputs.
Args:
inputs (Dict[str, str]): A dictionary of input values.
input_components (Optional[list[str]]): A list of input components.
types (Optional[list[str]]): A list of types.
outputs (Optional[list[str]]): A list of output components.
session_id (Optional[str]): The session ID.
stream (bool): Whether to stream the outputs.
Returns:
List[RunOutputs]: A list of RunOutputs objects representing the outputs.
"""
# run the async function in a sync way
# this could be used in a FastAPI endpoint
# so we should take care of the event loop
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.arun(
inputs=inputs,
inputs_components=input_components,
types=types,
outputs=outputs,
session_id=session_id,
stream=stream,
)
)
async def arun(
self,
inputs: list[Dict[str, str]],
inputs_components: Optional[list[list[str]]] = None,

View file

@ -1,15 +1,6 @@
import operator
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
List,
Optional,
Sequence,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union
from uuid import UUID
import yaml
@ -27,11 +18,7 @@ from langflow.schema import Record
from langflow.schema.dotdict import dotdict
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import (
get_credential_service,
get_db_service,
get_storage_service,
)
from langflow.services.deps import get_credential_service, get_db_service, get_storage_service
from langflow.services.storage.service import StorageService
from langflow.utils import validate
@ -370,7 +357,7 @@ class CustomComponent(Component):
input_value = [input_value]
graph = await self.load_flow(flow_id, tweaks)
input_value_dict = [{"input_value": input_val} for input_val in input_value]
return await graph.run(input_value_dict, stream=False)
return await graph.arun(input_value_dict, stream=False)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]:
if not self._user_id:

View file

@ -226,7 +226,7 @@ async def run_graph(
inputs_list.append({INPUT_FIELD_NAME: input_value_request.input_value})
types.append(input_value_request.type)
run_outputs = await graph.run(
run_outputs = await graph.arun(
inputs_list,
components,
types,