From 3586ba8469782a38f774e56497be3700f93f485b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 20 Dec 2023 14:53:15 -0300 Subject: [PATCH] Refactor process.py: Add import statements and update generate_result function --- src/backend/langflow/processing/process.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index 766027bdf..982ccc6f1 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -6,11 +6,12 @@ from langchain.chains.base import Chain from langchain.schema import AgentAction, Document from langchain.vectorstores.base import VectorStore from langchain_core.runnables.base import Runnable +from loguru import logger +from pydantic import BaseModel + from langflow.components.custom_components import CustomComponent from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys from langflow.services.deps import get_session_service -from loguru import logger -from pydantic import BaseModel def fix_memory_inputs(langchain_object): @@ -118,7 +119,7 @@ def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict: return inputs -async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict): +async def generate_result(langchain_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]): if isinstance(langchain_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -131,11 +132,15 @@ async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: d elif isinstance(langchain_object, Document): result = langchain_object.dict() elif isinstance(langchain_object, Runnable): - if isinstance(inputs, List): - call_func = langchain_object.abatch - elif isinstance(inputs, dict): - call_func = langchain_object.ainvoke - result = await call_func(inputs) + # Define call_method as a coroutine function + # by default + if isinstance(inputs, List) and hasattr(langchain_object, "abatch"): + call_method = langchain_object.abatch + elif isinstance(inputs, dict) and hasattr(langchain_object, "ainvoke"): + call_method = langchain_object.ainvoke + else: + raise ValueError("Inputs must be provided for a Runnable") + result = await call_method(inputs) if isinstance(result, list): result = [r.content if hasattr(r, "content") else r for r in result] elif hasattr(result, "content"):