Refactor process.py: Add import statements and update generate_result function

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-20 14:53:15 -03:00
commit 3586ba8469

View file

@ -6,11 +6,12 @@ from langchain.chains.base import Chain
from langchain.schema import AgentAction, Document
from langchain.vectorstores.base import VectorStore
from langchain_core.runnables.base import Runnable
from loguru import logger
from pydantic import BaseModel
from langflow.components.custom_components import CustomComponent
from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys
from langflow.services.deps import get_session_service
from loguru import logger
from pydantic import BaseModel
def fix_memory_inputs(langchain_object):
@ -118,7 +119,7 @@ def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
return inputs
async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
async def generate_result(langchain_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
if isinstance(langchain_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -131,11 +132,15 @@ async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: d
elif isinstance(langchain_object, Document):
result = langchain_object.dict()
elif isinstance(langchain_object, Runnable):
if isinstance(inputs, List):
call_func = langchain_object.abatch
elif isinstance(inputs, dict):
call_func = langchain_object.ainvoke
result = await call_func(inputs)
# Define call_method as a coroutine function
# by default
if isinstance(inputs, List) and hasattr(langchain_object, "abatch"):
call_method = langchain_object.abatch
elif isinstance(inputs, dict) and hasattr(langchain_object, "ainvoke"):
call_method = langchain_object.ainvoke
else:
raise ValueError("Inputs must be provided for a Runnable")
result = await call_method(inputs)
if isinstance(result, list):
result = [r.content if hasattr(r, "content") else r for r in result]
elif hasattr(result, "content"):