Refactor graph run method and custom component usage
This commit is contained in:
parent
d0c3a5b30b
commit
6c7d06b2f9
3 changed files with 43 additions and 18 deletions
|
|
@ -261,7 +261,45 @@ class Graph:
|
|||
|
||||
return vertex_outputs
|
||||
|
||||
async def run(
|
||||
def run(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
input_components: Optional[list[str]] = None,
|
||||
types: Optional[list[str]] = None,
|
||||
outputs: Optional[list[str]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> List[RunOutputs]:
|
||||
"""
|
||||
Run the graph with the given inputs and return the outputs.
|
||||
|
||||
Args:
|
||||
inputs (Dict[str, str]): A dictionary of input values.
|
||||
input_components (Optional[list[str]]): A list of input components.
|
||||
types (Optional[list[str]]): A list of types.
|
||||
outputs (Optional[list[str]]): A list of output components.
|
||||
session_id (Optional[str]): The session ID.
|
||||
stream (bool): Whether to stream the outputs.
|
||||
|
||||
Returns:
|
||||
List[RunOutputs]: A list of RunOutputs objects representing the outputs.
|
||||
"""
|
||||
# run the async function in a sync way
|
||||
# this could be used in a FastAPI endpoint
|
||||
# so we should take care of the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.arun(
|
||||
inputs=inputs,
|
||||
inputs_components=input_components,
|
||||
types=types,
|
||||
outputs=outputs,
|
||||
session_id=session_id,
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
inputs: list[Dict[str, str]],
|
||||
inputs_components: Optional[list[list[str]]] = None,
|
||||
|
|
|
|||
|
|
@ -1,15 +1,6 @@
|
|||
import operator
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
|
|
@ -27,11 +18,7 @@ from langflow.schema import Record
|
|||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.utils import session_getter
|
||||
from langflow.services.deps import (
|
||||
get_credential_service,
|
||||
get_db_service,
|
||||
get_storage_service,
|
||||
)
|
||||
from langflow.services.deps import get_credential_service, get_db_service, get_storage_service
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.utils import validate
|
||||
|
||||
|
|
@ -370,7 +357,7 @@ class CustomComponent(Component):
|
|||
input_value = [input_value]
|
||||
graph = await self.load_flow(flow_id, tweaks)
|
||||
input_value_dict = [{"input_value": input_val} for input_val in input_value]
|
||||
return await graph.run(input_value_dict, stream=False)
|
||||
return await graph.arun(input_value_dict, stream=False)
|
||||
|
||||
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]:
|
||||
if not self._user_id:
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ async def run_graph(
|
|||
inputs_list.append({INPUT_FIELD_NAME: input_value_request.input_value})
|
||||
types.append(input_value_request.type)
|
||||
|
||||
run_outputs = await graph.run(
|
||||
run_outputs = await graph.arun(
|
||||
inputs_list,
|
||||
components,
|
||||
types,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue