Refactor process_inputs function to handle both dict and list inputs

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-22 10:46:35 -03:00
commit 75493bbdea

View file

@ -7,12 +7,11 @@ from langchain.schema import AgentAction, Document
from langchain.vectorstores.base import VectorStore
from langchain_core.messages import AIMessage
from langchain_core.runnables.base import Runnable
from loguru import logger
from pydantic import BaseModel
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys
from langflow.services.deps import get_session_service
from loguru import logger
from pydantic import BaseModel
def fix_memory_inputs(langchain_object):
@ -107,10 +106,19 @@ def get_build_result(data_graph, session_id):
return build_sorted_vertices(data_graph)
def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
def process_inputs(inputs: Union[dict, List[dict]], artifacts: Dict[str, Any]) -> dict:
if inputs is None:
inputs = {}
if isinstance(inputs, dict):
inputs = update_inputs_dict(inputs, artifacts)
elif isinstance(inputs, List):
inputs = [update_inputs_dict(inp, artifacts) for inp in inputs]
return inputs
def update_inputs_dict(inputs: dict, artifacts: Dict[str, Any]) -> dict:
for key, value in artifacts.items():
if key == "repr":
continue