Refactor process.py: Format code and add type hints
This commit is contained in:
parent
985cf9e532
commit
0fb0b34423
1 changed files with 42 additions and 13 deletions
|
|
@ -10,7 +10,11 @@ from langchain_core.runnables.base import Runnable
|
|||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys
|
||||
from langflow.interface.run import (
|
||||
build_sorted_vertices,
|
||||
get_memory_key,
|
||||
update_memory_keys,
|
||||
)
|
||||
from langflow.services.deps import get_session_service
|
||||
from langflow.services.session.service import SessionService
|
||||
from loguru import logger
|
||||
|
|
@ -110,7 +114,8 @@ def get_build_result(data_graph, session_id):
|
|||
|
||||
|
||||
def process_inputs(
|
||||
inputs: Optional[Union[dict, List[dict]]] = None, artifacts: Optional[Dict[str, Any]] = None
|
||||
inputs: Optional[Union[dict, List[dict]]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[dict, List[dict]]:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
|
|
@ -141,7 +146,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
|
|||
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
|
||||
result = await runnable.ainvoke(inputs)
|
||||
else:
|
||||
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
|
||||
raise ValueError(
|
||||
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
|
||||
)
|
||||
# Check if the result is a list of AIMessages
|
||||
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
|
||||
result = [r.content for r in result]
|
||||
|
|
@ -150,7 +157,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
|
|||
return result
|
||||
|
||||
|
||||
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
|
||||
async def process_inputs_dict(
|
||||
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
|
||||
):
|
||||
if isinstance(built_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
|
|
@ -185,7 +194,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
|
|||
return await process_runnable(built_object, inputs)
|
||||
|
||||
|
||||
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
|
||||
async def generate_result(
|
||||
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
|
||||
):
|
||||
if isinstance(inputs, dict):
|
||||
result = await process_inputs_dict(built_object, inputs)
|
||||
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
|
||||
|
|
@ -217,7 +228,9 @@ async def process_graph_cached(
|
|||
if clear_cache:
|
||||
session_service.clear_session(session_id)
|
||||
if session_id is None:
|
||||
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
|
||||
session_id = session_service.generate_key(
|
||||
session_id=session_id, data_graph=data_graph
|
||||
)
|
||||
# Load the graph using SessionService
|
||||
session = await session_service.load_session(session_id, data_graph)
|
||||
graph, artifacts = session if session else (None, None)
|
||||
|
|
@ -225,7 +238,11 @@ async def process_graph_cached(
|
|||
raise ValueError("Graph not found in the session")
|
||||
|
||||
result = await build_graph_and_generate_result(
|
||||
graph=graph, session_id=session_id, inputs=inputs, artifacts=artifacts, session_service=session_service
|
||||
graph=graph,
|
||||
session_id=session_id,
|
||||
inputs=inputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -249,14 +266,18 @@ async def build_graph_and_generate_result(
|
|||
return Result(result=result, session_id=session_id)
|
||||
|
||||
|
||||
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def validate_input(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
|
||||
raise ValueError("graph_data and tweaks should be dictionaries")
|
||||
|
||||
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
|
||||
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
|
||||
raise ValueError(
|
||||
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
|
@ -265,7 +286,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
|
|||
template_data = node.get("data", {}).get("node", {}).get("template")
|
||||
|
||||
if not isinstance(template_data, dict):
|
||||
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
|
||||
logger.warning(
|
||||
f"Template data for node {node.get('id')} should be a dictionary"
|
||||
)
|
||||
return
|
||||
|
||||
for tweak_name, tweak_value in node_tweaks.items():
|
||||
|
|
@ -280,7 +303,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
|
|||
vertex.params[tweak_name] = tweak_value
|
||||
|
||||
|
||||
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def process_tweaks(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
This function is used to tweak the graph data using the node id and the tweaks dict.
|
||||
|
||||
|
|
@ -301,7 +326,9 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
|||
if node_tweaks := tweaks.get(node_id):
|
||||
apply_tweaks(node, node_tweaks)
|
||||
else:
|
||||
logger.warning("Each node should be a dictionary with an 'id' key of type str")
|
||||
logger.warning(
|
||||
"Each node should be a dictionary with an 'id' key of type str"
|
||||
)
|
||||
|
||||
return graph_data
|
||||
|
||||
|
|
@ -313,6 +340,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
|
|||
if node_tweaks := tweaks.get(node_id):
|
||||
apply_tweaks_on_vertex(vertex, node_tweaks)
|
||||
else:
|
||||
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
|
||||
logger.warning(
|
||||
"Each node should be a Vertex with an 'id' attribute of type str"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue