diff --git a/tests/conftest.py b/tests/conftest.py index af9bbe169..500419fe9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,48 +1,75 @@ +from contextlib import contextmanager import json - -# we need to import tmpdir -import tempfile -from contextlib import contextmanager, suppress +from contextlib import suppress from pathlib import Path -from typing import TYPE_CHECKING, AsyncGenerator - -import orjson -import pytest -from fastapi.testclient import TestClient -from httpx import AsyncClient -from sqlmodel import Session, SQLModel, create_engine -from sqlmodel.pool import StaticPool -from typer.testing import CliRunner +from typing import AsyncGenerator, TYPE_CHECKING from langflow.graph.graph.base import Graph from langflow.services.auth.utils import get_password_hash -from langflow.services.database.models.flow.model import Flow, FlowCreate -from langflow.services.database.models.user.model import User, UserCreate +from langflow.services.database.models.flow.flow import Flow, FlowCreate +from langflow.services.database.models.user.user import User, UserCreate +import orjson from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service +from langflow.services.getters import get_db_service +import pytest +from fastapi.testclient import TestClient +from httpx import AsyncClient +from sqlmodel import SQLModel, Session, create_engine +from sqlmodel.pool import StaticPool +from typer.testing import CliRunner + +# we need to import tmpdir +import tempfile if TYPE_CHECKING: - from langflow.services.database.service import DatabaseService + from langflow.services.database.manager import DatabaseService def pytest_configure(): - pytest.BASIC_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "basic_example.json" - pytest.COMPLEX_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "complex_example.json" - pytest.OPENAPI_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "Openapi.json" - pytest.GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "grouped_chat.json" - pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "one_group_chat.json" - pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" + pytest.BASIC_EXAMPLE_PATH = ( + Path(__file__).parent.absolute() / "data" / "basic_example.json" + ) + pytest.COMPLEX_EXAMPLE_PATH = ( + Path(__file__).parent.absolute() / "data" / "complex_example.json" + ) + pytest.OPENAPI_EXAMPLE_PATH = ( + Path(__file__).parent.absolute() / "data" / "Openapi.json" + ) + pytest.GROUPED_CHAT_EXAMPLE_PATH = ( + Path(__file__).parent.absolute() / "data" / "grouped_chat.json" + ) + pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH = ( + Path(__file__).parent.absolute() / "data" / "one_group_chat.json" + ) + pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH = ( + Path(__file__).parent.absolute() / "data" / "vector_store_grouped.json" + ) pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY = ( - Path(__file__).parent.absolute() / "data" / "BasicChatwithPromptandHistory.json" + Path(__file__).parent.absolute() / "data" / "BasicChatWithPromptAndHistory.json" + ) + pytest.CHAT_INPUT = Path(__file__).parent.absolute() / "data" / "ChatInputTest.json" + pytest.TWO_OUTPUTS = ( + Path(__file__).parent.absolute() / "data" / "TwoOutputsTest.json" + ) + pytest.VECTOR_STORE_PATH = ( + Path(__file__).parent.absolute() / "data" / "Vector_store.json" ) - pytest.VECTOR_STORE_PATH = Path(__file__).parent.absolute() / "data" / "Vector_store.json" pytest.CODE_WITH_SYNTAX_ERROR = """ def get_text(): retun "Hello World" """ +@pytest.fixture(autouse=True) +def check_openai_api_key_in_environment_variables(): + import os + + assert ( + os.environ.get("OPENAI_API_KEY") is not None + ), "OPENAI_API_KEY is not set in environment variables" + + @pytest.fixture() async def async_client() -> AsyncGenerator: from langflow.main import create_app @@ -54,7 +81,9 @@ async def async_client() -> AsyncGenerator: @pytest.fixture(name="session") def session_fixture(): - engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session @@ -90,7 +119,9 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env): monkeypatch.setenv("LANGFLOW_AUTO_LOGIN", "false") # monkeypatch langflow.services.task.manager.USE_CELERY to True # monkeypatch.setattr(manager, "USE_CELERY", True) - monkeypatch.setattr(celery_app, "celery_app", celery_app.make_celery("langflow", Config)) + monkeypatch.setattr( + celery_app, "celery_app", celery_app.make_celery("langflow", Config) + ) # def get_session_override(): # return session @@ -225,7 +256,7 @@ def test_user(client): username="testuser", password="testpassword", ) - response = client.post("/api/v1/users", json=user_data.model_dump()) + response = client.post("/api/v1/users", json=user_data.dict()) assert response.status_code == 201 return response.json() @@ -241,7 +272,11 @@ def active_user(client): is_superuser=False, ) # check if user exists - if active_user := session.query(User).filter(User.username == user.username).first(): + if ( + active_user := session.query(User) + .filter(User.username == user.username) + .first() + ): return active_user session.add(user) session.commit() @@ -261,16 +296,13 @@ def logged_in_headers(client, active_user): @pytest.fixture def flow(client, json_flow: str, active_user): - from langflow.services.database.models.flow.model import FlowCreate + from langflow.services.database.models.flow.flow import FlowCreate loaded_json = json.loads(json_flow) flow_data = FlowCreate( - name="test_flow", - data=loaded_json.get("data"), - user_id=active_user.id, - description="description", + name="test_flow", data=loaded_json.get("data"), user_id=active_user.id ) - flow = Flow.model_validate(flow_data.model_dump()) + flow = Flow(**flow_data.dict()) with session_getter(get_db_service()) as session: session.add(flow) session.commit() @@ -280,11 +312,31 @@ def flow(client, json_flow: str, active_user): @pytest.fixture -def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers): +def json_flow_with_prompt_and_history(): + with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f: + return f.read() + + +@pytest.fixture +def json_chat_input(): + with open(pytest.CHAT_INPUT, "r") as f: + return f.read() + + +@pytest.fixture +def json_two_outputs(): + with open(pytest.TWO_OUTPUTS, "r") as f: + return f.read() + + +@pytest.fixture +def added_flow_with_prompt_and_history( + client, json_flow_with_prompt_and_history, logged_in_headers +): flow = orjson.loads(json_flow_with_prompt_and_history) data = flow["data"] flow = FlowCreate(name="Basic Chat", description="description", data=data) - response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data @@ -292,28 +344,37 @@ def added_flow(client, json_flow_with_prompt_and_history, logged_in_headers): @pytest.fixture -def added_vector_store(client, json_vector_store, logged_in_headers): - vector_store = orjson.loads(json_vector_store) - data = vector_store["data"] - vector_store = FlowCreate(name="Vector Store", description="description", data=data) - response = client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers) +def added_flow_chat_input(client, json_chat_input, logged_in_headers): + flow = orjson.loads(json_chat_input) + data = flow["data"] + flow = FlowCreate(name="Chat Input", description="description", data=data) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) assert response.status_code == 201 - assert response.json()["name"] == vector_store.name - assert response.json()["data"] == vector_store.data + assert response.json()["name"] == flow.name + assert response.json()["data"] == flow.data return response.json() @pytest.fixture -def test_component_code(): - path = Path(__file__).parent.absolute() / "data" / "component.py" - # load the content as a string - with open(path, "r") as f: - return f.read() +def added_flow_two_outputs(client, json_two_outputs, logged_in_headers): + flow = orjson.loads(json_two_outputs) + data = flow["data"] + flow = FlowCreate(name="Two Outputs", description="description", data=data) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) + assert response.status_code == 201 + assert response.json()["name"] == flow.name + assert response.json()["data"] == flow.data @pytest.fixture -def test_component_with_templatefield_code(): - path = Path(__file__).parent.absolute() / "data" / "component_with_templatefield.py" - # load the content as a string - with open(path, "r") as f: - return f.read() +def added_vector_store(client, json_vector_store, logged_in_headers): + vector_store = orjson.loads(json_vector_store) + data = vector_store["data"] + vector_store = FlowCreate(name="Vector Store", description="description", data=data) + response = client.post( + "api/v1/flows/", json=vector_store.dict(), headers=logged_in_headers + ) + assert response.status_code == 201 + assert response.json()["name"] == vector_store.name + assert response.json()["data"] == vector_store.data + return response.json() diff --git a/tests/data/ChatInputTest.json b/tests/data/ChatInputTest.json new file mode 100644 index 000000000..24771bab9 --- /dev/null +++ b/tests/data/ChatInputTest.json @@ -0,0 +1 @@ +{"name":"ChatInputTest","description":"","data":{"nodes":[{"width":384,"height":359,"id":"PromptTemplate-IKKOx","type":"genericNode","position":{"x":880,"y":646.9375},"data":{"type":"PromptTemplate","node":{"template":{"output_parser":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"output_parser","advanced":false,"dynamic":false,"info":"","type":"BaseOutputParser","list":false},"input_variables":{"required":true,"placeholder":"","show":false,"multiline":false,"password":false,"name":"input_variables","advanced":false,"dynamic":false,"info":"","type":"str","list":true,"value":["input"]},"partial_variables":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"partial_variables","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"template":{"required":true,"placeholder":"","show":true,"multiline":true,"password":false,"name":"template","advanced":false,"dynamic":false,"info":"","type":"prompt","list":false,"value":"Input: {input}\nAI:"},"template_format":{"required":false,"placeholder":"","show":false,"multiline":false,"value":"f-string","password":false,"name":"template_format","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"validate_template":{"required":false,"placeholder":"","show":false,"multiline":false,"value":true,"password":false,"name":"validate_template","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"_type":"PromptTemplate","input":{"required":false,"placeholder":"","show":true,"multiline":true,"value":"","password":false,"name":"input","display_name":"input","advanced":false,"input_types":["Document","BaseOutputParser","str"],"dynamic":false,"info":"","type":"str","list":false}},"description":"A prompt template for a language model.","base_classes":["BasePromptTemplate","PromptTemplate","StringPromptTemplate"],"name":"","display_name":"PromptTemplate","documentation":"https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/","custom_fields":{"":["input"],"template":["input"]},"output_types":[],"field_formatters":{"formatters":{"openai_api_key":{}},"base_formatters":{"kwargs":{},"optional":{},"list":{},"dict":{},"union":{},"multiline":{},"show":{},"password":{},"default":{},"headers":{},"dict_code_file":{},"model_fields":{"MODEL_DICT":{"OpenAI":["text-davinci-003","text-davinci-002","text-curie-001","text-babbage-001","text-ada-001"],"ChatOpenAI":["gpt-3.5-turbo-0613","gpt-3.5-turbo","gpt-3.5-turbo-16k-0613","gpt-3.5-turbo-16k","gpt-4-0613","gpt-4-32k-0613","gpt-4","gpt-4-32k"],"Anthropic":["claude-v1","claude-v1-100k","claude-instant-v1","claude-instant-v1-100k","claude-v1.3","claude-v1.3-100k","claude-v1.2","claude-v1.0","claude-instant-v1.1","claude-instant-v1.1-100k","claude-instant-v1.0"],"ChatAnthropic":["claude-v1","claude-v1-100k","claude-instant-v1","claude-instant-v1-100k","claude-v1.3","claude-v1.3-100k","claude-v1.2","claude-v1.0","claude-instant-v1.1","claude-instant-v1.1-100k","claude-instant-v1.0"]}}}},"beta":false,"error":null},"id":"PromptTemplate-IKKOx"},"selected":false,"positionAbsolute":{"x":880,"y":646.9375},"dragging":false},{"width":384,"height":307,"id":"LLMChain-e2dhN","type":"genericNode","position":{"x":1449.330344958542,"y":880.1760221487797},"data":{"type":"LLMChain","node":{"template":{"callbacks":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"callbacks","advanced":false,"dynamic":false,"info":"","type":"langchain.callbacks.base.BaseCallbackHandler","list":true},"llm":{"required":true,"placeholder":"","show":true,"multiline":false,"password":false,"name":"llm","advanced":false,"dynamic":false,"info":"","type":"BaseLanguageModel","list":false},"memory":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"memory","advanced":false,"dynamic":false,"info":"","type":"BaseMemory","list":false},"output_parser":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"output_parser","advanced":false,"dynamic":false,"info":"","type":"BaseLLMOutputParser","list":false},"prompt":{"required":true,"placeholder":"","show":true,"multiline":false,"password":false,"name":"prompt","advanced":false,"dynamic":false,"info":"","type":"BasePromptTemplate","list":false},"llm_kwargs":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"llm_kwargs","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"metadata":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"metadata","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"output_key":{"required":true,"placeholder":"","show":true,"multiline":false,"value":"text","password":false,"name":"output_key","advanced":true,"dynamic":false,"info":"","type":"str","list":false},"return_final_only":{"required":false,"placeholder":"","show":false,"multiline":false,"value":true,"password":false,"name":"return_final_only","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"tags":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"tags","advanced":false,"dynamic":false,"info":"","type":"str","list":true},"verbose":{"required":false,"placeholder":"","show":false,"multiline":false,"value":false,"password":false,"name":"verbose","advanced":true,"dynamic":false,"info":"","type":"bool","list":false},"_type":"LLMChain"},"description":"Chain to run queries against LLMs.","base_classes":["Chain","LLMChain","function","Text"],"display_name":"LLMChain","custom_fields":{},"output_types":[],"documentation":"https://python.langchain.com/docs/modules/chains/foundational/llm_chain","beta":false,"error":null},"id":"LLMChain-e2dhN"},"positionAbsolute":{"x":1449.330344958542,"y":880.1760221487797}},{"width":384,"height":621,"id":"ChatOpenAI-2I57f","type":"genericNode","position":{"x":393.3551923753797,"y":1061.025177453298},"data":{"type":"ChatOpenAI","node":{"template":{"callbacks":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"callbacks","advanced":false,"dynamic":false,"info":"","type":"langchain.callbacks.base.BaseCallbackHandler","list":true},"cache":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"cache","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"client":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"client","advanced":false,"dynamic":false,"info":"","type":"Any","list":false},"max_retries":{"required":false,"placeholder":"","show":false,"multiline":false,"value":6,"password":false,"name":"max_retries","advanced":false,"dynamic":false,"info":"","type":"int","list":false},"max_tokens":{"required":false,"placeholder":"","show":true,"multiline":false,"password":true,"name":"max_tokens","advanced":false,"dynamic":false,"info":"","type":"int","list":false,"value":""},"metadata":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"metadata","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"model_kwargs":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"model_kwargs","advanced":true,"dynamic":false,"info":"","type":"code","list":false},"model_name":{"required":false,"placeholder":"","show":true,"multiline":false,"value":"gpt-3.5-turbo-0613","password":false,"options":["gpt-3.5-turbo-0613","gpt-3.5-turbo","gpt-3.5-turbo-16k-0613","gpt-3.5-turbo-16k","gpt-4-0613","gpt-4-32k-0613","gpt-4","gpt-4-32k"],"name":"model_name","advanced":false,"dynamic":false,"info":"","type":"str","list":true},"n":{"required":false,"placeholder":"","show":false,"multiline":false,"value":1,"password":false,"name":"n","advanced":false,"dynamic":false,"info":"","type":"int","list":false},"openai_api_base":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"openai_api_base","display_name":"OpenAI API Base","advanced":false,"dynamic":false,"info":"\nThe base URL of the OpenAI API. Defaults to https://api.openai.com/v1.\n\nYou can change this to use other APIs like JinaChat, LocalAI and Prem.\n","type":"str","list":false},"openai_api_key":{"required":false,"placeholder":"","show":true,"multiline":false,"value":"","password":true,"name":"openai_api_key","display_name":"OpenAI API Key","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"openai_organization":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"openai_organization","display_name":"OpenAI Organization","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"openai_proxy":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"openai_proxy","display_name":"OpenAI Proxy","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"request_timeout":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"request_timeout","advanced":false,"dynamic":false,"info":"","type":"float","list":false},"streaming":{"required":false,"placeholder":"","show":false,"multiline":false,"value":false,"password":false,"name":"streaming","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"tags":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"tags","advanced":false,"dynamic":false,"info":"","type":"str","list":true},"temperature":{"required":false,"placeholder":"","show":true,"multiline":false,"value":0.7,"password":false,"name":"temperature","advanced":false,"dynamic":false,"info":"","type":"float","list":false},"tiktoken_model_name":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"tiktoken_model_name","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"verbose":{"required":false,"placeholder":"","show":false,"multiline":false,"value":false,"password":false,"name":"verbose","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"_type":"ChatOpenAI"},"description":"`OpenAI` Chat large language models API.","base_classes":["BaseChatModel","ChatOpenAI","BaseLanguageModel","BaseLLM"],"display_name":"ChatOpenAI","custom_fields":{},"output_types":[],"documentation":"https://python.langchain.com/docs/modules/model_io/models/chat/integrations/openai","beta":false,"error":null},"id":"ChatOpenAI-2I57f"},"selected":false,"positionAbsolute":{"x":393.3551923753797,"y":1061.025177453298},"dragging":false},{"width":384,"height":359,"id":"ChatInput-207IY","type":"genericNode","position":{"x":415.1018926651509,"y":506.62736462360317},"data":{"type":"ChatInput","node":{"template":{"code":{"dynamic":true,"required":true,"placeholder":"","show":false,"multiline":true,"value":"from typing import Optional\nfrom langflow import CustomComponent\n\n\nclass ChatInput(CustomComponent):\n display_name = \"Chat Input\"\n\n def build(self, message: Optional[str] = \"\") -> str:\n return message\n","password":false,"name":"code","advanced":false,"type":"code","list":false},"_type":"CustomComponent","message":{"required":false,"placeholder":"","show":true,"multiline":false,"value":"","password":false,"name":"message","display_name":"message","advanced":false,"dynamic":false,"info":"","type":"str","list":false}},"description":"Used to get user input from the chat.","base_classes":["str"],"display_name":"Chat Input","custom_fields":{"message":null},"output_types":["ChatInput"],"documentation":"","beta":true,"error":null},"id":"ChatInput-207IY"},"positionAbsolute":{"x":415.1018926651509,"y":506.62736462360317}},{"width":384,"height":389,"id":"ChatOutput-1jlJy","type":"genericNode","position":{"x":2002.8008888732943,"y":926.1397178702218},"data":{"type":"ChatOutput","node":{"template":{"code":{"dynamic":true,"required":true,"placeholder":"","show":true,"multiline":true,"value":"from typing import Optional, Text\nfrom langflow.api.v1.schemas import ChatMessage\nfrom langflow.services.utils import get_chat_manager\nfrom langflow import CustomComponent\nfrom anyio.from_thread import start_blocking_portal\nfrom loguru import logger\n\n\nclass ChatOutput(CustomComponent):\n display_name = \"Chat Output\"\n description = \"Used to send a message to the chat.\"\n\n field_config = {\n \"code\": {\n \"show\": False,\n }\n }\n\n def build_config(self):\n return {\"message\": {\"input_types\": [\"Text\"]}}\n\n def build(self, message: Optional[Text], is_ai: bool = False) -> Text:\n if not message:\n return \"\"\n try:\n chat_manager = get_chat_manager()\n chat_message = ChatMessage(message=message, is_bot=is_ai)\n # send_message is a coroutine\n # run in a thread safe manner\n with start_blocking_portal() as portal:\n portal.call(chat_manager.send_message, chat_message)\n chat_manager.chat_history.add_message(\n chat_manager.cache_manager.current_client_id, chat_message\n )\n except Exception as exc:\n logger.exception(exc)\n logger.debug(f\"Error sending message to chat: {exc}\")\n self.repr_value = message\n return message\n","password":false,"name":"code","advanced":false,"type":"code","list":false},"_type":"CustomComponent","is_ai":{"required":true,"placeholder":"","show":true,"multiline":false,"value":true,"password":false,"name":"is_ai","display_name":"is_ai","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"message":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"message","display_name":"message","advanced":false,"input_types":["Text"],"dynamic":false,"info":"","type":"Text","list":false}},"description":"Used to send a message to the chat.","base_classes":["str"],"display_name":"Chat Output","custom_fields":{"is_ai":null,"message":null},"output_types":["ChatOutput"],"documentation":"","beta":true,"error":null},"id":"ChatOutput-1jlJy"},"selected":true,"dragging":false,"positionAbsolute":{"x":2002.8008888732943,"y":926.1397178702218}}],"edges":[{"source":"PromptTemplate-IKKOx","sourceHandle":"PromptTemplate|PromptTemplate-IKKOx|BasePromptTemplate|PromptTemplate|StringPromptTemplate","target":"LLMChain-e2dhN","targetHandle":"BasePromptTemplate|prompt|LLMChain-e2dhN","style":{"stroke":"#555"},"className":"","animated":false,"id":"reactflow__edge-PromptTemplate-IKKOxPromptTemplate|PromptTemplate-IKKOx|StringPromptTemplate|BasePromptTemplate|PromptTemplate-LLMChain-e2dhNBasePromptTemplate|prompt|LLMChain-e2dhN"},{"source":"ChatOpenAI-2I57f","sourceHandle":"ChatOpenAI|ChatOpenAI-2I57f|BaseChatModel|ChatOpenAI|BaseLanguageModel|BaseLLM","target":"LLMChain-e2dhN","targetHandle":"BaseLanguageModel|llm|LLMChain-e2dhN","style":{"stroke":"#555"},"className":"","animated":false,"id":"reactflow__edge-ChatOpenAI-2I57fChatOpenAI|ChatOpenAI-2I57f|BaseChatModel|ChatOpenAI|BaseLanguageModel|BaseLLM-LLMChain-e2dhNBaseLanguageModel|llm|LLMChain-e2dhN"},{"source":"ChatInput-207IY","sourceHandle":"ChatInput|ChatInput-207IY|str","target":"PromptTemplate-IKKOx","targetHandle":"Document;BaseOutputParser;str|input|PromptTemplate-IKKOx","style":{"stroke":"#555"},"className":"","animated":false,"id":"reactflow__edge-ChatInput-207IYChatInput|ChatInput-207IY|str-PromptTemplate-IKKOxDocument;BaseOutputParser;str|input|PromptTemplate-IKKOx"},{"source":"LLMChain-e2dhN","sourceHandle":"LLMChain|LLMChain-e2dhN|Chain|LLMChain|function|Text","target":"ChatOutput-1jlJy","targetHandle":"Text|message|ChatOutput-1jlJy","style":{"stroke":"#555"},"className":"stroke-foreground stroke-connection","animated":true,"id":"reactflow__edge-LLMChain-e2dhNLLMChain|LLMChain-e2dhN|Chain|LLMChain|function|Text-ChatOutput-1jlJyText|message|ChatOutput-1jlJy"}],"viewport":{"x":-141.98308184453367,"y":-104.98637616656356,"zoom":0.4788209787464315}},"id":"b3388ab9-b5dc-4447-b560-79caef40faa5","user_id":"c65bfea3-3eea-4e71-8fc4-106238eb0583"} \ No newline at end of file diff --git a/tests/data/TwoOutputsTest.json b/tests/data/TwoOutputsTest.json new file mode 100644 index 000000000..9dfb5cd43 --- /dev/null +++ b/tests/data/TwoOutputsTest.json @@ -0,0 +1 @@ +{"name":"TwoOutputsTest","description":"","data":{"nodes":[{"width":384,"height":359,"id":"PromptTemplate-CweKz","type":"genericNode","position":{"x":969.6448076246203,"y":528.7788853763968},"data":{"type":"PromptTemplate","node":{"template":{"output_parser":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"output_parser","advanced":false,"dynamic":false,"info":"","type":"BaseOutputParser","list":false},"input_variables":{"required":true,"placeholder":"","show":false,"multiline":false,"password":false,"name":"input_variables","advanced":false,"dynamic":false,"info":"","type":"str","list":true,"value":["input"]},"partial_variables":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"partial_variables","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"template":{"required":true,"placeholder":"","show":true,"multiline":true,"password":false,"name":"template","advanced":false,"dynamic":false,"info":"","type":"prompt","list":false,"value":"Input: {input}\nAI:"},"template_format":{"required":false,"placeholder":"","show":false,"multiline":false,"value":"f-string","password":false,"name":"template_format","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"validate_template":{"required":false,"placeholder":"","show":false,"multiline":false,"value":true,"password":false,"name":"validate_template","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"_type":"PromptTemplate","input":{"required":false,"placeholder":"","show":true,"multiline":true,"value":"","password":false,"name":"input","display_name":"input","advanced":false,"input_types":["Document","BaseOutputParser","str"],"dynamic":false,"info":"","type":"str","list":false}},"description":"A prompt template for a language model.","base_classes":["BasePromptTemplate","StringPromptTemplate","PromptTemplate"],"name":"","display_name":"PromptTemplate","documentation":"https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/","custom_fields":{"":["input"],"template":["input"]},"output_types":[],"field_formatters":{"formatters":{"openai_api_key":{}},"base_formatters":{"kwargs":{},"optional":{},"list":{},"dict":{},"union":{},"multiline":{},"show":{},"password":{},"default":{},"headers":{},"dict_code_file":{},"model_fields":{"MODEL_DICT":{"OpenAI":["text-davinci-003","text-davinci-002","text-curie-001","text-babbage-001","text-ada-001"],"ChatOpenAI":["gpt-3.5-turbo-0613","gpt-3.5-turbo","gpt-3.5-turbo-16k-0613","gpt-3.5-turbo-16k","gpt-4-0613","gpt-4-32k-0613","gpt-4","gpt-4-32k"],"Anthropic":["claude-v1","claude-v1-100k","claude-instant-v1","claude-instant-v1-100k","claude-v1.3","claude-v1.3-100k","claude-v1.2","claude-v1.0","claude-instant-v1.1","claude-instant-v1.1-100k","claude-instant-v1.0"],"ChatAnthropic":["claude-v1","claude-v1-100k","claude-instant-v1","claude-instant-v1-100k","claude-v1.3","claude-v1.3-100k","claude-v1.2","claude-v1.0","claude-instant-v1.1","claude-instant-v1.1-100k","claude-instant-v1.0"]}}}},"beta":false,"error":null},"id":"PromptTemplate-CweKz"},"selected":false,"positionAbsolute":{"x":969.6448076246203,"y":528.7788853763968}},{"width":384,"height":307,"id":"LLMChain-HUM6g","type":"genericNode","position":{"x":1515.3241458756393,"y":732.4536491407735},"data":{"type":"LLMChain","node":{"template":{"callbacks":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"callbacks","advanced":false,"dynamic":false,"info":"","type":"langchain.callbacks.base.BaseCallbackHandler","list":true},"llm":{"required":true,"placeholder":"","show":true,"multiline":false,"password":false,"name":"llm","advanced":false,"dynamic":false,"info":"","type":"BaseLanguageModel","list":false},"memory":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"memory","advanced":false,"dynamic":false,"info":"","type":"BaseMemory","list":false},"output_parser":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"output_parser","advanced":false,"dynamic":false,"info":"","type":"BaseLLMOutputParser","list":false},"prompt":{"required":true,"placeholder":"","show":true,"multiline":false,"password":false,"name":"prompt","advanced":false,"dynamic":false,"info":"","type":"BasePromptTemplate","list":false},"llm_kwargs":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"llm_kwargs","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"metadata":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"metadata","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"output_key":{"required":true,"placeholder":"","show":true,"multiline":false,"value":"text","password":false,"name":"output_key","advanced":true,"dynamic":false,"info":"","type":"str","list":false},"return_final_only":{"required":false,"placeholder":"","show":false,"multiline":false,"value":true,"password":false,"name":"return_final_only","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"tags":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"tags","advanced":false,"dynamic":false,"info":"","type":"str","list":true},"verbose":{"required":false,"placeholder":"","show":false,"multiline":false,"value":false,"password":false,"name":"verbose","advanced":true,"dynamic":false,"info":"","type":"bool","list":false},"_type":"LLMChain"},"description":"Chain to run queries against LLMs.","base_classes":["LLMChain","Chain","function","Text"],"display_name":"LLMChain","custom_fields":{},"output_types":[],"documentation":"https://python.langchain.com/docs/modules/chains/foundational/llm_chain","beta":false,"error":null},"id":"LLMChain-HUM6g"},"selected":false,"positionAbsolute":{"x":1515.3241458756393,"y":732.4536491407735},"dragging":false},{"width":384,"height":621,"id":"ChatOpenAI-02kOF","type":"genericNode","position":{"x":483,"y":942.8665628296949},"data":{"type":"ChatOpenAI","node":{"template":{"callbacks":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"callbacks","advanced":false,"dynamic":false,"info":"","type":"langchain.callbacks.base.BaseCallbackHandler","list":true},"cache":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"cache","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"client":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"client","advanced":false,"dynamic":false,"info":"","type":"Any","list":false},"max_retries":{"required":false,"placeholder":"","show":false,"multiline":false,"value":6,"password":false,"name":"max_retries","advanced":false,"dynamic":false,"info":"","type":"int","list":false},"max_tokens":{"required":false,"placeholder":"","show":true,"multiline":false,"password":true,"name":"max_tokens","advanced":false,"dynamic":false,"info":"","type":"int","list":false,"value":""},"metadata":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"metadata","advanced":false,"dynamic":false,"info":"","type":"code","list":false},"model_kwargs":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"model_kwargs","advanced":true,"dynamic":false,"info":"","type":"code","list":false},"model_name":{"required":false,"placeholder":"","show":true,"multiline":false,"value":"gpt-3.5-turbo-0613","password":false,"options":["gpt-3.5-turbo-0613","gpt-3.5-turbo","gpt-3.5-turbo-16k-0613","gpt-3.5-turbo-16k","gpt-4-0613","gpt-4-32k-0613","gpt-4","gpt-4-32k"],"name":"model_name","advanced":false,"dynamic":false,"info":"","type":"str","list":true},"n":{"required":false,"placeholder":"","show":false,"multiline":false,"value":1,"password":false,"name":"n","advanced":false,"dynamic":false,"info":"","type":"int","list":false},"openai_api_base":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"openai_api_base","display_name":"OpenAI API Base","advanced":false,"dynamic":false,"info":"\nThe base URL of the OpenAI API. Defaults to https://api.openai.com/v1.\n\nYou can change this to use other APIs like JinaChat, LocalAI and Prem.\n","type":"str","list":false},"openai_api_key":{"required":false,"placeholder":"","show":true,"multiline":false,"value":"","password":true,"name":"openai_api_key","display_name":"OpenAI API Key","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"openai_organization":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"openai_organization","display_name":"OpenAI Organization","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"openai_proxy":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"openai_proxy","display_name":"OpenAI Proxy","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"request_timeout":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"request_timeout","advanced":false,"dynamic":false,"info":"","type":"float","list":false},"streaming":{"required":false,"placeholder":"","show":false,"multiline":false,"value":false,"password":false,"name":"streaming","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"tags":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"tags","advanced":false,"dynamic":false,"info":"","type":"str","list":true},"temperature":{"required":false,"placeholder":"","show":true,"multiline":false,"value":0.7,"password":false,"name":"temperature","advanced":false,"dynamic":false,"info":"","type":"float","list":false},"tiktoken_model_name":{"required":false,"placeholder":"","show":false,"multiline":false,"password":false,"name":"tiktoken_model_name","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"verbose":{"required":false,"placeholder":"","show":false,"multiline":false,"value":false,"password":false,"name":"verbose","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"_type":"ChatOpenAI"},"description":"`OpenAI` Chat large language models API.","base_classes":["ChatOpenAI","BaseLanguageModel","BaseChatModel","BaseLLM"],"display_name":"ChatOpenAI","custom_fields":{},"output_types":[],"documentation":"https://python.langchain.com/docs/modules/model_io/models/chat/integrations/openai","beta":false,"error":null},"id":"ChatOpenAI-02kOF"},"selected":false,"positionAbsolute":{"x":483,"y":942.8665628296949}},{"width":384,"height":389,"id":"ChatOutput-8SWFf","type":"genericNode","position":{"x":2035.5749798606498,"y":651.0174452514373},"data":{"type":"ChatOutput","node":{"template":{"code":{"dynamic":true,"required":true,"placeholder":"","show":true,"multiline":true,"value":"from typing import Optional\nfrom langflow.api.v1.schemas import ChatMessage\nfrom langflow.services.utils import get_chat_manager\nfrom langflow import CustomComponent\nfrom anyio.from_thread import start_blocking_portal\nfrom loguru import logger\nfrom langflow.field_typing import Text\n\n\nclass ChatOutput(CustomComponent):\n display_name = \"Chat Output\"\n\n def build_config(self):\n return {\"message\": {\"input_types\": [\"str\"]}}\n\n def build(self, message: Optional[Text], is_ai: bool = False) -> Text:\n if not message:\n return \"\"\n try:\n chat_manager = get_chat_manager()\n chat_message = ChatMessage(message=message, is_bot=is_ai)\n # send_message is a coroutine\n # run in a thread safe manner\n with start_blocking_portal() as portal:\n portal.call(chat_manager.send_message, chat_message)\n chat_manager.chat_history.add_message(\n chat_manager.cache_manager.current_client_id, chat_message\n )\n except Exception as exc:\n logger.exception(exc)\n logger.debug(f\"Error sending message to chat: {exc}\")\n\n return message\n","password":false,"name":"code","advanced":false,"type":"code","list":false},"_type":"CustomComponent","is_ai":{"required":true,"placeholder":"","show":true,"multiline":false,"value":false,"password":false,"name":"is_ai","display_name":"is_ai","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"message":{"required":false,"placeholder":"","show":true,"multiline":false,"password":false,"name":"message","display_name":"message","advanced":false,"input_types":["Text"],"dynamic":false,"info":"","type":"Text","list":false}},"description":"Used to send a message to the chat.","base_classes":["str"],"display_name":"Chat Output","custom_fields":{"is_ai":null,"message":null},"output_types":["ChatOutput"],"documentation":"","beta":true,"error":null},"id":"ChatOutput-8SWFf"},"selected":false,"positionAbsolute":{"x":2035.5749798606498,"y":651.0174452514373}},{"width":384,"height":273,"id":"ChatInput-PqtHe","type":"genericNode","position":{"x":504.7467002897712,"y":388.46875},"data":{"type":"ChatInput","node":{"template":{"code":{"dynamic":true,"required":true,"placeholder":"","show":false,"multiline":true,"value":"from typing import Optional\nfrom langflow import CustomComponent\n\n\nclass ChatInput(CustomComponent):\n display_name = \"Chat Input\"\n\n def build(self, message: Optional[str] = \"\") -> str:\n return message\n","password":false,"name":"code","advanced":false,"type":"code","list":false},"_type":"CustomComponent","message":{"required":false,"placeholder":"","show":true,"multiline":false,"value":"","password":false,"name":"message","display_name":"message","advanced":false,"dynamic":false,"info":"","type":"str","list":false}},"description":"Used to get user input from the chat.","base_classes":["str"],"display_name":"Chat Input","custom_fields":{"message":null},"output_types":["ChatInput"],"documentation":"","beta":true,"error":null},"id":"ChatInput-PqtHe"},"selected":false,"positionAbsolute":{"x":504.7467002897712,"y":388.46875}},{"width":384,"height":475,"id":"Tool-jyI4N","type":"genericNode","position":{"x":2044.485030617051,"y":1131.4250055845532},"data":{"type":"Tool","node":{"template":{"func":{"required":true,"placeholder":"","show":true,"multiline":true,"password":false,"name":"func","advanced":false,"dynamic":false,"info":"","type":"function","list":false},"description":{"required":true,"placeholder":"","show":true,"multiline":true,"value":"Test tool","password":false,"name":"description","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"name":{"required":true,"placeholder":"","show":true,"multiline":true,"value":"Tool","password":false,"name":"name","advanced":false,"dynamic":false,"info":"","type":"str","list":false},"return_direct":{"required":true,"placeholder":"","show":true,"multiline":false,"value":false,"password":false,"name":"return_direct","advanced":false,"dynamic":false,"info":"","type":"bool","list":false},"_type":"Tool"},"description":"Converts a chain, agent or function into a tool.","base_classes":["Tool","BaseTool"],"display_name":"Tool","custom_fields":{},"output_types":[],"documentation":"","beta":false,"error":null},"id":"Tool-jyI4N"},"selected":true,"positionAbsolute":{"x":2044.485030617051,"y":1131.4250055845532},"dragging":false}],"edges":[{"source":"PromptTemplate-CweKz","target":"LLMChain-HUM6g","sourceHandle":"PromptTemplate|PromptTemplate-CweKz|BasePromptTemplate|StringPromptTemplate|PromptTemplate","targetHandle":"BasePromptTemplate|prompt|LLMChain-HUM6g","id":"reactflow__edge-PromptTemplate-CweKzPromptTemplate|PromptTemplate-CweKz|BasePromptTemplate|StringPromptTemplate|PromptTemplate-LLMChain-HUM6gBasePromptTemplate|prompt|LLMChain-HUM6g","style":{"stroke":"#555"},"className":"stroke-gray-900 ","animated":false,"selected":false},{"source":"ChatOpenAI-02kOF","target":"LLMChain-HUM6g","sourceHandle":"ChatOpenAI|ChatOpenAI-02kOF|ChatOpenAI|BaseLanguageModel|BaseChatModel|BaseLLM","targetHandle":"BaseLanguageModel|llm|LLMChain-HUM6g","id":"reactflow__edge-ChatOpenAI-02kOFChatOpenAI|ChatOpenAI-02kOF|ChatOpenAI|BaseLanguageModel|BaseChatModel|BaseLLM-LLMChain-HUM6gBaseLanguageModel|llm|LLMChain-HUM6g","style":{"stroke":"#555"},"className":"stroke-gray-900 ","animated":false,"selected":false},{"source":"ChatInput-PqtHe","target":"PromptTemplate-CweKz","sourceHandle":"ChatInput|ChatInput-PqtHe|str","targetHandle":"Document;BaseOutputParser;str|input|PromptTemplate-CweKz","id":"reactflow__edge-ChatInput-PqtHeChatInput|ChatInput-PqtHe|str-PromptTemplate-CweKzDocument;BaseOutputParser;str|input|PromptTemplate-CweKz","style":{"stroke":"#555"},"className":"stroke-gray-900 ","animated":false,"selected":false},{"source":"LLMChain-HUM6g","sourceHandle":"LLMChain|LLMChain-HUM6g|LLMChain|Chain|function|Text","target":"ChatOutput-8SWFf","targetHandle":"Text|message|ChatOutput-8SWFf","style":{"stroke":"#555"},"className":"stroke-foreground stroke-connection","animated":true,"id":"reactflow__edge-LLMChain-HUM6gLLMChain|LLMChain-HUM6g|LLMChain|Chain|function|Text-ChatOutput-8SWFfText|message|ChatOutput-8SWFf"},{"source":"LLMChain-HUM6g","sourceHandle":"LLMChain|LLMChain-HUM6g|LLMChain|Chain|function|Text","target":"Tool-jyI4N","targetHandle":"function|func|Tool-jyI4N","style":{"stroke":"#555"},"className":"stroke-foreground stroke-connection","animated":false,"id":"reactflow__edge-LLMChain-HUM6gLLMChain|LLMChain-HUM6g|LLMChain|Chain|function|Text-Tool-jyI4Nfunction|func|Tool-jyI4N"}],"viewport":{"x":-401.32668426335044,"y":-129.59138346130635,"zoom":0.5073779796520557}},"id":"cf923ccb-e14c-4754-96eb-a8a3b5bbe082","user_id":"c65bfea3-3eea-4e71-8fc4-106238eb0583"} \ No newline at end of file diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index 01891ec05..e354d4a16 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -12,7 +12,7 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "ZeroShotAgent", "BaseSingleActionAgent", "Agent", - "Callable", + "function", } template = zero_shot_agent["template"] @@ -28,7 +28,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "list": True, "advanced": False, "info": "", - "fileTypes": [], } # Additional assertions for other template variables @@ -44,7 +43,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["llm"] == { "required": True, @@ -58,7 +56,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["output_parser"] == { "required": False, @@ -72,7 +69,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["input_variables"] == { "required": False, @@ -86,7 +82,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "list": True, "advanced": False, "info": "", - "fileTypes": [], } assert template["prefix"] == { "required": False, @@ -101,7 +96,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["suffix"] == { "required": False, @@ -116,7 +110,6 @@ def test_zero_shot_agent(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } @@ -142,9 +135,6 @@ def test_json_agent(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "file_path": "", - "fileTypes": [], - "value": "", } assert template["llm"] == { "required": True, @@ -159,9 +149,6 @@ def test_json_agent(client: TestClient, logged_in_headers): "advanced": False, "display_name": "LLM", "info": "", - "file_path": "", - "fileTypes": [], - "value": "", } @@ -182,12 +169,87 @@ def test_csv_agent(client: TestClient, logged_in_headers): "show": True, "multiline": False, "value": "", - "fileTypes": [".csv"], + "suffixes": [".csv"], + "fileTypes": ["csv"], "password": False, "name": "path", "type": "file", "list": False, - "file_path": "", + "file_path": None, + "advanced": False, + "info": "", + } + assert template["llm"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm", + "type": "BaseLanguageModel", + "list": False, + "advanced": False, + "display_name": "LLM", + "info": "", + } + + +def test_initialize_agent(client: TestClient, logged_in_headers): + response = client.get("api/v1/all", headers=logged_in_headers) + assert response.status_code == 200 + json_response = response.json() + agents = json_response["agents"] + + initialize_agent = agents["AgentInitializer"] + assert initialize_agent["base_classes"] == ["AgentExecutor", "function"] + template = initialize_agent["template"] + + assert template["agent"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "zero-shot-react-description", + "password": False, + "options": [ + "zero-shot-react-description", + "react-docstore", + "self-ask-with-search", + "conversational-react-description", + "openai-functions", + "openai-multi-functions", + ], + "name": "agent", + "type": "str", + "list": True, + "advanced": False, + "info": "", + } + assert template["memory"] == { + "required": False, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "memory", + "type": "BaseChatMemory", + "list": False, + "advanced": False, + "info": "", + } + assert template["tools"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "tools", + "type": "Tool", + "list": True, "advanced": False, "info": "", } @@ -204,7 +266,4 @@ def test_csv_agent(client: TestClient, logged_in_headers): "advanced": False, "display_name": "LLM", "info": "", - "file_path": "", - "fileTypes": [], - "value": "", } diff --git a/tests/test_api_key.py b/tests/test_api_key.py index 7988793d4..43b91fa43 100644 --- a/tests/test_api_key.py +++ b/tests/test_api_key.py @@ -6,7 +6,9 @@ from langflow.services.database.models.api_key import ApiKeyCreate def api_key(client, logged_in_headers, active_user): api_key = ApiKeyCreate(name="test-api-key") - response = client.post("api/v1/api_key", data=api_key.json(), headers=logged_in_headers) + response = client.post( + "api/v1/api_key", data=api_key.json(), headers=logged_in_headers + ) assert response.status_code == 200, response.text return response.json() @@ -26,7 +28,9 @@ def test_get_api_keys(client, logged_in_headers, api_key): def test_create_api_key(client, logged_in_headers): api_key_name = "test-api-key" - response = client.post("api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers) + response = client.post( + "api/v1/api_key", json={"name": api_key_name}, headers=logged_in_headers + ) assert response.status_code == 200 data = response.json() assert "name" in data and data["name"] == api_key_name diff --git a/tests/test_cache.py b/tests/test_cache.py index 8cfebe230..736884673 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,9 +1,8 @@ import json +from langflow.graph import Graph import pytest -from langflow.graph import Graph - def get_graph(_type="basic"): """Get a graph from a json file""" diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index 2e705ac00..4a8038d14 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -1,5 +1,6 @@ from fastapi.testclient import TestClient + # def test_chains_settings(client: TestClient, logged_in_headers): # response = client.get("api/v1/all", headers=logged_in_headers) # assert response.status_code == 200 @@ -8,53 +9,21 @@ from fastapi.testclient import TestClient # assert set(chains.keys()) == set(settings.chains) -def test_llm_checker_chain(client: TestClient, logged_in_headers): - response = client.get("api/v1/all", headers=logged_in_headers) - assert response.status_code == 200 - json_response = response.json() - chains = json_response["chains"] - chain = chains["LLMCheckerChain"] - - # Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects - assert set(chain["base_classes"]) == { - "Callable", - "LLMCheckerChain", - "Chain", - } - - template = chain["template"] - assert template["llm"] == { - "required": True, - "dynamic": False, - "placeholder": "", - "show": True, - "multiline": False, - "password": False, - "name": "llm", - "type": "BaseLanguageModel", - "list": False, - "advanced": False, - "info": "", - "fileTypes": [], - } - assert template["_type"] == "LLMCheckerChain" - - # Test the description object - assert chain["description"] == "" - - -def test_llm_math_chain(client: TestClient, logged_in_headers): +# Test the ConversationChain object +def test_conversation_chain(client: TestClient, logged_in_headers): response = client.get("api/v1/all", headers=logged_in_headers) assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] - chain = chains["LLMMathChain"] + chain = chains["ConversationChain"] # Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects assert set(chain["base_classes"]) == { - "Callable", - "LLMMathChain", + "ConversationChain", + "LLMChain", "Chain", + "function", + "Text", } template = chain["template"] @@ -70,7 +39,195 @@ def test_llm_math_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], + } + assert template["verbose"] == { + "required": False, + "dynamic": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "verbose", + "type": "bool", + "list": False, + "advanced": True, + "info": "", + } + assert template["llm"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm", + "type": "BaseLanguageModel", + "list": False, + "advanced": False, + "info": "", + } + assert template["input_key"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "input", + "password": False, + "name": "input_key", + "type": "str", + "list": False, + "advanced": True, + "info": "", + } + assert template["output_key"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "response", + "password": False, + "name": "output_key", + "type": "str", + "list": False, + "advanced": True, + "info": "", + } + assert template["_type"] == "ConversationChain" + + # Test the description object + assert ( + chain["description"] + == "Chain to have a conversation and load context from memory." + ) + + +def test_llm_chain(client: TestClient, logged_in_headers): + response = client.get("api/v1/all", headers=logged_in_headers) + assert response.status_code == 200 + json_response = response.json() + chains = json_response["chains"] + chain = chains["LLMChain"] + + # Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects + assert set(chain["base_classes"]) == {"function", "LLMChain", "Chain", "Text"} + + template = chain["template"] + assert template["memory"] == { + "required": False, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "memory", + "type": "BaseMemory", + "list": False, + "advanced": False, + "info": "", + } + assert template["verbose"] == { + "required": False, + "dynamic": False, + "placeholder": "", + "show": False, + "multiline": False, + "value": False, + "password": False, + "name": "verbose", + "type": "bool", + "list": False, + "advanced": True, + "info": "", + } + assert template["llm"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm", + "type": "BaseLanguageModel", + "list": False, + "advanced": False, + "info": "", + } + assert template["output_key"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "text", + "password": False, + "name": "output_key", + "type": "str", + "list": False, + "advanced": True, + "info": "", + } + + +def test_llm_checker_chain(client: TestClient, logged_in_headers): + response = client.get("api/v1/all", headers=logged_in_headers) + assert response.status_code == 200 + json_response = response.json() + chains = json_response["chains"] + chain = chains["LLMCheckerChain"] + + # Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects + assert set(chain["base_classes"]) == { + "function", + "LLMCheckerChain", + "Chain", + "Text", + } + + template = chain["template"] + assert template["llm"] == { + "required": True, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "llm", + "type": "BaseLanguageModel", + "list": False, + "advanced": False, + "info": "", + } + assert template["_type"] == "LLMCheckerChain" + + # Test the description object + assert chain["description"] == "" + + +def test_llm_math_chain(client: TestClient, logged_in_headers): + response = client.get("api/v1/all", headers=logged_in_headers) + assert response.status_code == 200 + json_response = response.json() + chains = json_response["chains"] + + chain = chains["LLMMathChain"] + # Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects + assert set(chain["base_classes"]) == {"function", "LLMMathChain", "Chain", "Text"} + + template = chain["template"] + assert template["memory"] == { + "required": False, + "dynamic": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "memory", + "type": "BaseMemory", + "list": False, + "advanced": False, + "info": "", } assert template["verbose"] == { "required": False, @@ -85,7 +242,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers): "list": False, "advanced": True, "info": "", - "fileTypes": [], } assert template["llm"] == { "required": True, @@ -99,7 +255,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["input_key"] == { "required": True, @@ -114,7 +269,6 @@ def test_llm_math_chain(client: TestClient, logged_in_headers): "list": False, "advanced": True, "info": "", - "fileTypes": [], } assert template["output_key"] == { "required": True, @@ -129,12 +283,14 @@ def test_llm_math_chain(client: TestClient, logged_in_headers): "list": False, "advanced": True, "info": "", - "fileTypes": [], } assert template["_type"] == "LLMMathChain" # Test the description object - assert chain["description"] == "Chain that interprets a prompt and executes python code to do math." + assert ( + chain["description"] + == "Chain that interprets a prompt and executes python code to do math." + ) def test_series_character_chain(client: TestClient, logged_in_headers): @@ -147,7 +303,7 @@ def test_series_character_chain(client: TestClient, logged_in_headers): # Test the base classes, template, memory, verbose, llm, input_key, output_key, and _type objects assert set(chain["base_classes"]) == { - "Callable", + "function", "LLMChain", "BaseCustomChain", "Chain", @@ -169,9 +325,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], - "file_path": "", - "value": "", } assert template["character"] == { "required": True, @@ -185,9 +338,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], - "file_path": "", - "value": "", } assert template["series"] == { "required": True, @@ -201,9 +351,6 @@ def test_series_character_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], - "file_path": "", - "value": "", } assert template["_type"] == "SeriesCharacterChain" @@ -247,12 +394,12 @@ def test_mid_journey_prompt_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "file_path": "", - "fileTypes": [], - "value": "", } # Test the description object - assert chain["description"] == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." + assert ( + chain["description"] + == "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." + ) def test_time_travel_guide_chain(client: TestClient, logged_in_headers): @@ -288,9 +435,6 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "file_path": "", - "fileTypes": [], - "value": "", } assert template["memory"] == { "required": False, @@ -304,9 +448,6 @@ def test_time_travel_guide_chain(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "file_path": "", - "fileTypes": [], - "value": "", } assert chain["description"] == "Time travel guide chain." diff --git a/tests/test_cli.py b/tests/test_cli.py index ee95a271c..ee938db12 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,10 +1,9 @@ from pathlib import Path from tempfile import tempdir - +from langflow.__main__ import app import pytest -from langflow.__main__ import app -from langflow.services import deps +from langflow.services import getters @pytest.fixture(scope="module") @@ -27,12 +26,11 @@ def test_components_path(runner, client, default_settings): ["run", "--components-path", str(temp_dir), *default_settings], ) assert result.exit_code == 0, result.stdout - settings_service = deps.get_settings_service() + settings_service = getters.get_settings_service() assert str(temp_dir) in settings_service.settings.COMPONENTS_PATH def test_superuser(runner, client, session): result = runner.invoke(app, ["superuser"], input="admin\nadmin\n") assert result.exit_code == 0, result.stdout - assert "Superuser creation failed." not in result.output, result.output - assert "Superuser created successfully." in result.output, result.output + assert "Superuser created successfully." in result.stdout diff --git a/tests/test_custom_component.py b/tests/test_custom_component.py index 96a7d7acc..47c9cbfb2 100644 --- a/tests/test_custom_component.py +++ b/tests/test_custom_component.py @@ -1,13 +1,19 @@ import ast +import pytest import types from uuid import uuid4 -import pytest -from langflow.interface.custom.base import CustomComponent -from langflow.interface.custom.code_parser.code_parser import CodeParser, CodeSyntaxError -from langflow.interface.custom.custom_component.component import Component, ComponentCodeNullError -from langflow.interface.custom.utils import build_custom_component_template + +from fastapi import HTTPException from langflow.services.database.models.flow import Flow, FlowCreate +from langflow.interface.custom.base import CustomComponent +from langflow.interface.custom.component import ( + Component, + ComponentCodeNullError, + ComponentFunctionEntrypointNameNullError, +) +from langflow.interface.custom.code_parser import CodeParser, CodeSyntaxError + code_default = """ from langflow import Prompt @@ -47,7 +53,7 @@ def test_code_parser_get_tree(): Test the __get_tree method of the CodeParser class. """ parser = CodeParser(code_default) - tree = parser.get_tree() + tree = parser._CodeParser__get_tree() assert isinstance(tree, ast.AST) @@ -60,23 +66,23 @@ def test_code_parser_syntax_error(): parser = CodeParser(code_syntax_error) with pytest.raises(CodeSyntaxError): - parser.get_tree() + parser._CodeParser__get_tree() def test_component_init(): """ Test the initialization of the Component class. """ - component = Component(code=code_default, _function_entrypoint_name="build") + component = Component(code=code_default, function_entrypoint_name="build") assert component.code == code_default - assert component._function_entrypoint_name == "build" + assert component.function_entrypoint_name == "build" def test_component_get_code_tree(): """ Test the get_code_tree method of the Component class. """ - component = Component(code=code_default, _function_entrypoint_name="build") + component = Component(code=code_default, function_entrypoint_name="build") tree = component.get_code_tree(component.code) assert "imports" in tree @@ -86,20 +92,19 @@ def test_component_code_null_error(): Test the get_function method raises the ComponentCodeNullError when the code is empty. """ - component = Component(code="", _function_entrypoint_name="") + component = Component(code="", function_entrypoint_name="") with pytest.raises(ComponentCodeNullError): component.get_function() -# TODO: Validate if we should remove this -# def test_component_function_entrypoint_name_null_error(): -# """ -# Test the get_function method raises the ComponentFunctionEntrypointNameNullError -# when the function_entrypoint_name is empty. -# """ -# component = Component(code=code_default, _function_entrypoint_name="") -# with pytest.raises(ComponentFunctionEntrypointNameNullError): -# component.get_function() +def test_component_function_entrypoint_name_null_error(): + """ + Test the get_function method raises the ComponentFunctionEntrypointNameNullError + when the function_entrypoint_name is empty. + """ + component = Component(code=code_default, function_entrypoint_name="") + with pytest.raises(ComponentFunctionEntrypointNameNullError): + component.get_function() def test_custom_component_init(): @@ -108,7 +113,9 @@ def test_custom_component_init(): """ function_entrypoint_name = "build" - custom_component = CustomComponent(code=code_default, function_entrypoint_name=function_entrypoint_name) + custom_component = CustomComponent( + code=code_default, function_entrypoint_name=function_entrypoint_name + ) assert custom_component.code == code_default assert custom_component.function_entrypoint_name == function_entrypoint_name @@ -117,8 +124,10 @@ def test_custom_component_build_template_config(): """ Test the build_template_config property of the CustomComponent class. """ - custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") - config = custom_component.template_config + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) + config = custom_component.build_template_config assert isinstance(config, dict) @@ -126,7 +135,9 @@ def test_custom_component_get_function(): """ Test the get_function property of the CustomComponent class. """ - custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") + custom_component = CustomComponent( + code="def build(): pass", function_entrypoint_name="build" + ) my_function = custom_component.get_function assert isinstance(my_function, types.FunctionType) @@ -137,7 +148,7 @@ def test_code_parser_parse_imports_import(): class with an import statement. """ parser = CodeParser(code_default) - tree = parser.get_tree() + tree = parser._CodeParser__get_tree() for node in ast.walk(tree): if isinstance(node, ast.Import): parser.parse_imports(node) @@ -150,7 +161,7 @@ def test_code_parser_parse_imports_importfrom(): class with an import from statement. """ parser = CodeParser("from os import path") - tree = parser.get_tree() + tree = parser._CodeParser__get_tree() for node in ast.walk(tree): if isinstance(node, ast.ImportFrom): parser.parse_imports(node) @@ -162,7 +173,7 @@ def test_code_parser_parse_functions(): Test the parse_functions method of the CodeParser class. """ parser = CodeParser("def test(): pass") - tree = parser.get_tree() + tree = parser._CodeParser__get_tree() for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): parser.parse_functions(node) @@ -175,7 +186,7 @@ def test_code_parser_parse_classes(): Test the parse_classes method of the CodeParser class. """ parser = CodeParser("class Test: pass") - tree = parser.get_tree() + tree = parser._CodeParser__get_tree() for node in ast.walk(tree): if isinstance(node, ast.ClassDef): parser.parse_classes(node) @@ -188,7 +199,7 @@ def test_code_parser_parse_global_vars(): Test the parse_global_vars method of the CodeParser class. """ parser = CodeParser("x = 1") - tree = parser.get_tree() + tree = parser._CodeParser__get_tree() for node in ast.walk(tree): if isinstance(node, ast.Assign): parser.parse_global_vars(node) @@ -201,7 +212,7 @@ def test_component_get_function_valid(): Test the get_function method of the Component class with valid code and function_entrypoint_name. """ - component = Component(code="def build(): pass", _function_entrypoint_name="build") + component = Component(code="def build(): pass", function_entrypoint_name="build") my_function = component.get_function() assert callable(my_function) @@ -211,7 +222,9 @@ def test_custom_component_get_function_entrypoint_args(): Test the get_function_entrypoint_args property of the CustomComponent class. """ - custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) args = custom_component.get_function_entrypoint_args assert len(args) == 4 assert args[0]["name"] == "self" @@ -224,18 +237,20 @@ def test_custom_component_get_function_entrypoint_return_type(): Test the get_function_entrypoint_return_type property of the CustomComponent class. """ - from langchain.schema import Document - - custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) return_type = custom_component.get_function_entrypoint_return_type - assert return_type == [Document] + assert return_type == ["Document"] def test_custom_component_get_main_class_name(): """ Test the get_main_class_name property of the CustomComponent class. """ - custom_component = CustomComponent(code=code_default, function_entrypoint_name="build") + custom_component = CustomComponent( + code=code_default, function_entrypoint_name="build" + ) class_name = custom_component.get_main_class_name assert class_name == "YourComponent" @@ -245,7 +260,9 @@ def test_custom_component_get_function_valid(): Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name. """ - custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") + custom_component = CustomComponent( + code="def build(): pass", function_entrypoint_name="build" + ) my_function = custom_component.get_function assert callable(my_function) @@ -280,7 +297,9 @@ def test_code_parser_parse_callable_details_no_args(): parser = CodeParser("") node = ast.FunctionDef( name="test", - args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), + args=ast.arguments( + args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] + ), body=[], decorator_list=[], returns=None, @@ -309,7 +328,7 @@ def test_code_parser_parse_ann_assign(): stmt = ast.AnnAssign( target=ast.Name(id="x", ctx=ast.Store()), annotation=ast.Name(id="int", ctx=ast.Load()), - value=ast.Constant(n=1), + value=ast.Num(n=1), simple=1, ) result = parser.parse_ann_assign(stmt) @@ -326,7 +345,9 @@ def test_code_parser_parse_function_def_not_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="test", - args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), + args=ast.arguments( + args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] + ), body=[], decorator_list=[], returns=None, @@ -344,7 +365,9 @@ def test_code_parser_parse_function_def_init(): parser = CodeParser("") stmt = ast.FunctionDef( name="__init__", - args=ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), + args=ast.arguments( + args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[] + ), body=[], decorator_list=[], returns=None, @@ -359,17 +382,29 @@ def test_component_get_code_tree_syntax_error(): Test the get_code_tree method of the Component class raises the CodeSyntaxError when given incorrect syntax. """ - component = Component(code="import os as", _function_entrypoint_name="build") + component = Component(code="import os as", function_entrypoint_name="build") with pytest.raises(CodeSyntaxError): component.get_code_tree(component.code) +def test_custom_component_class_template_validation_no_code(): + """ + Test the _class_template_validation method of the CustomComponent class + raises the HTTPException when the code is None. + """ + custom_component = CustomComponent(code=None, function_entrypoint_name="build") + with pytest.raises(HTTPException): + custom_component._class_template_validation(custom_component.code) + + def test_custom_component_get_code_tree_syntax_error(): """ Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax. """ - custom_component = CustomComponent(code="import os as", function_entrypoint_name="build") + custom_component = CustomComponent( + code="import os as", function_entrypoint_name="build" + ) with pytest.raises(CodeSyntaxError): custom_component.get_code_tree(custom_component.code) @@ -423,7 +458,9 @@ def test_custom_component_build_not_implemented(): Test the build method of the CustomComponent class raises the NotImplementedError. """ - custom_component = CustomComponent(code="def build(): pass", function_entrypoint_name="build") + custom_component = CustomComponent( + code="def build(): pass", function_entrypoint_name="build" + ) with pytest.raises(NotImplementedError): custom_component.build() @@ -431,7 +468,7 @@ def test_custom_component_build_not_implemented(): def test_build_config_no_code(): component = CustomComponent(code=None) - assert component.get_function_entrypoint_args == [] + assert component.get_function_entrypoint_args == "" assert component.get_function_entrypoint_return_type == [] @@ -457,7 +494,9 @@ def test_flow(db): } # Create flow - flow = FlowCreate(id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data) + flow = FlowCreate( + id=uuid4(), name="Test Flow", description="Fixture flow", data=flow_data + ) # Add to database db.add(flow) @@ -518,36 +557,3 @@ def test_build_config_field_value_keys(component): config = component.build_config() field_values = config["fields"].values() assert all("type" in value for value in field_values) - - -def test_create_and_validate_component_valid_code(test_component_code): - component = CustomComponent(code=test_component_code) - assert isinstance(component, CustomComponent) - - -def test_build_langchain_template_custom_component_valid_code(test_component_code): - component = CustomComponent(code=test_component_code) - frontend_node = build_custom_component_template(component) - assert isinstance(frontend_node, dict) - template = frontend_node["template"] - assert isinstance(template, dict) - assert "param" in template - param_options = template["param"]["options"] - # Now run it again with an update field - frontend_node = build_custom_component_template(component, update_field="param") - new_param_options = frontend_node["template"]["param"]["options"] - assert param_options != new_param_options - - -def test_build_langchain_template_custom_component_templatefield(test_component_with_templatefield_code): - component = CustomComponent(code=test_component_with_templatefield_code) - frontend_node = build_custom_component_template(component) - assert isinstance(frontend_node, dict) - template = frontend_node["template"] - assert isinstance(template, dict) - assert "param" in template - param_options = template["param"]["options"] - # Now run it again with an update field - frontend_node = build_custom_component_template(component, update_field="param") - new_param_options = frontend_node["template"]["param"]["options"] - assert param_options != new_param_options diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index ba54b7023..b65f58d0a 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -18,7 +18,9 @@ def test_python_function_tool(): with pytest.raises(SyntaxError): code = pytest.CODE_WITH_SYNTAX_ERROR func = get_function(code) - func = PythonFunctionTool(name="Test", description="Testing", code=code, func=func) + func = PythonFunctionTool( + name="Test", description="Testing", code=code, func=func + ) def test_python_function(): diff --git a/tests/test_database.py b/tests/test_database.py index f52252856..21f0cec17 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,15 +1,17 @@ -from uuid import UUID, uuid4 - +from langflow.services.database.models.base import orjson_dumps +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_service import orjson import pytest -from fastapi.testclient import TestClient -from langflow.api.v1.schemas import FlowListCreate -from langflow.services.database.models.base import orjson_dumps -from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate -from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service + +from uuid import UUID, uuid4 from sqlmodel import Session +from fastapi.testclient import TestClient + +from langflow.api.v1.schemas import FlowListCreate +from langflow.services.database.models.flow import Flow, FlowCreate, FlowUpdate + @pytest.fixture(scope="module") def json_style(): @@ -25,17 +27,21 @@ def json_style(): ) -def test_create_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): +def test_create_flow( + client: TestClient, json_flow: str, active_user, logged_in_headers +): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data # flow is optional so we can create a flow without a flow - flow = FlowCreate(name="Test Flow", description="description") - response = client.post("api/v1/flows/", json=flow.model_dump(exclude_unset=True), headers=logged_in_headers) + flow = FlowCreate(name="Test Flow") + response = client.post( + "api/v1/flows/", json=flow.dict(exclude_unset=True), headers=logged_in_headers + ) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data @@ -45,13 +51,13 @@ def test_read_flows(client: TestClient, json_flow: str, active_user, logged_in_h flow_data = orjson.loads(json_flow) data = flow_data["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data flow = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) assert response.status_code == 201 assert response.json()["name"] == flow.name assert response.json()["data"] == flow.data @@ -65,7 +71,7 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) flow_id = response.json()["id"] # flow_id should be a UUID but is a string # turn it into a UUID flow_id = UUID(flow_id) @@ -76,12 +82,14 @@ def test_read_flow(client: TestClient, json_flow: str, active_user, logged_in_he assert response.json()["data"] == flow.data -def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): +def test_update_flow( + client: TestClient, json_flow: str, active_user, logged_in_headers +): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) flow_id = response.json()["id"] updated_flow = FlowUpdate( @@ -89,7 +97,9 @@ def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_ description="updated description", data=data, ) - response = client.patch(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers) + response = client.patch( + f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers + ) assert response.status_code == 200 assert response.json()["name"] == updated_flow.name @@ -97,18 +107,22 @@ def test_update_flow(client: TestClient, json_flow: str, active_user, logged_in_ # assert response.json()["data"] == updated_flow.data -def test_delete_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): +def test_delete_flow( + client: TestClient, json_flow: str, active_user, logged_in_headers +): flow = orjson.loads(json_flow) data = flow["data"] flow = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post("api/v1/flows/", json=flow.model_dump(), headers=logged_in_headers) + response = client.post("api/v1/flows/", json=flow.dict(), headers=logged_in_headers) flow_id = response.json()["id"] response = client.delete(f"api/v1/flows/{flow_id}", headers=logged_in_headers) assert response.status_code == 200 assert response.json()["message"] == "Flow deleted successfully" -def test_create_flows(client: TestClient, session: Session, json_flow: str, logged_in_headers): +def test_create_flows( + client: TestClient, session: Session, json_flow: str, logged_in_headers +): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -119,7 +133,9 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str, logg ] ) # Make request to endpoint - response = client.post("api/v1/flows/batch/", json=flow_list.model_dump(), headers=logged_in_headers) + response = client.post( + "api/v1/flows/batch/", json=flow_list.dict(), headers=logged_in_headers + ) # Check response status code assert response.status_code == 201 # Check response data @@ -133,7 +149,9 @@ def test_create_flows(client: TestClient, session: Session, json_flow: str, logg assert response_data[1]["data"] == data -def test_upload_file(client: TestClient, session: Session, json_flow: str, logged_in_headers): +def test_upload_file( + client: TestClient, session: Session, json_flow: str, logged_in_headers +): flow = orjson.loads(json_flow) data = flow["data"] # Create test data @@ -143,7 +161,7 @@ def test_upload_file(client: TestClient, session: Session, json_flow: str, logge FlowCreate(name="Flow 2", description="description", data=data), ] ) - file_contents = orjson_dumps(flow_list.model_dump()) + file_contents = orjson_dumps(flow_list.dict()) response = client.post( "api/v1/flows/upload/", files={"file": ("examples.json", file_contents, "application/json")}, @@ -182,7 +200,7 @@ def test_download_file( with session_getter(db_manager) as session: for flow in flow_list.flows: flow.user_id = active_user.id - db_flow = Flow.model_validate(flow, from_attributes=True) + db_flow = Flow.from_orm(flow) session.add(db_flow) session.commit() # Make request to endpoint @@ -200,7 +218,9 @@ def test_download_file( assert response_data[1]["data"] == data -def test_create_flow_with_invalid_data(client: TestClient, active_user, logged_in_headers): +def test_create_flow_with_invalid_data( + client: TestClient, active_user, logged_in_headers +): flow = {"name": "a" * 256, "data": "Invalid flow data"} response = client.post("api/v1/flows/", json=flow, headers=logged_in_headers) assert response.status_code == 422 @@ -212,19 +232,29 @@ def test_get_nonexistent_flow(client: TestClient, active_user, logged_in_headers assert response.status_code == 404 -def test_update_flow_idempotency(client: TestClient, json_flow: str, active_user, logged_in_headers): +def test_update_flow_idempotency( + client: TestClient, json_flow: str, active_user, logged_in_headers +): flow_data = orjson.loads(json_flow) data = flow_data["data"] flow_data = FlowCreate(name="Test Flow", description="description", data=data) - response = client.post("api/v1/flows/", json=flow_data.model_dump(), headers=logged_in_headers) + response = client.post( + "api/v1/flows/", json=flow_data.dict(), headers=logged_in_headers + ) flow_id = response.json()["id"] updated_flow = FlowCreate(name="Updated Flow", description="description", data=data) - response1 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers) - response2 = client.put(f"api/v1/flows/{flow_id}", json=updated_flow.model_dump(), headers=logged_in_headers) + response1 = client.put( + f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers + ) + response2 = client.put( + f"api/v1/flows/{flow_id}", json=updated_flow.dict(), headers=logged_in_headers + ) assert response1.json() == response2.json() -def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user, logged_in_headers): +def test_update_nonexistent_flow( + client: TestClient, json_flow: str, active_user, logged_in_headers +): flow_data = orjson.loads(json_flow) data = flow_data["data"] uuid = uuid4() @@ -233,7 +263,9 @@ def test_update_nonexistent_flow(client: TestClient, json_flow: str, active_user description="description", data=data, ) - response = client.patch(f"api/v1/flows/{uuid}", json=updated_flow.model_dump(), headers=logged_in_headers) + response = client.patch( + f"api/v1/flows/{uuid}", json=updated_flow.dict(), headers=logged_in_headers + ) assert response.status_code == 404 diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 338224004..a7c721b09 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,17 +1,18 @@ -import time -import uuid from collections import namedtuple - +import uuid +from langflow.processing.process import Result +from langflow.services.auth.utils import get_password_hash +from langflow.services.database.models.api_key.api_key import ApiKey +from langflow.services.getters import get_settings_service +from langflow.services.database.utils import session_getter +from langflow.services.getters import get_db_service import pytest from fastapi.testclient import TestClient from langflow.interface.tools.constants import CUSTOM_TOOLS -from langflow.processing.process import Result -from langflow.services.auth.utils import get_password_hash -from langflow.services.database.models.api_key.model import ApiKey -from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service, get_settings_service from langflow.template.frontend_node.chains import TimeTravelGuideChainNode +import time + def run_post(client, flow_id, headers, post_data): response = client.post( @@ -24,13 +25,16 @@ def run_post(client, flow_id, headers, post_data): # Helper function to poll task status -def poll_task_status(client, headers, href, max_attempts=20, sleep_time=2): +def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1): for _ in range(max_attempts): task_status_response = client.get( href, headers=headers, ) - if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS": + if ( + task_status_response.status_code == 200 + and task_status_response.json()["status"] == "SUCCESS" + ): return task_status_response.json() time.sleep(sleep_time) return None # Return None if task did not complete in time @@ -72,14 +76,11 @@ PROMPT_REQUEST = { "text-ada-001", ], "ChatOpenAI": [ - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k-0613", - "gpt-3.5-turbo-16k", - "gpt-4-0613", - "gpt-4-32k-0613", + "gpt-4-1106-preview", "gpt-4", "gpt-4-32k", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", ], "Anthropic": [ "claude-v1", @@ -126,7 +127,11 @@ def created_api_key(active_user): ) db_manager = get_db_service() with session_getter(db_manager) as session: - if existing_api_key := session.query(ApiKey).filter(ApiKey.api_key == api_key.api_key).first(): + if ( + existing_api_key := session.query(ApiKey) + .filter(ApiKey.api_key == api_key.api_key) + .first() + ): return existing_api_key session.add(api_key) session.commit() @@ -185,7 +190,9 @@ def test_process_flow_invalid_id(client, monkeypatch, created_api_key): } invalid_id = uuid.uuid4() - response = client.post(f"api/v1/process/{invalid_id}", headers=headers, json=post_data) + response = client.post( + f"api/v1/process/{invalid_id}", headers=headers, json=post_data + ) assert response.status_code == 404 assert f"Flow {invalid_id} not found" in response.json()["detail"] @@ -226,7 +233,9 @@ def test_process_flow_without_autologin(client, flow, monkeypatch, created_api_k monkeypatch.setattr(endpoints, "process_graph_cached", mock_process_graph_cached) monkeypatch.setattr(crud, "update_total_uses", mock_update_total_uses) - monkeypatch.setattr(endpoints, "process_graph_cached_task", mock_process_graph_cached_task) + monkeypatch.setattr( + endpoints, "process_graph_cached_task", mock_process_graph_cached_task + ) api_key = created_api_key.api_key headers = {"x-api-key": api_key} @@ -410,6 +419,105 @@ def test_various_prompts(client, prompt, expected_input_variables): assert response.json()["input_variables"] == expected_input_variables +def test_get_vertices_flow_not_found(client, logged_in_headers): + response = client.get( + "/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers + ) + assert ( + response.status_code == 500 + ) # Or whatever status code you've set for invalid ID + + +def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers): + flow_id = added_flow_with_prompt_and_history["id"] + response = client.get( + f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers + ) + assert response.status_code == 200 + assert "ids" in response.json() + # The response should contain the list in this order + # ['ConversationBufferMemory-Lu2Nb', 'PromptTemplate-5Q0W8', 'ChatOpenAI-vy7fV', 'LLMChain-UjBh1'] + # The important part is before the - (ConversationBufferMemory, PromptTemplate, ChatOpenAI, LLMChain) + ids = [id.split("-")[0] for id in response.json()["ids"]] + assert ids == [ + "ConversationBufferMemory", + "PromptTemplate", + "ChatOpenAI", + "LLMChain", + ] + + +def test_build_vertex_invalid_flow_id(client, logged_in_headers): + response = client.post( + "/api/v1/build/nonexistent_id/vertices/vertex_id", headers=logged_in_headers + ) + assert response.status_code == 500 + + +def test_build_vertex_invalid_vertex_id( + client, added_flow_with_prompt_and_history, logged_in_headers +): + flow_id = added_flow_with_prompt_and_history["id"] + response = client.post( + f"/api/v1/build/{flow_id}/vertices/invalid_vertex_id", headers=logged_in_headers + ) + assert response.status_code == 500 + + +def test_build_all_vertices_in_sequence_with_chat_input( + client, added_flow_chat_input, logged_in_headers +): + flow_id = added_flow_chat_input["id"] + + # First, get all the vertices in the correct sequence + response = client.get( + f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers + ) + assert response.status_code == 200 + assert "ids" in response.json() + vertex_ids = response.json()["ids"] + + # Now, iterate through each vertex and build it + for vertex_id in vertex_ids: + response = client.post( + f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers + ) + json_response = response.json() + assert ( + response.status_code == 200 + ), f"Failed at vertex {vertex_id}: {json_response}" + assert "valid" in json_response + assert json_response["valid"], json_response["params"] + + +def test_build_all_vertices_in_sequence_with_two_outputs( + client, added_flow_two_outputs, logged_in_headers +): + """This tests the case where a node has two outputs, one of which is Text and the other (in this case) is + a LLMChain. We need to make sure the correct output is passed in both cases.""" + flow_id = added_flow_two_outputs["id"] + + # First, get all the vertices in the correct sequence + response = client.get( + f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers + ) + assert response.status_code == 200 + assert "ids" in response.json() + vertex_ids = response.json()["ids"] + + # Now, iterate through each vertex and build it + for vertex_id in vertex_ids: + response = client.post( + f"/api/v1/build/{flow_id}/vertices/{vertex_id}", headers=logged_in_headers + ) + json_response = response.json() + assert ( + response.status_code == 200 + ), f"Failed at vertex {vertex_id}: {json_response}" + assert "valid" in json_response + assert json_response["valid"], json_response["params"] + + def test_basic_chat_in_process(client, added_flow, created_api_key): # Run the /api/v1/process/{flow_id} endpoint headers = {"x-api-key": created_api_key.api_key} @@ -498,7 +606,9 @@ def test_basic_chat_with_two_session_ids_and_names(client, added_flow, created_a @pytest.mark.async_test -def test_vector_store_in_process(distributed_client, added_vector_store, created_api_key): +def test_vector_store_in_process( + distributed_client, added_vector_store, created_api_key +): # Run the /api/v1/process/{flow_id} endpoint headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"input": "What is Langflow?"}} @@ -549,7 +659,9 @@ def test_async_task_processing(distributed_client, added_flow, created_api_key): # Test function without loop @pytest.mark.async_test -def test_async_task_processing_vector_store(client, added_vector_store, created_api_key): +def test_async_task_processing_vector_store( + client, added_vector_store, created_api_key +): headers = {"x-api-key": created_api_key.api_key} post_data = {"inputs": {"input": "How do I upload examples?"}} @@ -578,4 +690,6 @@ def test_async_task_processing_vector_store(client, added_vector_store, created_ # Validate that the task completed successfully and the result is as expected assert "result" in task_status_json, task_status_json assert "output" in task_status_json["result"], task_status_json["result"] - assert "Langflow" in task_status_json["result"]["output"], task_status_json["result"] + assert "Langflow" in task_status_json["result"]["output"], task_status_json[ + "result" + ] diff --git a/tests/test_frontend_nodes.py b/tests/test_frontend_nodes.py index e92ad1fe4..00fe9fcb1 100644 --- a/tests/test_frontend_nodes.py +++ b/tests/test_frontend_nodes.py @@ -31,14 +31,17 @@ def test_template_field_defaults(sample_template_field: TemplateField): assert sample_template_field.is_list is False assert sample_template_field.show is True assert sample_template_field.multiline is False - assert sample_template_field.value == "" + assert sample_template_field.value is None + assert sample_template_field.suffixes == [] assert sample_template_field.file_types == [] - assert sample_template_field.file_path == "" + assert sample_template_field.file_path is None assert sample_template_field.password is False assert sample_template_field.name == "test_field" -def test_template_to_dict(sample_template: Template, sample_template_field: TemplateField): +def test_template_to_dict( + sample_template: Template, sample_template_field: TemplateField +): template_dict = sample_template.to_dict() assert template_dict["_type"] == "test_template" assert len(template_dict) == 2 # _type and test_field diff --git a/tests/test_graph.py b/tests/test_graph.py index cca6d4f49..f32bc21d7 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,31 +1,32 @@ import copy import json import os -import pickle from pathlib import Path +import pickle from typing import Type, Union - -import pytest +from langflow.graph.edge.base import Edge +from langflow.graph.vertex.base import Vertex from langchain.agents import AgentExecutor +import pytest from langchain.chains.base import Chain from langchain.llms.fake import FakeListLLM - from langflow.graph import Graph -from langflow.graph.edge.base import Edge +from langflow.graph.vertex.types import ( + FileToolVertex, + LLMVertex, + ToolkitVertex, +) +from langflow.processing.process import get_result_and_thought +from langflow.utils.payload import get_root_node from langflow.graph.graph.utils import ( find_last_node, - process_flow, set_new_target_handle, ungroup_node, + process_flow, update_source_handle, update_target_handle, update_template, ) -from langflow.graph.utils import UnbuiltObject -from langflow.graph.vertex.base import Vertex -from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex -from langflow.processing.process import get_result_and_thought -from langflow.utils.payload import get_root_vertex # Test cases for the graph module @@ -46,7 +47,13 @@ def sample_nodes(): return [ { "id": "node1", - "data": {"node": {"template": {"some_field": {"show": True, "advanced": False, "name": "Name1"}}}}, + "data": { + "node": { + "template": { + "some_field": {"show": True, "advanced": False, "name": "Name1"} + } + } + }, }, { "id": "node2", @@ -64,7 +71,11 @@ def sample_nodes(): }, { "id": "node3", - "data": {"node": {"template": {"unrelated_field": {"show": True, "advanced": True}}}}, + "data": { + "node": { + "template": {"unrelated_field": {"show": True, "advanced": True}} + } + }, }, ] @@ -82,8 +93,8 @@ def test_graph_structure(basic_graph): assert isinstance(node, Vertex) for edge in basic_graph.edges: assert isinstance(edge, Edge) - assert edge.source_id in basic_graph.vertex_map.keys() - assert edge.target_id in basic_graph.vertex_map.keys() + assert edge.source in basic_graph.vertices + assert edge.target in basic_graph.vertices def test_circular_dependencies(basic_graph): @@ -91,7 +102,7 @@ def test_circular_dependencies(basic_graph): def check_circular(node, visited): visited.add(node) - neighbors = basic_graph.get_vertices_with_target(node) + neighbors = basic_graph.get_nodes_with_target(node) for neighbor in neighbors: if neighbor in visited: return True @@ -124,13 +135,13 @@ def test_invalid_node_types(): Graph(graph_data["nodes"], graph_data["edges"]) -def test_get_vertices_with_target(basic_graph): +def test_get_nodes_with_target(basic_graph): """Test getting connected nodes""" assert isinstance(basic_graph, Graph) # Get root node - root = get_root_vertex(basic_graph) + root = get_root_node(basic_graph) assert root is not None - connected_nodes = basic_graph.get_vertices_with_target(root.id) + connected_nodes = basic_graph.get_nodes_with_target(root) assert connected_nodes is not None @@ -139,66 +150,23 @@ def test_get_node_neighbors_basic(basic_graph): assert isinstance(basic_graph, Graph) # Get root node - root = get_root_vertex(basic_graph) + root = get_root_node(basic_graph) assert root is not None - neighbors = basic_graph.get_vertex_neighbors(root) + neighbors = basic_graph.get_node_neighbors(root) assert neighbors is not None assert isinstance(neighbors, dict) # Root Node is an Agent, it requires an LLMChain and tools # We need to check if there is a Chain in the one of the neighbors' # data attribute in the type key - assert any("ConversationBufferMemory" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) + assert any( + "ConversationBufferMemory" in neighbor.data["type"] + for neighbor, val in neighbors.items() + if val + ) - assert any("OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val) - - -# def test_get_node_neighbors_complex(complex_graph): -# """Test getting node neighbors""" -# assert isinstance(complex_graph, Graph) -# # Get root node -# root = get_root_node(complex_graph) -# assert root is not None -# neighbors = complex_graph.get_nodes_with_target(root) -# assert neighbors is not None -# # Neighbors should be a list of nodes -# assert isinstance(neighbors, list) -# # Root Node is an Agent, it requires an LLMChain and tools -# # We need to check if there is a Chain in the one of the neighbors' -# assert any("Chain" in neighbor.data["type"] for neighbor in neighbors) -# # assert Tool is in the neighbors -# assert any("Tool" in neighbor.data["type"] for neighbor in neighbors) -# # Now on to the Chain's neighbors -# chain = next(neighbor for neighbor in neighbors if "Chain" in neighbor.data["type"]) -# chain_neighbors = complex_graph.get_nodes_with_target(chain) -# assert chain_neighbors is not None -# # Check if there is a LLM in the chain's neighbors -# assert any("OpenAI" in neighbor.data["type"] for neighbor in chain_neighbors) -# # Chain should have a Prompt as a neighbor -# assert any("Prompt" in neighbor.data["type"] for neighbor in chain_neighbors) -# # Now on to the Tool's neighbors -# tool = next(neighbor for neighbor in neighbors if "Tool" in neighbor.data["type"]) -# tool_neighbors = complex_graph.get_nodes_with_target(tool) -# assert tool_neighbors is not None -# # Check if there is an Agent in the tool's neighbors -# assert any("Agent" in neighbor.data["type"] for neighbor in tool_neighbors) -# # This Agent has a Tool that has a PythonFunction as func -# agent = next( -# neighbor for neighbor in tool_neighbors if "Agent" in neighbor.data["type"] -# ) -# agent_neighbors = complex_graph.get_nodes_with_target(agent) -# assert agent_neighbors is not None -# # Check if there is a Tool in the agent's neighbors -# assert any("Tool" in neighbor.data["type"] for neighbor in agent_neighbors) -# # This Tool has a PythonFunction as func -# tool = next( -# neighbor for neighbor in agent_neighbors if "Tool" in neighbor.data["type"] -# ) -# tool_neighbors = complex_graph.get_nodes_with_target(tool) -# assert tool_neighbors is not None -# # Check if there is a PythonFunction in the tool's neighbors -# assert any( -# "PythonFunctionTool" in neighbor.data["type"] for neighbor in tool_neighbors -# ) + assert any( + "OpenAI" in neighbor.data["type"] for neighbor, val in neighbors.items() if val + ) def test_get_node(basic_graph): @@ -212,7 +180,7 @@ def test_get_node(basic_graph): def test_build_nodes(basic_graph): """Test building nodes""" - assert len(basic_graph.vertices) == len(basic_graph._vertices) + assert len(basic_graph.vertices) == len(basic_graph._nodes) for node in basic_graph.vertices: assert isinstance(node, Vertex) @@ -222,21 +190,20 @@ def test_build_edges(basic_graph): assert len(basic_graph.edges) == len(basic_graph._edges) for edge in basic_graph.edges: assert isinstance(edge, Edge) - - assert isinstance(edge.source_id, str) - assert isinstance(edge.target_id, str) + assert isinstance(edge.source, Vertex) + assert isinstance(edge.target, Vertex) -def test_get_root_vertex(client, basic_graph, complex_graph): +def test_get_root_node(client, basic_graph, complex_graph): """Test getting root node""" assert isinstance(basic_graph, Graph) - root = get_root_vertex(basic_graph) + root = get_root_node(basic_graph) assert root is not None assert isinstance(root, Vertex) assert root.data["type"] == "TimeTravelGuideChain" # For complex example, the root node is a ZeroShotAgent too assert isinstance(complex_graph, Graph) - root = get_root_vertex(complex_graph) + root = get_root_node(complex_graph) assert root is not None assert isinstance(root, Vertex) assert root.data["type"] == "ZeroShotAgent" @@ -272,7 +239,7 @@ def test_build_params(basic_graph): # The matched_type attribute should be in the source_types attr assert all(edge.matched_type in edge.source_types for edge in basic_graph.edges) # Get the root node - root = get_root_vertex(basic_graph) + root = get_root_node(basic_graph) # Root node is a TimeTravelGuideChain # which requires an llm and memory assert root is not None @@ -281,32 +248,29 @@ def test_build_params(basic_graph): assert "memory" in root.params -@pytest.mark.asyncio -async def test_build(basic_graph): +def test_build(basic_graph): """Test Node's build method""" - await assert_agent_was_built(basic_graph) + assert_agent_was_built(basic_graph) -async def assert_agent_was_built(graph): +def assert_agent_was_built(graph): """Assert that the agent was built""" assert isinstance(graph, Graph) # Now we test the build method # Build the Agent - result = await graph.build() + result = graph.build() # The agent should be a AgentExecutor assert isinstance(result, Chain) -@pytest.mark.asyncio -async def test_llm_node_build(basic_graph): +def test_llm_node_build(basic_graph): llm_node = get_node_by_type(basic_graph, LLMVertex) assert llm_node is not None - built_object = await llm_node.build() - assert built_object is not UnbuiltObject() + built_object = llm_node.build() + assert built_object is not None -@pytest.mark.asyncio -async def test_toolkit_node_build(client, openapi_graph): +def test_toolkit_node_build(client, openapi_graph): # Write a file to the disk file_path = "api-with-examples.yaml" with open(file_path, "w") as f: @@ -314,31 +278,36 @@ async def test_toolkit_node_build(client, openapi_graph): toolkit_node = get_node_by_type(openapi_graph, ToolkitVertex) assert toolkit_node is not None - built_object = await toolkit_node.build() - assert built_object is not UnbuiltObject + built_object = toolkit_node.build() + assert built_object is not None # Remove the file os.remove(file_path) assert not Path(file_path).exists() -@pytest.mark.asyncio -async def test_file_tool_node_build(client, openapi_graph): +def test_file_tool_node_build(client, openapi_graph): file_path = "api-with-examples.yaml" with open(file_path, "w") as f: f.write("openapi: 3.0.0") assert Path(file_path).exists() file_tool_node = get_node_by_type(openapi_graph, FileToolVertex) - assert file_tool_node is not UnbuiltObject and file_tool_node is not None - built_object = await file_tool_node.build() - assert built_object is not UnbuiltObject + assert file_tool_node is not None + built_object = file_tool_node.build() + assert built_object is not None # Remove the file os.remove(file_path) assert not Path(file_path).exists() -@pytest.mark.asyncio -async def test_get_result_and_thought(basic_graph): +# def test_wrapper_node_build(openapi_graph): +# wrapper_node = get_node_by_type(openapi_graph, WrapperVertex) +# assert wrapper_node is not None +# built_object = wrapper_node.build() +# assert built_object is not None + + +def test_get_result_and_thought(basic_graph): """Test the get_result_and_thought method""" responses = [ "Final Answer: I am a response", @@ -350,7 +319,7 @@ async def test_get_result_and_thought(basic_graph): assert llm_node is not None llm_node._built_object = FakeListLLM(responses=responses) llm_node._built = True - langchain_object = await basic_graph.build() + langchain_object = basic_graph.build() # assert all nodes are built assert all(node._built for node in basic_graph.vertices) # now build again and check if FakeListLLM was used @@ -370,7 +339,9 @@ def test_find_last_node(grouped_chat_json_flow): def test_ungroup_node(grouped_chat_json_flow): grouped_chat_data = json.loads(grouped_chat_json_flow).get("data") - group_node = grouped_chat_data["nodes"][2] # Assuming the first node is a group node + group_node = grouped_chat_data["nodes"][ + 2 + ] # Assuming the first node is a group node base_flow = copy.deepcopy(grouped_chat_data) ungroup_node(group_node["data"], base_flow) # after ungroup_node is called, the base_flow and grouped_chat_data should be different @@ -422,9 +393,14 @@ def test_process_flow_one_group(one_grouped_chat_json_flow): assert "edges" in processed_flow # Now get the node that has ChatOpenAI in its id - chat_openai_node = next((node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None) + chat_openai_node = next( + (node for node in processed_flow["nodes"] if "ChatOpenAI" in node["id"]), None + ) assert chat_openai_node is not None - assert chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] == "test" + assert ( + chat_openai_node["data"]["node"]["template"]["openai_api_key"]["value"] + == "test" + ) def test_process_flow_vector_store_grouped(vector_store_grouped_json_flow): @@ -471,15 +447,19 @@ def test_update_template(sample_template, sample_nodes): node2_updated = next((n for n in nodes_copy if n["id"] == "node2"), None) node3_updated = next((n for n in nodes_copy if n["id"] == "node3"), None) - assert node1_updated is not None assert node1_updated["data"]["node"]["template"]["some_field"]["show"] is True assert node1_updated["data"]["node"]["template"]["some_field"]["advanced"] is False - assert node1_updated["data"]["node"]["template"]["some_field"]["display_name"] == "Name1" + assert ( + node1_updated["data"]["node"]["template"]["some_field"]["display_name"] + == "Name1" + ) - assert node2_updated is not None assert node2_updated["data"]["node"]["template"]["other_field"]["show"] is False assert node2_updated["data"]["node"]["template"]["other_field"]["advanced"] is True - assert node2_updated["data"]["node"]["template"]["other_field"]["display_name"] == "DisplayName2" + assert ( + node2_updated["data"]["node"]["template"]["other_field"]["display_name"] + == "DisplayName2" + ) # Ensure node3 remains unchanged assert node3_updated == sample_nodes[2] @@ -510,7 +490,9 @@ def test_set_new_target_handle(): "data": { "node": { "flow": True, - "template": {"field_1": {"proxy": {"field": "new_field", "id": "new_id"}}}, + "template": { + "field_1": {"proxy": {"field": "new_field", "id": "new_id"}} + }, } } } @@ -530,34 +512,34 @@ def test_update_source_handle(): "nodes": [{"id": "some_node"}, {"id": "last_node"}], "edges": [{"source": "some_node"}], } - updated_edge = update_source_handle(new_edge, flow_data["nodes"], flow_data["edges"]) + updated_edge = update_source_handle( + new_edge, flow_data["nodes"], flow_data["edges"] + ) assert updated_edge["source"] == "last_node" assert updated_edge["data"]["sourceHandle"]["id"] == "last_node" -@pytest.mark.asyncio -async def test_pickle_graph(json_vector_store): +def test_pickle_graph(json_vector_store): loaded_json = json.loads(json_vector_store) graph = Graph.from_payload(loaded_json) assert isinstance(graph, Graph) - first_result = await graph.build() + first_result = graph.build() assert isinstance(first_result, AgentExecutor) pickled = pickle.dumps(graph) - assert pickled is not UnbuiltObject + assert pickled is not None unpickled = pickle.loads(pickled) - assert unpickled is not UnbuiltObject - result = await unpickled.build() + assert unpickled is not None + result = unpickled.build() assert isinstance(result, AgentExecutor) -@pytest.mark.asyncio -async def test_pickle_each_vertex(json_vector_store): +def test_pickle_each_vertex(json_vector_store): loaded_json = json.loads(json_vector_store) graph = Graph.from_payload(loaded_json) assert isinstance(graph, Graph) - for vertex in graph.vertices: - await vertex.build() + for vertex in graph.nodes: + vertex.build() pickled = pickle.dumps(vertex) - assert pickled is not UnbuiltObject + assert pickled is not None unpickled = pickle.loads(pickled) - assert unpickled is not UnbuiltObject + assert unpickled is not None diff --git a/tests/test_llms_template.py b/tests/test_llms_template.py index 30a15c932..d63d1d016 100644 --- a/tests/test_llms_template.py +++ b/tests/test_llms_template.py @@ -22,7 +22,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["verbose"] == { "required": False, @@ -36,7 +35,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["client"] == { "required": False, @@ -50,7 +48,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["model_name"] == { "required": False, @@ -72,7 +69,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": True, "advanced": False, "info": "", - "fileTypes": [], } # Add more assertions for other properties here assert template["temperature"] == { @@ -88,8 +84,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1}, - "fileTypes": [], } assert template["max_tokens"] == { "required": False, @@ -104,7 +98,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["top_p"] == { "required": False, @@ -119,8 +112,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1}, - "fileTypes": [], } assert template["frequency_penalty"] == { "required": False, @@ -135,8 +126,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1}, - "fileTypes": [], } assert template["presence_penalty"] == { "required": False, @@ -151,8 +140,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1}, - "fileTypes": [], } assert template["n"] == { "required": False, @@ -167,7 +154,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["best_of"] == { "required": False, @@ -182,7 +168,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["model_kwargs"] == { "required": False, @@ -196,7 +181,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": True, "info": "", - "fileTypes": [], } assert template["openai_api_key"] == { "required": False, @@ -212,7 +196,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["batch_size"] == { "required": False, @@ -227,7 +210,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["request_timeout"] == { "required": False, @@ -241,8 +223,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1}, - "fileTypes": [], } assert template["logit_bias"] == { "required": False, @@ -256,7 +236,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["max_retries"] == { "required": False, @@ -264,14 +243,13 @@ def test_openai(client: TestClient, logged_in_headers): "placeholder": "", "show": False, "multiline": False, - "value": 2, + "value": 6, "password": False, "name": "max_retries", "type": "int", "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["streaming"] == { "required": False, @@ -286,7 +264,6 @@ def test_openai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } @@ -312,7 +289,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["client"] == { "required": False, @@ -326,7 +302,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["model_name"] == { "required": False, @@ -334,11 +309,10 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "placeholder": "", "show": True, "multiline": False, - "value": "gpt-4-1106-preview", + "value": "gpt-3.5-turbo-0613", "password": False, "options": [ "gpt-4-1106-preview", - "gpt-4-vision-preview", "gpt-4", "gpt-4-32k", "gpt-3.5-turbo", @@ -349,7 +323,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": True, "advanced": False, "info": "", - "fileTypes": [], } assert template["temperature"] == { "required": False, @@ -364,8 +337,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1}, - "fileTypes": [], } assert template["model_kwargs"] == { "required": False, @@ -379,7 +350,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": True, "info": "", - "fileTypes": [], } assert template["openai_api_key"] == { "required": False, @@ -395,7 +365,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["request_timeout"] == { "required": False, @@ -409,8 +378,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "rangeSpec": {"max": 1.0, "min": -1.0, "step": 0.1}, - "fileTypes": [], } assert template["max_retries"] == { "required": False, @@ -418,14 +385,13 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "placeholder": "", "show": False, "multiline": False, - "value": 2, + "value": 6, "password": False, "name": "max_retries", "type": "int", "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["streaming"] == { "required": False, @@ -440,7 +406,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["n"] == { "required": False, @@ -455,7 +420,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["max_tokens"] == { @@ -470,7 +434,6 @@ def test_chat_open_ai(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["_type"] == "ChatOpenAI" assert ( diff --git a/tests/test_loading.py b/tests/test_loading.py index eb3987e93..8a797579f 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -1,10 +1,9 @@ import json - import pytest from langchain.chains.base import Chain +from langflow.processing.process import load_flow_from_json from langflow.graph import Graph -from langflow.processing.load import load_flow_from_json -from langflow.utils.payload import get_root_vertex +from langflow.utils.payload import get_root_node def test_load_flow_from_json(): @@ -16,22 +15,21 @@ def test_load_flow_from_json(): def test_load_flow_from_json_with_tweaks(): """Test loading a flow from a json file and applying tweaks""" - tweaks = {"dndnode_82": {"model_name": "test model"}} + tweaks = {"dndnode_82": {"model_name": "gpt-3.5-turbo-16k-0613"}} loaded = load_flow_from_json(pytest.BASIC_EXAMPLE_PATH, tweaks=tweaks) assert loaded is not None assert isinstance(loaded, Chain) - assert loaded.llm.model_name == "test model" + assert loaded.llm.model_name == "gpt-3.5-turbo-16k-0613" -def test_get_root_vertex(): +def test_get_root_node(): with open(pytest.BASIC_EXAMPLE_PATH, "r") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] nodes = data_graph["nodes"] edges = data_graph["edges"] graph = Graph(nodes, edges) - root = get_root_vertex(graph) + root = get_root_node(graph) assert root is not None assert hasattr(root, "id") assert hasattr(root, "data") - assert hasattr(root, "data") diff --git a/tests/test_login.py b/tests/test_login.py index 399c7b761..f505f4100 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -1,5 +1,5 @@ from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service +from langflow.services.getters import get_db_service import pytest from langflow.services.database.models.user import User from langflow.services.auth.utils import get_password_hash @@ -9,7 +9,9 @@ from langflow.services.auth.utils import get_password_hash def test_user(): return User( username="testuser", - password=get_password_hash("testpassword"), # Assuming password needs to be hashed + password=get_password_hash( + "testpassword" + ), # Assuming password needs to be hashed is_active=True, is_superuser=False, ) @@ -21,13 +23,17 @@ def test_login_successful(client, test_user): session.add(test_user) session.commit() - response = client.post("api/v1/login", data={"username": "testuser", "password": "testpassword"}) + response = client.post( + "api/v1/login", data={"username": "testuser", "password": "testpassword"} + ) assert response.status_code == 200 assert "access_token" in response.json() def test_login_unsuccessful_wrong_username(client): - response = client.post("api/v1/login", data={"username": "wrongusername", "password": "testpassword"}) + response = client.post( + "api/v1/login", data={"username": "wrongusername", "password": "testpassword"} + ) assert response.status_code == 401 assert response.json()["detail"] == "Incorrect username or password" @@ -37,6 +43,8 @@ def test_login_unsuccessful_wrong_password(client, test_user, session): session.add(test_user) session.commit() - response = client.post("api/v1/login", data={"username": "testuser", "password": "wrongpassword"}) + response = client.post( + "api/v1/login", data={"username": "testuser", "password": "wrongpassword"} + ) assert response.status_code == 401 assert response.json()["detail"] == "Incorrect username or password" diff --git a/tests/test_process.py b/tests/test_process.py index c8e4ec9cc..0588800dc 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,6 +1,5 @@ -import pytest from langflow.processing.process import process_tweaks -from langflow.services.deps import get_session_service +from langflow.services.getters import get_session_service def test_no_tweaks(): @@ -198,42 +197,39 @@ def test_tweak_not_in_template(): assert result == graph_data -@pytest.mark.asyncio -async def test_load_langchain_object_with_cached_session(client, basic_graph_data): +def test_load_langchain_object_with_cached_session(client, basic_graph_data): # Provide a non-existent session_id session_service = get_session_service() session_id1 = "non-existent-session-id" - graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data) + graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data) # Use the new session_id to get the langchain_object again - graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data) + graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data) assert graph1 == graph2 assert artifacts1 == artifacts2 -@pytest.mark.asyncio -async def test_load_langchain_object_with_no_cached_session(client, basic_graph_data): +def test_load_langchain_object_with_no_cached_session(client, basic_graph_data): # Provide a non-existent session_id session_service = get_session_service() session_id1 = "non-existent-session-id" session_id = session_service.build_key(session_id1, basic_graph_data) - graph1, artifacts1 = await session_service.load_session(session_id, basic_graph_data) + graph1, artifacts1 = session_service.load_session(session_id, basic_graph_data) # Clear the cache session_service.clear_session(session_id) # Use the new session_id to get the langchain_object again - graph2, artifacts2 = await session_service.load_session(session_id, basic_graph_data) + graph2, artifacts2 = session_service.load_session(session_id, basic_graph_data) assert id(graph1) != id(graph2) # Since the cache was cleared, objects should be different -@pytest.mark.asyncio -async def test_load_langchain_object_without_session_id(client, basic_graph_data): +def test_load_langchain_object_without_session_id(client, basic_graph_data): # Provide a non-existent session_id session_service = get_session_service() session_id1 = None - graph1, artifacts1 = await session_service.load_session(session_id1, basic_graph_data) + graph1, artifacts1 = session_service.load_session(session_id1, basic_graph_data) # Use the new session_id to get the langchain_object again - graph2, artifacts2 = await session_service.load_session(session_id1, basic_graph_data) + graph2, artifacts2 = session_service.load_session(session_id1, basic_graph_data) assert graph1 == graph2 diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index ca2bedb13..bc2e935e9 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -1,6 +1,5 @@ from fastapi.testclient import TestClient - -from langflow.services.deps import get_settings_service +from langflow.services.getters import get_settings_service def test_prompts_settings(client: TestClient, logged_in_headers): @@ -32,7 +31,6 @@ def test_prompt_template(client: TestClient, logged_in_headers): "list": True, "advanced": False, "info": "", - "fileTypes": [], } assert template["output_parser"] == { @@ -47,7 +45,6 @@ def test_prompt_template(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["partial_variables"] == { @@ -62,7 +59,6 @@ def test_prompt_template(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["template"] == { @@ -77,7 +73,6 @@ def test_prompt_template(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["template_format"] == { @@ -93,7 +88,6 @@ def test_prompt_template(client: TestClient, logged_in_headers): "list": False, "advanced": False, "info": "", - "fileTypes": [], } assert template["validate_template"] == { @@ -102,12 +96,11 @@ def test_prompt_template(client: TestClient, logged_in_headers): "placeholder": "", "show": False, "multiline": False, - "value": False, + "value": True, "password": False, "name": "validate_template", "type": "bool", "list": False, "advanced": False, "info": "", - "fileTypes": [], } diff --git a/tests/test_setup_superuser.py b/tests/test_setup_superuser.py index 03d3882fb..8cdcfc0c8 100644 --- a/tests/test_setup_superuser.py +++ b/tests/test_setup_superuser.py @@ -1,11 +1,17 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch, MagicMock +from langflow.services.database.models.user.user import User +from langflow.services.settings.constants import ( + DEFAULT_SUPERUSER, + DEFAULT_SUPERUSER_PASSWORD, +) +from langflow.services.utils import ( + teardown_superuser, +) -from langflow.services.settings.constants import DEFAULT_SUPERUSER, DEFAULT_SUPERUSER_PASSWORD -from langflow.services.utils import teardown_superuser -# @patch("langflow.services.deps.get_session") +# @patch("langflow.services.getters.get_session") # @patch("langflow.services.utils.create_super_user") -# @patch("langflow.services.deps.get_settings_service") +# @patch("langflow.services.getters.get_settings_service") # # @patch("langflow.services.utils.verify_password") # def test_setup_superuser( # mock_get_session, mock_create_super_user, mock_get_settings_service @@ -86,9 +92,11 @@ from langflow.services.utils import teardown_superuser # assert str(actual_expr) == str(expected_expr) -@patch("langflow.services.deps.get_settings_service") -@patch("langflow.services.deps.get_session") -def test_teardown_superuser_default_superuser(mock_get_session, mock_get_settings_service): +@patch("langflow.services.getters.get_settings_service") +@patch("langflow.services.getters.get_session") +def test_teardown_superuser_default_superuser( + mock_get_session, mock_get_settings_service +): mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = True mock_settings_service.auth_settings.SUPERUSER = DEFAULT_SUPERUSER @@ -103,12 +111,20 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting teardown_superuser(mock_settings_service, mock_session) - mock_session.query.assert_not_called() + mock_session.query.assert_called_once_with(User) + actual_expr = mock_session.query.return_value.filter.call_args[0][0] + expected_expr = User.username == DEFAULT_SUPERUSER + + assert str(actual_expr) == str(expected_expr) + mock_session.delete.assert_called_once_with(mock_user) + mock_session.commit.assert_called_once() -@patch("langflow.services.deps.get_settings_service") -@patch("langflow.services.deps.get_session") -def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service): +@patch("langflow.services.getters.get_settings_service") +@patch("langflow.services.getters.get_session") +def test_teardown_superuser_no_default_superuser( + mock_get_session, mock_get_settings_service +): ADMIN_USER_NAME = "admin_user" mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = False @@ -119,11 +135,11 @@ def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_sett mock_session = MagicMock() mock_user = MagicMock() mock_user.is_superuser = False - mock_session.exec.return_value.filter.return_value.first.return_value = mock_user + mock_session.query.return_value.filter.return_value.first.return_value = mock_user mock_get_session.return_value = [mock_session] teardown_superuser(mock_settings_service, mock_session) - mock_session.exec.assert_called_once() + mock_session.query.assert_not_called() mock_session.delete.assert_not_called() mock_session.commit.assert_not_called() diff --git a/tests/test_template.py b/tests/test_template.py index 6dcb789ee..f0b55d3c0 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -2,8 +2,6 @@ import importlib from typing import Dict, List, Optional import pytest -from pydantic import BaseModel - from langflow.utils.constants import CHAT_OPENAI_MODELS, OPENAI_MODELS from langflow.utils.util import ( build_template_from_class, @@ -12,6 +10,7 @@ from langflow.utils.util import ( get_base_classes, get_default_factory, ) +from pydantic import BaseModel # Dummy classes for testing purposes @@ -66,9 +65,11 @@ def test_build_template_from_function(): assert "base_classes" in result # Test with add_function=True - result_with_function = build_template_from_function("ExampleClass1", type_to_loader_dict, add_function=True) + result_with_function = build_template_from_function( + "ExampleClass1", type_to_loader_dict, add_function=True + ) assert result_with_function is not None - assert "Callable" in result_with_function["base_classes"] + assert "function" in result_with_function["base_classes"] # Test with invalid name with pytest.raises(ValueError, match=r".* not found"): @@ -236,7 +237,7 @@ def test_format_dict(): "password": False, "multiline": False, "options": CHAT_OPENAI_MODELS, - "value": "gpt-4-1106-preview", + "value": "gpt-4-1106-preview"", }, } assert format_dict(input_dict, "OpenAI") == expected_output_openai diff --git a/tests/test_user.py b/tests/test_user.py index d7e814845..49962c8d1 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -1,12 +1,11 @@ from datetime import datetime - -import pytest - from langflow.services.auth.utils import create_super_user, get_password_hash -from langflow.services.database.models.user import UserUpdate -from langflow.services.database.models.user.model import User + +from langflow.services.database.models.user.user import User from langflow.services.database.utils import session_getter -from langflow.services.deps import get_db_service, get_settings_service +from langflow.services.getters import get_db_service, get_settings_service +import pytest +from langflow.services.database.models.user import UserUpdate @pytest.fixture @@ -86,11 +85,15 @@ def test_deactivated_user_cannot_access(client, deactivated_user, logged_in_head assert response.json()["detail"] == "The user doesn't have enough privileges" -def test_data_consistency_after_update(client, active_user, logged_in_headers, super_user_headers): +def test_data_consistency_after_update( + client, active_user, logged_in_headers, super_user_headers +): user_id = active_user.id update_data = UserUpdate(is_active=False) - response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=super_user_headers) + response = client.patch( + f"/api/v1/users/{user_id}", json=update_data.dict(), headers=super_user_headers + ) assert response.status_code == 200, response.json() # Fetch the updated user from the database @@ -117,7 +120,7 @@ def test_inactive_user(client): username="inactiveuser", password=get_password_hash("testpassword"), is_active=False, - last_login_at=datetime.now(), + last_login_at="2023-01-01T00:00:00", # Set to a valid datetime string ) session.add(user) session.commit() @@ -164,13 +167,17 @@ def test_patch_user(client, active_user, logged_in_headers): username="newname", ) - response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers) + response = client.patch( + f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers + ) assert response.status_code == 200, response.json() update_data = UserUpdate( profile_image="new_image", ) - response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers) + response = client.patch( + f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers + ) assert response.status_code == 200, response.json() @@ -182,7 +189,7 @@ def test_patch_reset_password(client, active_user, logged_in_headers): response = client.patch( f"/api/v1/users/{user_id}/reset-password", - json=update_data.model_dump(), + json=update_data.dict(), headers=logged_in_headers, ) assert response.status_code == 200, response.json() @@ -198,13 +205,19 @@ def test_patch_user_wrong_id(client, active_user, logged_in_headers): username="newname", ) - response = client.patch(f"/api/v1/users/{user_id}", json=update_data.model_dump(), headers=logged_in_headers) + response = client.patch( + f"/api/v1/users/{user_id}", json=update_data.dict(), headers=logged_in_headers + ) assert response.status_code == 422, response.json() - json_response = response.json() - detail = json_response["detail"] - assert detail[0]["type"] == "uuid_parsing" - assert detail[0]["loc"] == ["path", "user_id"] - assert detail[0]["input"] == "wrong_id" + assert response.json() == { + "detail": [ + { + "loc": ["path", "user_id"], + "msg": "value is not a valid uuid", + "type": "type_error.uuid", + } + ] + } def test_delete_user(client, test_user, super_user_headers): @@ -218,11 +231,15 @@ def test_delete_user_wrong_id(client, test_user, super_user_headers): user_id = "wrong_id" response = client.delete(f"/api/v1/users/{user_id}", headers=super_user_headers) assert response.status_code == 422 - json_response = response.json() - detail = json_response["detail"] - assert detail[0]["type"] == "uuid_parsing" - assert detail[0]["loc"] == ["path", "user_id"] - assert detail[0]["input"] == "wrong_id" + assert response.json() == { + "detail": [ + { + "loc": ["path", "user_id"], + "msg": "value is not a valid uuid", + "type": "type_error.uuid", + } + ] + } def test_normal_user_cant_delete_user(client, test_user, logged_in_headers): diff --git a/tests/test_vectorstore_template.py b/tests/test_vectorstore_template.py index 3b5c7ed42..9dd131dbc 100644 --- a/tests/test_vectorstore_template.py +++ b/tests/test_vectorstore_template.py @@ -1,5 +1,5 @@ from fastapi.testclient import TestClient -from langflow.services.deps import get_settings_service +from langflow.services.getters import get_settings_service # check that all agents are in settings.agents diff --git a/tests/test_websocket.py b/tests/test_websocket.py index c4c9ee322..5016eb704 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -31,7 +31,9 @@ def test_websocket_endpoint(client: TestClient, active_user, logged_in_headers): # Assuming your websocket_endpoint uses chat_service which caches data from stream_build access_token = logged_in_headers["Authorization"].split(" ")[1] with pytest.raises(WebSocketDisconnect): - with client.websocket_connect(f"api/v1/chat/non_existing_client_id?token={access_token}") as websocket: + with client.websocket_connect( + f"api/v1/chat/non_existing_client_id?token={access_token}" + ) as websocket: websocket.send_json({"type": "test"}) data = websocket.receive_json() assert "Please, build the flow before sending messages" in data["message"]