parent
abcddeeaf1
commit
e1f4041908
3 changed files with 16 additions and 4 deletions
|
|
@ -62,7 +62,12 @@ class Graph:
|
|||
if isinstance(node, ToolkitNode):
|
||||
node.params["llm"] = llm_node
|
||||
# remove invalid nodes
|
||||
self.nodes = [node for node in self.nodes if self._validate_node(node)]
|
||||
self.nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if self._validate_node(node)
|
||||
or (len(self.nodes) == 1 and len(self.edges) == 0)
|
||||
]
|
||||
|
||||
def _validate_node(self, node: Node) -> bool:
|
||||
# All nodes that do not have edges are invalid
|
||||
|
|
|
|||
|
|
@ -191,9 +191,12 @@ def get_result_and_thought_using_graph(langchain_object, message: str):
|
|||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
|
|
|
|||
|
|
@ -33,6 +33,10 @@ def get_root_node(graph):
|
|||
Returns the root node of the template.
|
||||
"""
|
||||
incoming_edges = {edge.source for edge in graph.edges}
|
||||
|
||||
if not incoming_edges and len(graph.nodes) == 1:
|
||||
return graph.nodes[0]
|
||||
|
||||
return next((node for node in graph.nodes if node not in incoming_edges), None)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue