Update process.py with import statements and formatting improvements

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 13:44:38 -03:00
commit 5a39af29a3

View file

@ -7,6 +7,9 @@ from langchain.schema import AgentAction, Document
from langchain_community.vectorstores import VectorStore
from langchain_core.messages import AIMessage
from langchain_core.runnables.base import Runnable
from loguru import logger
from pydantic import BaseModel
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.interface.custom.custom_component import CustomComponent
@ -17,8 +20,6 @@ from langflow.interface.run import (
)
from langflow.services.deps import get_session_service
from langflow.services.session.service import SessionService
from loguru import logger
from pydantic import BaseModel
def fix_memory_inputs(langchain_object):
@ -146,7 +147,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
result = await runnable.ainvoke(inputs)
else:
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
raise ValueError(
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
)
# Check if the result is a list of AIMessages
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
result = [r.content for r in result]
@ -155,7 +158,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
return result
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
async def process_inputs_dict(
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
):
if isinstance(built_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -190,7 +195,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
return await process_runnable(built_object, inputs)
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
async def generate_result(
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
):
if isinstance(inputs, dict):
result = await process_inputs_dict(built_object, inputs)
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
@ -222,7 +229,9 @@ async def process_graph_cached(
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
session_id = session_service.generate_key(
session_id=session_id, data_graph=data_graph
)
# Load the graph using SessionService
session = await session_service.load_session(session_id, data_graph)
graph, artifacts = session if session else (None, None)
@ -258,14 +267,34 @@ async def build_graph_and_generate_result(
return Result(result=result, session_id=session_id)
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
async def run_graph(
graph: Union["Graph", dict],
session_id: str,
inputs: Optional[Union[dict, List[dict]]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
):
"""Run the graph and generate the result"""
if isinstance(graph, dict):
graph = Graph.from_payload(graph)
outputs = await graph.run(inputs)
if session_id and session_service:
session_service.update_session(session_id, (graph, artifacts))
return outputs
def validate_input(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
raise ValueError("graph_data and tweaks should be dictionaries")
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
if not isinstance(nodes, list):
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
raise ValueError(
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
)
return nodes
@ -274,7 +303,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
template_data = node.get("data", {}).get("node", {}).get("template")
if not isinstance(template_data, dict):
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
logger.warning(
f"Template data for node {node.get('id')} should be a dictionary"
)
return
for tweak_name, tweak_value in node_tweaks.items():
@ -289,7 +320,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
vertex.params[tweak_name] = tweak_value
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
def process_tweaks(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> Dict[str, Any]:
"""
This function is used to tweak the graph data using the node id and the tweaks dict.
@ -310,7 +343,9 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
if node_tweaks := tweaks.get(node_id):
apply_tweaks(node, node_tweaks)
else:
logger.warning("Each node should be a dictionary with an 'id' key of type str")
logger.warning(
"Each node should be a dictionary with an 'id' key of type str"
)
return graph_data
@ -322,6 +357,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
if node_tweaks := tweaks.get(node_id):
apply_tweaks_on_vertex(vertex, node_tweaks)
else:
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
logger.warning(
"Each node should be a Vertex with an 'id' attribute of type str"
)
return graph