From dd9347a18693149f90bc8224335e8a203504ce1d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 23 Mar 2024 00:21:27 -0300 Subject: [PATCH] Refactor process.py and schema.py --- src/backend/langflow/processing/process.py | 80 ++-------------------- src/backend/langflow/schema/schema.py | 3 + 2 files changed, 7 insertions(+), 76 deletions(-) diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index ff8c97bb5..563815b78 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -1,20 +1,15 @@ -import asyncio -from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from langchain.agents import AgentExecutor -from langchain.chains.base import Chain -from langchain.schema import AgentAction, Document -from langchain_community.vectorstores import VectorStore -from langchain_core.messages import AIMessage -from langchain_core.runnables.base import Runnable +from langchain.schema import AgentAction from loguru import logger from pydantic import BaseModel from langflow.graph.graph.base import Graph -from langflow.graph.schema import INPUT_FIELD_NAME, RunOutputs +from langflow.graph.schema import RunOutputs from langflow.graph.vertex.base import Vertex -from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.run import get_memory_key, update_memory_keys +from langflow.schema.schema import INPUT_FIELD_NAME from langflow.services.session.service import SessionService if TYPE_CHECKING: @@ -124,73 +119,6 @@ def update_inputs_dict(inputs: dict, artifacts: Dict[str, Any]) -> dict: return inputs -async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): - if isinstance(inputs, List) and hasattr(runnable, "abatch"): - result = await runnable.abatch(inputs) - elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"): - result = await runnable.ainvoke(inputs) - else: - raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}") - # Check if the result is a list of AIMessages - if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result): - result = [r.content for r in result] - elif isinstance(result, AIMessage): - result = result.content - return result - - -async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict): - if isinstance(built_object, Chain): - if inputs is None: - raise ValueError("Inputs must be provided for a Chain") - logger.debug("Generating result and thought") - result = get_result_and_thought(built_object, inputs) - - logger.debug("Generated result and thought") - elif isinstance(built_object, VectorStore) and "query" in inputs: - if isinstance(inputs, dict) and "search_type" not in inputs: - inputs["search_type"] = "similarity" - logger.info("search_type not provided, using default value: similarity") - result = built_object.search(**inputs) - elif isinstance(built_object, Document): - result = built_object.dict() - elif isinstance(built_object, Runnable): - result = await process_runnable(built_object, inputs) - if isinstance(result, list): - result = [r.content if hasattr(r, "content") else r for r in result] - elif hasattr(result, "content"): - result = result.content - else: - result = result - elif hasattr(built_object, "run") and isinstance(built_object, CustomComponent): - result = built_object.run(inputs) - else: - result = None - - return result - - -async def process_inputs_list(built_object: Runnable, inputs: List[dict]): - return await process_runnable(built_object, inputs) - - -async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]): - if isinstance(inputs, dict): - result = await process_inputs_dict(built_object, inputs) - elif isinstance(inputs, List) and isinstance(built_object, Runnable): - result = await process_inputs_list(built_object, inputs) - else: - raise ValueError(f"Invalid inputs type: {type(inputs)}") - - if result is None: - logger.warning(f"Unknown built_object type: {type(built_object)}") - if isinstance(built_object, Coroutine): - result = asyncio.run(built_object) - result = built_object - - return result - - class Result(BaseModel): result: Any session_id: str diff --git a/src/backend/langflow/schema/schema.py b/src/backend/langflow/schema/schema.py index 4079f90ac..93a733038 100644 --- a/src/backend/langflow/schema/schema.py +++ b/src/backend/langflow/schema/schema.py @@ -120,3 +120,6 @@ class Record(BaseModel): # check which attributes the Record has by checking the keys in the data dictionary def __dir__(self): return super().__dir__() + list(self.data.keys()) + + +INPUT_FIELD_NAME = "input_value"