Add async/await to generate_result function

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-12-20 14:18:48 -03:00
commit c4daf5095a

View file

@ -118,7 +118,7 @@ def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
return inputs
def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
if isinstance(langchain_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -131,7 +131,11 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
elif isinstance(langchain_object, Document):
result = langchain_object.dict()
elif isinstance(langchain_object, Runnable):
result = langchain_object.invoke(inputs)
if isinstance(inputs, List):
call_func = langchain_object.abatch
elif isinstance(inputs, dict):
call_func = langchain_object.ainvoke
result = await call_func(inputs)
result = result.content if hasattr(result, "content") else result
elif hasattr(langchain_object, "run") and isinstance(langchain_object, CustomComponent):
result = langchain_object.run(inputs)
@ -152,7 +156,7 @@ class Result(BaseModel):
async def process_graph_cached(
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
inputs: Optional[Union[dict, List[dict]]] = None,
clear_cache=False,
session_id=None,
) -> Result:
@ -168,7 +172,7 @@ async def process_graph_cached(
raise ValueError("Graph not found in the session")
built_object = await graph.build()
processed_inputs = process_inputs(inputs, artifacts or {})
result = generate_result(built_object, processed_inputs)
result = await generate_result(built_object, processed_inputs)
# langchain_object is now updated with the new memory
# we need to update the cache with the updated langchain_object
session_service.update_session(session_id, (graph, artifacts))