From 7dbda097f55d07fe5aa9a980cebe51b9d628540a Mon Sep 17 00:00:00 2001 From: Gabriel Almeida Date: Sun, 26 Mar 2023 01:15:01 -0300 Subject: [PATCH] feat: complex example test passing --- src/backend/langflow/interface/loading.py | 20 ++++++++++++- src/backend/langflow/utils/graph.py | 34 +++++++---------------- tests/data/complex_example.json | 2 +- tests/test_graph.py | 8 ++++++ tests/test_loading.py | 2 +- 5 files changed, 39 insertions(+), 27 deletions(-) diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 59cda2b8a..0d304e3c3 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -17,8 +17,26 @@ from langchain.agents.load_tools import ( _EXTRA_LLM_TOOLS, _EXTRA_OPTIONAL_TOOLS, ) -from langflow.utils.graph import Graph from langchain.agents import agent as agent_module +from langflow.utils.graph import Graph + +from langflow.interface.importing import import_by_type + +from langchain.agents import ZeroShotAgent + + +def instantiate_class(module_type: str, base_type: str, params: Dict) -> Any: + """Instantiate class from module type and key, and params""" + class_object = import_by_type(_type=base_type, name=module_type) + if base_type == "agents": + # We need to initialize it differently + allowed_tools = params["allowed_tools"] + llm_chain = params["llm_chain"] + return load_agent_executor(class_object, allowed_tools, llm_chain) + elif base_type == "tools" or module_type != "ZeroShotPrompt": + return class_object(**params) + else: + return ZeroShotAgent.create_prompt(**params, tools=[]) def load_flow_from_json(path: str): diff --git a/src/backend/langflow/utils/graph.py b/src/backend/langflow/utils/graph.py index 0589ee8f5..601f626a0 100644 --- a/src/backend/langflow/utils/graph.py +++ b/src/backend/langflow/utils/graph.py @@ -1,6 +1,5 @@ from typing import Dict, List, Union -from langflow.interface import listing -from langflow.interface.importing import import_by_type +from langflow.interface import listing, loading from langflow.utils import payload, util LANGCHAIN_TYPES_DICT = { @@ -105,8 +104,6 @@ class Node: self.params = params def build(self): - from langflow.interface.loading import load_agent_executor - # The params dict is used to build the module # it contains values and keys that point to nodes which # have their own params dict @@ -122,8 +119,8 @@ class Node: for key, value in self.params.items(): # Check if Node or list of Nodes if isinstance(value, Node): - self.params[key] = value.build() - + result = value.build() + self.params[key] = result.run if key == "func" else result elif isinstance(value, list) and all( isinstance(node, Node) for node in value ): @@ -133,26 +130,15 @@ class Node: # and instantiate it with the params # and return the instance instance = None - for key, value in LANGCHAIN_TYPES_DICT.items(): - if key == "tools": + for base_type, value in LANGCHAIN_TYPES_DICT.items(): + if base_type == "tools": value = util.get_tools_dict() if self.module_type in value: - class_object = import_by_type(_type=key, name=self.module_type) - if key == "agents": - # We need to initialize it differently - allowed_tools = self.params["allowed_tools"] - llm_chain = self.params["llm_chain"] - instance = load_agent_executor( - class_object, allowed_tools, llm_chain - ) - elif key == "tools": - instance = class_object(**self.params) - elif self.module_type == "ZeroShotPrompt": - from langchain.agents import ZeroShotAgent - - instance = ZeroShotAgent.create_prompt(**self.params, tools=[]) - else: - instance = class_object(**self.params) + instance = loading.instantiate_class( + module_type=self.module_type, + base_type=base_type, + params=self.params, + ) break return instance diff --git a/tests/data/complex_example.json b/tests/data/complex_example.json index df11acf76..9d478c687 100644 --- a/tests/data/complex_example.json +++ b/tests/data/complex_example.json @@ -149,7 +149,7 @@ "show": true, "password": true, "multiline": false, - "value": null + "value": "sk-" }, "batch_size": { "type": "int", diff --git a/tests/test_graph.py b/tests/test_graph.py index e5a72e949..bbcaffa85 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -281,3 +281,11 @@ def test_build(): agent = graph.build() # The agent should be a AgentExecutor assert isinstance(agent, AgentExecutor) + + # Now we test the complex example + graph = get_graph(basic=False) + assert isinstance(graph, Graph) + # Now we test the build method + agent = graph.build() + # The agent should be a AgentExecutor + assert isinstance(agent, AgentExecutor) diff --git a/tests/test_loading.py b/tests/test_loading.py index 96a530921..783474886 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -2,7 +2,7 @@ import json from langchain import LLMChain, OpenAI from langflow.utils.graph import Graph import pytest -from pathlib import Path + from langflow import load_flow_from_json from langflow.interface.loading import extract_json from langflow.utils.payload import get_root_node, build_json