Refactor process.py to improve code structure and readability

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-20 18:56:32 -03:00
commit 4a9c7506ea

View file

@ -6,12 +6,11 @@ from langchain.chains.base import Chain
from langchain.schema import AgentAction, Document
from langchain.vectorstores.base import VectorStore
from langchain_core.runnables.base import Runnable
from loguru import logger
from pydantic import BaseModel
from langflow.components.custom_components import CustomComponent
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys
from langflow.services.deps import get_session_service
from loguru import logger
from pydantic import BaseModel
def fix_memory_inputs(langchain_object):
@ -119,42 +118,64 @@ def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
return inputs
async def generate_result(langchain_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
if isinstance(langchain_object, Chain):
async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
if isinstance(inputs, List) and hasattr(runnable, "abatch"):
result = await runnable.abatch(inputs)
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
result = await runnable.ainvoke(inputs)
else:
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
return result
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
if isinstance(built_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
logger.debug("Generating result and thought")
result = get_result_and_thought(langchain_object, inputs)
result = get_result_and_thought(built_object, inputs)
logger.debug("Generated result and thought")
elif isinstance(langchain_object, VectorStore):
result = langchain_object.search(**inputs)
elif isinstance(langchain_object, Document):
result = langchain_object.dict()
elif isinstance(langchain_object, Runnable):
# Define call_method as a coroutine function
# by default
if isinstance(inputs, List) and hasattr(langchain_object, "abatch"):
call_method = langchain_object.abatch
elif isinstance(inputs, dict) and hasattr(langchain_object, "ainvoke"):
call_method = langchain_object.ainvoke
else:
raise ValueError("Inputs must be provided for a Runnable")
result = await call_method(inputs)
elif isinstance(built_object, VectorStore) and "query" in inputs:
if isinstance(inputs, dict) and "search_type" not in inputs:
inputs["search_type"] = "similarity"
logger.info("search_type not provided, using default value: similarity")
result = built_object.search(**inputs)
elif isinstance(built_object, Document):
result = built_object.dict()
elif isinstance(built_object, Runnable):
result = await process_runnable(built_object, inputs)
if isinstance(result, list):
result = [r.content if hasattr(r, "content") else r for r in result]
elif hasattr(result, "content"):
result = result.content
else:
result = result
elif hasattr(langchain_object, "run") and isinstance(langchain_object, CustomComponent):
result = langchain_object.run(inputs)
elif hasattr(built_object, "run") and isinstance(built_object, CustomComponent):
result = built_object.run(inputs)
else:
logger.warning(f"Unknown langchain_object type: {type(langchain_object)}")
if isinstance(langchain_object, Coroutine):
result = asyncio.run(langchain_object)
result = langchain_object
result = None
return result
async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
return await process_runnable(built_object, inputs)
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
if isinstance(inputs, dict):
result = await process_inputs_dict(built_object, inputs)
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
result = await process_inputs_list(built_object, inputs)
else:
raise ValueError(f"Invalid inputs type: {type(inputs)}")
if result is None:
logger.warning(f"Unknown built_object type: {type(built_object)}")
if isinstance(built_object, Coroutine):
result = asyncio.run(built_object)
result = built_object
return result