Add async/await to generate_result function
This commit is contained in:
parent
84f9c34bd1
commit
c4daf5095a
1 changed files with 8 additions and 4 deletions
|
|
@ -118,7 +118,7 @@ def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
|
|||
return inputs
|
||||
|
||||
|
||||
def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
||||
async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
||||
if isinstance(langchain_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
|
|
@ -131,7 +131,11 @@ def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
|||
elif isinstance(langchain_object, Document):
|
||||
result = langchain_object.dict()
|
||||
elif isinstance(langchain_object, Runnable):
|
||||
result = langchain_object.invoke(inputs)
|
||||
if isinstance(inputs, List):
|
||||
call_func = langchain_object.abatch
|
||||
elif isinstance(inputs, dict):
|
||||
call_func = langchain_object.ainvoke
|
||||
result = await call_func(inputs)
|
||||
result = result.content if hasattr(result, "content") else result
|
||||
elif hasattr(langchain_object, "run") and isinstance(langchain_object, CustomComponent):
|
||||
result = langchain_object.run(inputs)
|
||||
|
|
@ -152,7 +156,7 @@ class Result(BaseModel):
|
|||
|
||||
async def process_graph_cached(
|
||||
data_graph: Dict[str, Any],
|
||||
inputs: Optional[dict] = None,
|
||||
inputs: Optional[Union[dict, List[dict]]] = None,
|
||||
clear_cache=False,
|
||||
session_id=None,
|
||||
) -> Result:
|
||||
|
|
@ -168,7 +172,7 @@ async def process_graph_cached(
|
|||
raise ValueError("Graph not found in the session")
|
||||
built_object = await graph.build()
|
||||
processed_inputs = process_inputs(inputs, artifacts or {})
|
||||
result = generate_result(built_object, processed_inputs)
|
||||
result = await generate_result(built_object, processed_inputs)
|
||||
# langchain_object is now updated with the new memory
|
||||
# we need to update the cache with the updated langchain_object
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue