diff --git a/tests/test_graph.py b/tests/test_graph.py index f3efe3614..f88d89a1b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,17 +1,22 @@ +import json import os from pathlib import Path +import pickle from typing import Type, Union +from langflow import graph from langflow.graph.edge.base import Edge from langflow.graph.vertex.base import Vertex - +from langchain.agents import AgentExecutor import pytest from langchain.chains.base import Chain from langchain.llms.fake import FakeListLLM from langflow.graph import Graph from langflow.graph.vertex.types import ( + AgentVertex, FileToolVertex, LLMVertex, ToolkitVertex, + VectorStoreVertex, ) from langflow.processing.process import get_result_and_thought from langflow.utils.payload import get_root_node @@ -185,7 +190,7 @@ def test_build_edges(basic_graph): assert isinstance(edge.target, Vertex) -def test_get_root_node(basic_graph, complex_graph): +def test_get_root_node(client, basic_graph, complex_graph): """Test getting root node""" assert isinstance(basic_graph, Graph) root = get_root_node(basic_graph) @@ -261,7 +266,7 @@ def test_llm_node_build(basic_graph): assert built_object is not None -def test_toolkit_node_build(openapi_graph): +def test_toolkit_node_build(client, openapi_graph): # Write a file to the disk file_path = "api-with-examples.yaml" with open(file_path, "w") as f: @@ -276,7 +281,7 @@ def test_toolkit_node_build(openapi_graph): assert not Path(file_path).exists() -def test_file_tool_node_build(openapi_graph): +def test_file_tool_node_build(client, openapi_graph): file_path = "api-with-examples.yaml" with open(file_path, "w") as f: f.write("openapi: 3.0.0") @@ -318,3 +323,29 @@ def test_get_result_and_thought(basic_graph): # Get the result and thought result = get_result_and_thought(langchain_object, message) assert isinstance(result, dict) + + +def test_pickle_graph(json_vector_store): + loaded_json = json.loads(json_vector_store) + graph = Graph.from_payload(loaded_json) + assert isinstance(graph, Graph) + first_result = graph.build() + assert isinstance(first_result, AgentExecutor) + pickled = pickle.dumps(graph) + assert pickled is not None + unpickled = pickle.loads(pickled) + assert unpickled is not None + result = unpickled.build() + assert isinstance(result, AgentExecutor) + + +def test_pickle_each_vertex(json_vector_store): + loaded_json = json.loads(json_vector_store) + graph = Graph.from_payload(loaded_json) + assert isinstance(graph, Graph) + for vertex in graph.nodes: + vertex.build() + pickled = pickle.dumps(vertex) + assert pickled is not None + unpickled = pickle.loads(pickled) + assert unpickled is not None diff --git a/tests/test_process.py b/tests/test_process.py index bb0147616..775d17145 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -201,15 +201,11 @@ def test_load_langchain_object_with_cached_session(client, basic_graph_data): # Provide a non-existent session_id session_service = get_session_service() session_id1 = "non-existent-session-id" - langchain_object1, artifacts1 = session_service.load_session( - session_id1, basic_graph_data - ) + graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data) # Use the new session_id to get the langchain_object again - langchain_object2, artifacts2 = session_service.load_session( - session_id1, basic_graph_data - ) + graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data) - assert id(langchain_object1) == id(langchain_object2) + assert graph1 == graph2 assert artifacts1 == artifacts2 @@ -218,31 +214,22 @@ def test_load_langchain_object_with_no_cached_session(client, basic_graph_data): session_service = get_session_service() session_id1 = "non-existent-session-id" session_id = session_service.build_key(session_id1, basic_graph_data) - langchain_object1, artifacts1 = session_service.load_session( - session_id, basic_graph_data - ) + graph1, artifacts1 = session_service.load_session(session_id, basic_graph_data) # Clear the cache session_service.clear_session(session_id) # Use the new session_id to get the langchain_object again - langchain_object2, artifacts2 = session_service.load_session( - session_id, basic_graph_data - ) + graph2, artifacts2 = session_service.load_session(session_id, basic_graph_data) - assert id(langchain_object1) != id( - langchain_object2 - ) # Since the cache was cleared, objects should be different + assert id(graph1) != id(graph2) + # Since the cache was cleared, objects should be different def test_load_langchain_object_without_session_id(client, basic_graph_data): # Provide a non-existent session_id session_service = get_session_service() session_id1 = None - langchain_object1, artifacts1 = session_service.load_session( - session_id1, basic_graph_data - ) + graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data) # Use the new session_id to get the langchain_object again - langchain_object2, artifacts2 = session_service.load_session( - session_id1, basic_graph_data - ) + graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data) - assert id(langchain_object1) == id(langchain_object2) + assert graph1 == graph2