🔧 fix(process.py): refactor process_graph_cached function for better readability and maintainability

 feat(process.py): add clear_caches_if_needed function to clear cache if clear_cache flag is set
 feat(process.py): add load_langchain_object function to load langchain_object and artifacts from data_graph
 feat(process.py): add process_inputs function to process inputs and add artifacts to inputs
 feat(process.py): add generate_result function to generate result and thought based on langchain_object and inputs
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-16 21:29:17 -03:00
commit 199ea3b1b8

View file

@ -101,45 +101,39 @@ def get_build_result(data_graph, session_id):
return build_sorted_vertices_with_caching(data_graph)
def process_graph_cached(
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
clear_cache=False,
session_id=None,
):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
def clear_caches_if_needed(clear_cache: bool):
if clear_cache:
build_sorted_vertices_with_caching.clear_cache()
logger.debug("Cleared cache")
def load_langchain_object(
data_graph: Dict[str, Any], session_id: str
) -> Tuple[Union[Chain, VectorStore], Dict[str, Any]]:
langchain_object, artifacts = get_build_result(data_graph, session_id)
session_id = build_sorted_vertices_with_caching.hash
logger.debug("Loaded LangChain object")
if inputs is None:
inputs = {}
# Add artifacts to inputs
# artifacts can be documents loaded when building
# the flow
for (
key,
value,
) in artifacts.items():
if key not in inputs or not inputs[key]:
inputs[key] = value
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
return langchain_object, artifacts, session_id
def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
if inputs is None:
inputs = {}
for key, value in artifacts.items():
if key not in inputs or not inputs[key]:
inputs[key] = value
return inputs
def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
if isinstance(langchain_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -152,6 +146,25 @@ def process_graph_cached(
raise ValueError(
f"Unknown langchain_object type: {type(langchain_object).__name__}"
)
return result
def process_graph_cached(
data_graph: Dict[str, Any],
inputs: Optional[dict] = None,
clear_cache=False,
session_id=None,
) -> Tuple[Any, str]:
clear_caches_if_needed(clear_cache)
# If session_id is provided, load the langchain_object from the session
# else build the graph and return the result and the new session_id
langchain_object, artifacts, session_id = load_langchain_object(
data_graph, session_id
)
processed_inputs = process_inputs(inputs, artifacts)
result = generate_result(langchain_object, processed_inputs)
return result, session_id