🔨 refactor(base.py): refactor the _build method in the Vertex class for better readability and maintainability

🐛 fix(base.py): fix a bug where the built object is not updated correctly in the Vertex class
 feat(base.py): add validation to check if the built object is None in the Vertex class
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-07-09 18:09:53 -03:00
commit d8dee14ed5

View file

@ -123,89 +123,119 @@ class Vertex:
self.params = params
def _build(self):
# The params dict is used to build the module
# it contains values and keys that point to nodes which
# have their own params dict
# When build is called, we iterate through the params dict
# and if the value is a node, we call build on that node
# and use the output of that build as the value for the param
# if the value is not a node, then we use the value as the param
# and continue
# Another aspect is that the node_type is the class that we need to import
# and instantiate with these built params
"""
Initiate the build process.
"""
logger.debug(f"Building {self.vertex_type}")
# Build each node in the params dict
self._build_each_node_in_params_dict()
self._get_and_instantiate_class()
self._validate_built_object()
self._built = True
def _build_each_node_in_params_dict(self):
"""
Iterates over each node in the params dictionary and builds it.
"""
for key, value in self.params.copy().items():
# Check if Node or list of Nodes and not self
# to avoid recursion
if isinstance(value, Vertex):
if self._is_node(value):
if value == self:
del self.params[key]
continue
result = value.build()
# If the key is "func", then we need to use the run method
if key == "func":
if not isinstance(result, types.FunctionType):
# func can be
# PythonFunction(code='\ndef upper_case(text: str) -> str:\n return text.upper()\n')
# so we need to check if there is an attribute called run
if hasattr(result, "run"):
result = result.run # type: ignore
elif hasattr(result, "get_function"):
result = result.get_function() # type: ignore
elif inspect.iscoroutinefunction(result):
self.params["coroutine"] = result
else:
# turn result which is a function into a coroutine
# so that it can be awaited
self.params["coroutine"] = sync_to_async(result)
if isinstance(result, list):
# If the result is a list, then we need to extend the list
# with the result but first check if the key exists
# if it doesn't, then we need to create a new list
if isinstance(self.params[key], list):
self.params[key].extend(result)
self._build_node_and_update_params(key, value)
elif isinstance(value, list) and self._is_list_of_nodes(value):
self._build_list_of_nodes_and_update_params(key, value)
self.params[key] = result
elif isinstance(value, list) and all(
isinstance(node, Vertex) for node in value
):
self.params[key] = []
for node in value:
built = node.build()
if isinstance(built, list):
self.params[key].extend(built)
else:
self.params[key].append(built)
def _is_node(self, value):
"""
Checks if the provided value is an instance of Vertex.
"""
return isinstance(value, Vertex)
# Get the class from LANGCHAIN_TYPES_DICT
# and instantiate it with the params
# and return the instance
def _is_list_of_nodes(self, value):
"""
Checks if the provided value is a list of Vertex instances.
"""
return all(self._is_node(node) for node in value)
def _build_node_and_update_params(self, key, node):
"""
Builds a given node and updates the params dictionary accordingly.
"""
result = node.build()
self._handle_func(key, result)
if isinstance(result, list):
self._extend_params_list_with_result(key, result)
self.params[key] = result
def _build_list_of_nodes_and_update_params(self, key, nodes):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
self.params[key] = []
for node in nodes:
built = node.build()
if isinstance(built, list):
self.params[key].extend(built)
else:
self.params[key].append(built)
def _handle_func(self, key, result):
"""
Handles 'func' key by checking if the result is a function and setting it as coroutine.
"""
if key == "func":
if not isinstance(result, types.FunctionType):
if hasattr(result, "run"):
result = result.run # type: ignore
elif hasattr(result, "get_function"):
result = result.get_function() # type: ignore
elif inspect.iscoroutinefunction(result):
self.params["coroutine"] = result
else:
self.params["coroutine"] = sync_to_async(result)
def _extend_params_list_with_result(self, key, result):
"""
Extends a list in the params dictionary with the given result if it exists.
"""
if isinstance(self.params[key], list):
self.params[key].extend(result)
def _get_and_instantiate_class(self):
"""
Gets the class from a dictionary and instantiates it with the params.
"""
if self.base_type is None:
raise ValueError(f"Base type for node {self.vertex_type} not found")
try:
if self.base_type is None:
raise ValueError(f"Base type for node {self.vertex_type} not found")
result = loading.instantiate_class(
node_type=self.vertex_type,
base_type=self.base_type,
params=self.params,
)
# Result could be the _built_object or
# (_built_object, dict) tuple
if isinstance(result, tuple):
self._built_object, self.artifacts = result
else:
self._built_object = result
self._update_built_object_and_artifacts(result)
except Exception as exc:
raise ValueError(
f"Error building node {self.vertex_type}: {str(exc)}"
) from exc
def _update_built_object_and_artifacts(self, result):
"""
Updates the built object and its artifacts.
"""
if isinstance(result, tuple):
self._built_object, self.artifacts = result
else:
self._built_object = result
def _validate_built_object(self):
"""
Checks if the built object is None and raises a ValueError if so.
"""
if self._built_object is None:
raise ValueError(f"Node type {self.vertex_type} not found")
self._built = True
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()