feat: added edge validation
This commit is contained in:
parent
230f0d95e9
commit
cccbeeb4d2
4 changed files with 173 additions and 119 deletions
|
|
@ -10,6 +10,26 @@ class Node:
|
|||
|
||||
def _parse_data(self) -> None:
|
||||
self.data = self._data["data"]
|
||||
# Data dict:
|
||||
# {'type': 'LLMChain', 'node': {'template': {'_type': 'llm_chain', 'memory': {'type': 'BaseMemory', 'required': False, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False, 'value': None}, 'verbose': {'type': 'bool', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': False, 'multiline': False, 'value': False}, 'prompt': {'type': 'BasePromptTemplate', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'llm': {'type': 'BaseLanguageModel', 'required': True, 'placeholder': '', 'list': False, 'show': True, 'password': False, 'multiline': False}, 'output_key': {'type': 'str', 'required': False, 'placeholder': '', 'list': False, 'show': False, 'password': True, 'multiline': False, 'value': 'text'}}, 'description': 'Chain to run queries against LLMs.', 'base_classes': ['Chain']}, 'id': 'dndnode_1', 'value': None}
|
||||
# base_classes are the classes that the node can be cast to
|
||||
self.output = self.data["node"]["base_classes"]
|
||||
template_dict = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
if isinstance(value, dict)
|
||||
}
|
||||
|
||||
self.required_inputs = [
|
||||
template_dict[key]["type"]
|
||||
for key, value in template_dict.items()
|
||||
if value["required"]
|
||||
]
|
||||
self.optional_inputs = [
|
||||
template_dict[key]["type"]
|
||||
for key, value in template_dict.items()
|
||||
if not value["required"]
|
||||
]
|
||||
|
||||
def add_edge(self, edge: "Edge") -> None:
|
||||
self.edges.append(edge)
|
||||
|
|
@ -28,9 +48,23 @@ class Edge:
|
|||
def __init__(self, source: "Node", target: "Node"):
|
||||
self.source: "Node" = source
|
||||
self.target: "Node" = target
|
||||
self.validate_edge()
|
||||
|
||||
def validate_edge(self) -> None:
|
||||
# Validate that the outputs of the source node are valid inputs for the target node
|
||||
self.coming_out = self.source.output
|
||||
self.going_in = self.target.required_inputs + self.target.optional_inputs
|
||||
# Both lists contain strings and sometimes a string contains the value we are looking for
|
||||
# e.g. comgin_out=["Chain"] and going_in=["LLMChain"]
|
||||
# so we need to check if any of the strings in coming_out is in going_in
|
||||
self.valid = any(
|
||||
output in going_in
|
||||
for output in self.coming_out
|
||||
for going_in in self.going_in
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Edge(source={self.source.id}, target={self.target.id})"
|
||||
return f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}, coming_out={self.coming_out}, going_in={self.going_in})"
|
||||
|
||||
|
||||
class Graph:
|
||||
|
|
@ -92,3 +126,12 @@ class Graph:
|
|||
|
||||
def _build_nodes(self) -> List[Node]:
|
||||
return [Node(node) for node in self._nodes]
|
||||
|
||||
def get_children_by_module_type(self, node: Node, module_type: str) -> List[Node]:
|
||||
children = []
|
||||
module_types = [node.data["type"]]
|
||||
if "node" in node.data:
|
||||
module_types += node.data["node"]["base_classes"]
|
||||
if module_type in module_types:
|
||||
children.append(node)
|
||||
return children
|
||||
|
|
|
|||
|
|
@ -72,11 +72,10 @@ def build_json(root: Node, graph: Graph) -> Dict:
|
|||
# Otherwise, recursively build the child nodes
|
||||
children = []
|
||||
for local_node in local_nodes:
|
||||
module_types = [local_node.data["type"]]
|
||||
if "node" in local_node.data:
|
||||
module_types += local_node.data["node"]["base_classes"]
|
||||
if module_type in module_types:
|
||||
children.append(local_node)
|
||||
node_children = graph.get_children_by_module_type(
|
||||
local_node, module_type
|
||||
)
|
||||
children.extend(node_children)
|
||||
|
||||
if value["required"] and not children:
|
||||
raise ValueError(f"No child with type {module_type} found")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue