This commit is contained in:
Gabriel Almeida 2023-03-26 10:01:50 -03:00
commit 405041f9c8
6 changed files with 73 additions and 39 deletions

View file

@ -55,7 +55,7 @@ def list_tools():
tool_params = util.get_tool_params(util.get_tools_dict(tool))
if tool_params and tool_params["name"] in settings.tools or settings.dev:
tools.append(tool_params["name"])
tools.append("BaseTool")
return tools
@ -84,3 +84,15 @@ def list_memories():
for memory in memory_type_to_cls_dict.values()
if memory.__name__ in settings.memories or settings.dev
]
LANGCHAIN_TYPES_DICT = {
k: list_function() for k, list_function in get_type_dict().items()
}
# Now we'll build a dict with Langchain types and ours
ALL_TYPES_DICT = {
**LANGCHAIN_TYPES_DICT,
"Custom": ["Custom Tool", "Python Function"],
}

View file

@ -25,16 +25,26 @@ from langflow.interface.types import get_type_list
from langflow.utils import payload, util
def instantiate_class(module_type: str, base_type: str, params: Dict) -> Any:
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
"""Instantiate class from module type and key, and params"""
class_object = import_by_type(_type=base_type, name=module_type)
class_object = import_by_type(_type=base_type, name=node_type)
if base_type == "agents":
# We need to initialize it differently
allowed_tools = params["allowed_tools"]
llm_chain = params["llm_chain"]
return load_agent_executor(class_object, allowed_tools, llm_chain)
elif base_type == "tools" or module_type != "ZeroShotPrompt":
elif base_type == "tools" or node_type != "ZeroShotPrompt":
return class_object(**params)
elif node_type == "PythonFunction":
# If the node_type is "PythonFunction"
# we need to get the function from the params
# which will be a str containing a python function
# and then we need to compile it and return the function
# as the instance
function_string = params["function"]
if isinstance(function_string, str):
return util.eval_function(function_string)
raise ValueError("Function should be a string")
else:
return ZeroShotAgent.create_prompt(**params, tools=[])

View file

@ -1,3 +1,5 @@
OPENAI_MODELS = [
"text-davinci-003",
"text-davinci-002",
@ -6,3 +8,4 @@ OPENAI_MODELS = [
"text-ada-001",
]
CHAT_OPENAI_MODELS = ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]

View file

@ -1,10 +1,8 @@
import types
from typing import Dict, List, Union
from langflow.interface import listing, loading
from langflow.interface import loading
from langflow.utils import payload, util
LANGCHAIN_TYPES_DICT = {
k: list_function() for k, list_function in listing.get_type_dict().items()
}
from langflow.interface.listing import ALL_TYPES_DICT
class Node:
@ -38,11 +36,11 @@ class Node:
]
template_dict = self.data["node"]["template"]
self.module_type = (
self.node_type = (
self.data["type"] if "Tool" not in self.output else template_dict["_type"]
)
def _build_params(self) -> Dict:
def _build_params(self):
# Some params are required, some are optional
# but most importantly, some params are python base classes
# like str and others are LangChain objects like LLMChain, BasePromptTemplate
@ -68,7 +66,7 @@ class Node:
continue
if value["type"] not in ["str", "bool"]:
# Get the edge that connects to this node
edge = next(
edge: Edge = next(
(
edge
for edge in self.edges
@ -83,21 +81,16 @@ class Node:
# or create a new list if it doesn't
if edge is None and value["required"]:
raise ValueError(
f"Required input {key} for module {self.module_type} is not connected"
f"Required input {key} for module {self.node_type} is not connected"
)
if value["list"]:
if key in params:
params[key].append(edge.source)
else:
params[key] = [edge.source]
else:
if not value["required"] and edge is None:
continue
elif value["required"] or edge is not None:
params[key] = edge.source
else:
if not value["required"] and not value.get("value"):
continue
elif value["required"] or value.get("value"):
params[key] = value["value"]
# Add _type to params
@ -112,7 +105,7 @@ class Node:
# and use the output of that build as the value for the param
# if the value is not a node, then we use the value as the param
# and continue
# Another aspect is that the module_type is the class that we need to import
# Another aspect is that the node_type is the class that we need to import
# and instantiate with these built params
# Build each node in the params dict
@ -120,7 +113,10 @@ class Node:
# Check if Node or list of Nodes
if isinstance(value, Node):
result = value.build()
self.params[key] = result.run if key == "func" else result
# If the key is "func", then we need to use the run method
if key == "func" and not isinstance(result, types.FunctionType):
result = result.run
self.params[key] = result
elif isinstance(value, list) and all(
isinstance(node, Node) for node in value
):
@ -130,12 +126,13 @@ class Node:
# and instantiate it with the params
# and return the instance
instance = None
for base_type, value in LANGCHAIN_TYPES_DICT.items():
for base_type, value in ALL_TYPES_DICT.items():
if base_type == "tools":
value = util.get_tools_dict()
if self.module_type in value:
if self.node_type in value:
instance = loading.instantiate_class(
module_type=self.module_type,
node_type=self.node_type,
base_type=base_type,
params=self.params,
)
@ -260,11 +257,11 @@ class Graph:
def _build_nodes(self) -> List[Node]:
return [Node(node) for node in self._nodes]
def get_children_by_module_type(self, node: Node, module_type: str) -> List[Node]:
def get_children_by_node_type(self, node: Node, node_type: str) -> List[Node]:
children = []
module_types = [node.data["type"]]
node_types = [node.data["type"]]
if "node" in node.data:
module_types += node.data["node"]["base_classes"]
if module_type in module_types:
node_types += node.data["node"]["base_classes"]
if node_type in node_types:
children.append(node)
return children

View file

@ -45,7 +45,7 @@ def build_json(root, graph) -> Dict:
else:
# Otherwise, find all children whose type matches the type
# specified in the template
module_type = root.module_type
node_type = root.node_type
local_nodes = graph.get_nodes_with_target(root)
if len(local_nodes) == 1:
@ -58,25 +58,23 @@ def build_json(root, graph) -> Dict:
if key == "_type":
continue
module_type = value["type"]
node_type = value["type"]
if "value" in value and value["value"] is not None:
# If the value is specified, use it
value = value["value"]
elif "dict" in module_type:
elif "dict" in node_type:
# If the value is a dictionary, create an empty dictionary
value = {}
else:
# Otherwise, recursively build the child nodes
children = []
for local_node in local_nodes:
node_children = graph.get_children_by_module_type(
local_node, module_type
)
node_children = graph.get_children_by_node_type(local_node, node_type)
children.extend(node_children)
if value["required"] and not children:
raise ValueError(f"No child with type {module_type} found")
raise ValueError(f"No child with type {node_type} found")
values = [build_json(child, graph) for child in children]
value = list(values) if value["list"] else next(iter(values), None)
final_dict[key] = value

View file

@ -3,6 +3,7 @@ import importlib
import inspect
import re
from typing import Dict, Optional
import types
from langchain.agents.load_tools import (
_BASE_TOOLS,
@ -122,6 +123,21 @@ def build_template_from_class(
}
def eval_function(function_string: str):
# Create an empty dictionary to serve as a separate namespace
namespace: Dict = {}
# Execute the code string in the new namespace
exec(function_string, namespace)
function_object = next(
(obj for name, obj in namespace.items() if isinstance(obj, types.FunctionType)),
None,
)
if function_object is None:
raise ValueError("Function string does not contain a function")
return function_object
def get_base_classes(cls):
"""Get the base classes of a class.
These are used to determine the output of the nodes.
@ -165,9 +181,6 @@ class GenericTool(Tool):
super().__init__(name=name, description=description, func=func)
# from langchain.llms.base import BaseLLM
def get_base_tool(name, description, func: callable) -> BaseTool:
return GenericTool(func=func, name="Generic Tool", description="Bacon")
@ -205,6 +218,7 @@ def get_tool_params(func, **kwargs):
tool_params["description"] = ast.literal_eval(
keyword.value
)
return tool_params
return {
"name": ast.literal_eval(tool.args[0]),