fix: Remove sync Graph.run() (#4287)

Remove sync Graph.run()
This commit is contained in:
Christophe Bornet 2024-10-27 15:44:43 +01:00 committed by GitHub
commit eccdb3a566
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 66 deletions

View file

@ -707,58 +707,6 @@ class Graph:
return vertex_outputs
def run(
self,
inputs: list[dict[str, str]],
*,
input_components: list[list[str]] | None = None,
types: list[InputType | None] | None = None,
outputs: list[str] | None = None,
session_id: str | None = None,
stream: bool = False,
fallback_to_env_vars: bool = False,
) -> list[RunOutputs]:
"""Run the graph with the given inputs and return the outputs.
Args:
inputs (Dict[str, str]): A dictionary of input values.
input_components (Optional[list[str]]): A list of input components.
types (Optional[list[str]]): A list of types.
outputs (Optional[list[str]]): A list of output components.
session_id (Optional[str]): The session ID.
stream (bool): Whether to stream the outputs.
fallback_to_env_vars (bool): Whether to fallback to environment variables.
Returns:
List[RunOutputs]: A list of RunOutputs objects representing the outputs.
"""
# run the async function in a sync way
# this could be used in a FastAPI endpoint
# so we should take care of the event loop
coro = self.arun(
inputs=inputs,
inputs_components=input_components,
types=types,
outputs=outputs,
session_id=session_id,
stream=stream,
fallback_to_env_vars=fallback_to_env_vars,
)
try:
# Attempt to get the running event loop; if none, an exception is raised
loop = asyncio.get_running_loop()
except RuntimeError:
# If there's no running event loop, use asyncio.run
return asyncio.run(coro)
# If the event loop is closed, use asyncio.run
if loop.is_closed():
return asyncio.run(coro)
# If there's an existing, open event loop, use it to run the async function
return loop.run_until_complete(coro)
async def arun(
self,
inputs: list[dict[str, str]],

View file

@ -1,8 +1,8 @@
import asyncio
import json
from pathlib import Path
from dotenv import load_dotenv
from loguru import logger
from langflow.graph import Graph
from langflow.graph.schema import RunOutputs
@ -69,7 +69,7 @@ def load_flow_from_json(
return Graph.from_payload(graph_data)
def run_flow_from_json(
async def arun_flow_from_json(
flow: Path | str | dict,
input_value: str,
*,
@ -106,17 +106,11 @@ def run_flow_from_json(
Returns:
List[RunOutputs]: A list of RunOutputs objects representing the results of running the flow.
"""
# Set all streaming to false
try:
import nest_asyncio
nest_asyncio.apply()
except Exception: # noqa: BLE001
logger.opt(exception=True).warning("Could not apply nest_asyncio")
if tweaks is None:
tweaks = {}
tweaks["stream"] = False
graph = load_flow_from_json(
graph = await asyncio.to_thread(
load_flow_from_json,
flow=flow,
tweaks=tweaks,
log_level=log_level,
@ -125,7 +119,7 @@ def run_flow_from_json(
cache=cache,
disable_logs=disable_logs,
)
return run_graph(
return await run_graph(
graph=graph,
session_id=session_id,
input_value=input_value,
@ -134,3 +128,43 @@ def run_flow_from_json(
output_component=output_component,
fallback_to_env_vars=fallback_to_env_vars,
)
def run_flow_from_json(
flow: Path | str | dict,
input_value: str,
*,
session_id: str | None = None,
tweaks: dict | None = None,
input_type: str = "chat",
output_type: str = "chat",
output_component: str | None = None,
log_level: str | None = None,
log_file: str | None = None,
env_file: str | None = None,
cache: str | None = None,
disable_logs: bool | None = True,
fallback_to_env_vars: bool = False,
) -> list[RunOutputs]:
coro = arun_flow_from_json(
flow,
input_value,
session_id=session_id,
tweaks=tweaks,
input_type=input_type,
output_type=output_type,
output_component=output_component,
log_level=log_level,
log_file=log_file,
env_file=env_file,
cache=cache,
disable_logs=disable_logs,
fallback_to_env_vars=fallback_to_env_vars,
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
return loop.run_until_complete(coro)

View file

@ -58,7 +58,7 @@ async def run_graph_internal(
return run_outputs, session_id_str
def run_graph(
async def run_graph(
graph: Graph,
input_value: str,
input_type: str,
@ -104,9 +104,9 @@ def run_graph(
components.append(input_value_request.components or [])
inputs_list.append({INPUT_FIELD_NAME: input_value_request.input_value})
types.append(input_value_request.type)
return graph.run(
return await graph.arun(
inputs_list,
input_components=components,
inputs_components=components,
types=types,
outputs=outputs or [],
stream=False,