feat: added edge validation

This commit is contained in:
Gabriel Almeida 2023-03-24 18:48:03 -03:00
commit cccbeeb4d2
4 changed files with 173 additions and 119 deletions

View file

@ -10,6 +10,26 @@ class Node:
def _parse_data(self) -> None:
self.data = self._data["data"]
# Data dict:
# {'type': 'LLMChain', 'node': {'template': {'_type': 'llm_chain', 'memory': {'type': 'BaseMemory', 'required': False, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False, 'value': None}, 'verbose': {'type': 'bool', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': False, 'multiline': False, 'value': False}, 'prompt': {'type': 'BasePromptTemplate', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'llm': {'type': 'BaseLanguageModel', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'output_key': {'type': 'str', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': True, 'multiline': False, 'value': 'text'}}, 'description': 'Chain to run queries against LLMs.', 'base_classes': ['Chain']}, 'id': 'dndnode_1', 'value': None}
# base_classes are the classes that the node can be cast to
self.output = self.data["node"]["base_classes"]
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
if isinstance(value, dict)
}
self.required_inputs = [
template_dict[key]["type"]
for key, value in template_dict.items()
if value["required"]
]
self.optional_inputs = [
template_dict[key]["type"]
for key, value in template_dict.items()
if not value["required"]
]
def add_edge(self, edge: "Edge") -> None:
self.edges.append(edge)
@ -28,9 +48,23 @@ class Edge:
def __init__(self, source: "Node", target: "Node"):
self.source: "Node" = source
self.target: "Node" = target
self.validate_edge()
def validate_edge(self) -> None:
# Validate that the outputs of the source node are valid inputs for the target node
self.coming_out = self.source.output
self.going_in = self.target.required_inputs + self.target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are looking for
# e.g. comgin_out=["Chain"] and going_in=["LLMChain"]
# so we need to check if any of the strings in coming_out is in going_in
self.valid = any(
output in going_in
for output in self.coming_out
for going_in in self.going_in
)
def __repr__(self) -> str:
return f"Edge(source={self.source.id}, target={self.target.id})"
return f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}, coming_out={self.coming_out}, going_in={self.going_in})"
class Graph:
@ -92,3 +126,12 @@ class Graph:
def _build_nodes(self) -> List[Node]:
return [Node(node) for node in self._nodes]
def get_children_by_module_type(self, node: Node, module_type: str) -> List[Node]:
children = []
module_types = [node.data["type"]]
if "node" in node.data:
module_types += node.data["node"]["base_classes"]
if module_type in module_types:
children.append(node)
return children

View file

@ -72,11 +72,10 @@ def build_json(root: Node, graph: Graph) -> Dict:
# Otherwise, recursively build the child nodes
children = []
for local_node in local_nodes:
module_types = [local_node.data["type"]]
if "node" in local_node.data:
module_types += local_node.data["node"]["base_classes"]
if module_type in module_types:
children.append(local_node)
node_children = graph.get_children_by_module_type(
local_node, module_type
)
children.extend(node_children)
if value["required"] and not children:
raise ValueError(f"No child with type {module_type} found")