🔨 refactor(base.py): refactor the _build method in the Vertex class for better readability and maintainability
🐛 fix(base.py): fix a bug where the built object is not updated correctly in the Vertex class ✨ feat(base.py): add validation to check if the built object is None in the Vertex class
This commit is contained in:
parent
008a1c4079
commit
d8dee14ed5
1 changed files with 91 additions and 61 deletions
|
|
@ -123,89 +123,119 @@ class Vertex:
|
|||
self.params = params
|
||||
|
||||
def _build(self):
|
||||
# The params dict is used to build the module
|
||||
# it contains values and keys that point to nodes which
|
||||
# have their own params dict
|
||||
# When build is called, we iterate through the params dict
|
||||
# and if the value is a node, we call build on that node
|
||||
# and use the output of that build as the value for the param
|
||||
# if the value is not a node, then we use the value as the param
|
||||
# and continue
|
||||
# Another aspect is that the node_type is the class that we need to import
|
||||
# and instantiate with these built params
|
||||
"""
|
||||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.vertex_type}")
|
||||
# Build each node in the params dict
|
||||
self._build_each_node_in_params_dict()
|
||||
self._get_and_instantiate_class()
|
||||
self._validate_built_object()
|
||||
|
||||
self._built = True
|
||||
|
||||
def _build_each_node_in_params_dict(self):
|
||||
"""
|
||||
Iterates over each node in the params dictionary and builds it.
|
||||
"""
|
||||
for key, value in self.params.copy().items():
|
||||
# Check if Node or list of Nodes and not self
|
||||
# to avoid recursion
|
||||
if isinstance(value, Vertex):
|
||||
if self._is_node(value):
|
||||
if value == self:
|
||||
del self.params[key]
|
||||
continue
|
||||
result = value.build()
|
||||
# If the key is "func", then we need to use the run method
|
||||
if key == "func":
|
||||
if not isinstance(result, types.FunctionType):
|
||||
# func can be
|
||||
# PythonFunction(code='\ndef upper_case(text: str) -> str:\n return text.upper()\n')
|
||||
# so we need to check if there is an attribute called run
|
||||
if hasattr(result, "run"):
|
||||
result = result.run # type: ignore
|
||||
elif hasattr(result, "get_function"):
|
||||
result = result.get_function() # type: ignore
|
||||
elif inspect.iscoroutinefunction(result):
|
||||
self.params["coroutine"] = result
|
||||
else:
|
||||
# turn result which is a function into a coroutine
|
||||
# so that it can be awaited
|
||||
self.params["coroutine"] = sync_to_async(result)
|
||||
if isinstance(result, list):
|
||||
# If the result is a list, then we need to extend the list
|
||||
# with the result but first check if the key exists
|
||||
# if it doesn't, then we need to create a new list
|
||||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
self._build_node_and_update_params(key, value)
|
||||
elif isinstance(value, list) and self._is_list_of_nodes(value):
|
||||
self._build_list_of_nodes_and_update_params(key, value)
|
||||
|
||||
self.params[key] = result
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(node, Vertex) for node in value
|
||||
):
|
||||
self.params[key] = []
|
||||
for node in value:
|
||||
built = node.build()
|
||||
if isinstance(built, list):
|
||||
self.params[key].extend(built)
|
||||
else:
|
||||
self.params[key].append(built)
|
||||
def _is_node(self, value):
|
||||
"""
|
||||
Checks if the provided value is an instance of Vertex.
|
||||
"""
|
||||
return isinstance(value, Vertex)
|
||||
|
||||
# Get the class from LANGCHAIN_TYPES_DICT
|
||||
# and instantiate it with the params
|
||||
# and return the instance
|
||||
def _is_list_of_nodes(self, value):
|
||||
"""
|
||||
Checks if the provided value is a list of Vertex instances.
|
||||
"""
|
||||
return all(self._is_node(node) for node in value)
|
||||
|
||||
def _build_node_and_update_params(self, key, node):
|
||||
"""
|
||||
Builds a given node and updates the params dictionary accordingly.
|
||||
"""
|
||||
result = node.build()
|
||||
self._handle_func(key, result)
|
||||
if isinstance(result, list):
|
||||
self._extend_params_list_with_result(key, result)
|
||||
self.params[key] = result
|
||||
|
||||
def _build_list_of_nodes_and_update_params(self, key, nodes):
|
||||
"""
|
||||
Iterates over a list of nodes, builds each and updates the params dictionary.
|
||||
"""
|
||||
self.params[key] = []
|
||||
for node in nodes:
|
||||
built = node.build()
|
||||
if isinstance(built, list):
|
||||
self.params[key].extend(built)
|
||||
else:
|
||||
self.params[key].append(built)
|
||||
|
||||
def _handle_func(self, key, result):
|
||||
"""
|
||||
Handles 'func' key by checking if the result is a function and setting it as coroutine.
|
||||
"""
|
||||
if key == "func":
|
||||
if not isinstance(result, types.FunctionType):
|
||||
if hasattr(result, "run"):
|
||||
result = result.run # type: ignore
|
||||
elif hasattr(result, "get_function"):
|
||||
result = result.get_function() # type: ignore
|
||||
elif inspect.iscoroutinefunction(result):
|
||||
self.params["coroutine"] = result
|
||||
else:
|
||||
self.params["coroutine"] = sync_to_async(result)
|
||||
|
||||
def _extend_params_list_with_result(self, key, result):
|
||||
"""
|
||||
Extends a list in the params dictionary with the given result if it exists.
|
||||
"""
|
||||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
|
||||
def _get_and_instantiate_class(self):
|
||||
"""
|
||||
Gets the class from a dictionary and instantiates it with the params.
|
||||
"""
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.vertex_type} not found")
|
||||
try:
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.vertex_type} not found")
|
||||
result = loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
params=self.params,
|
||||
)
|
||||
# Result could be the _built_object or
|
||||
# (_built_object, dict) tuple
|
||||
if isinstance(result, tuple):
|
||||
self._built_object, self.artifacts = result
|
||||
else:
|
||||
self._built_object = result
|
||||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Error building node {self.vertex_type}: {str(exc)}"
|
||||
) from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result):
|
||||
"""
|
||||
Updates the built object and its artifacts.
|
||||
"""
|
||||
if isinstance(result, tuple):
|
||||
self._built_object, self.artifacts = result
|
||||
else:
|
||||
self._built_object = result
|
||||
|
||||
def _validate_built_object(self):
|
||||
"""
|
||||
Checks if the built object is None and raises a ValueError if so.
|
||||
"""
|
||||
if self._built_object is None:
|
||||
raise ValueError(f"Node type {self.vertex_type} not found")
|
||||
|
||||
self._built = True
|
||||
|
||||
def build(self, force: bool = False) -> Any:
|
||||
if not self._built or force:
|
||||
self._build()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue