🔨 refactor(types.py): reorder class definitions to match the order of their usage in the code

The order of the class definitions in the file has been changed to match the order of their usage in the code. This improves the readability of the code and makes it easier to understand the relationships between the classes. No functionality has been changed.
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-06-02 14:21:38 -03:00
commit df3cdb90b7

View file

@ -8,13 +8,13 @@ class AgentVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="agents")
self.tools: List[Union[ToolVertex, ToolkitVertex]] = []
self.tools: List[Union[ToolkitVertex, ToolVertex]] = []
self.chains: List[ChainVertex] = []
def _set_tools_and_chains(self) -> None:
for edge in self.edges:
source_node = edge.source
if isinstance(source_node, ToolVertex):
if isinstance(source_node, (ToolVertex, ToolkitVertex)):
self.tools.append(source_node)
elif isinstance(source_node, ChainVertex):
self.chains.append(source_node)
@ -40,74 +40,6 @@ class ToolVertex(Vertex):
super().__init__(data, base_type="tools")
class PromptVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="prompts")
def build(
self,
force: bool = False,
tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None,
) -> Any:
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build() for tool_node in tools]
if tools is not None
else []
)
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
for param in prompt_params:
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
self._build()
return self._built_object
class ChainVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def build(
self,
force: bool = False,
tools: Optional[Union[List[Vertex], List[ToolVertex]]] = None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptVertex
for key, value in self.params.items():
if isinstance(value, PromptVertex):
# Build the PromptVertex, passing the tools if available
self.params[key] = value.build(tools=tools, force=force)
self._build()
#! Cannot deepcopy SQLDatabaseChain
if self.vertex_type in ["SQLDatabaseChain"]:
return self._built_object
return self._built_object
class LLMVertex(Vertex):
built_node_type = None
class_built_object = None
@ -193,3 +125,68 @@ class TextSplitterVertex(Vertex):
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nDocuments: {self._built_object[:3]}..."""
return f"{self.vertex_type}()"
class ChainVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def build(
self,
force: bool = False,
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptVertex
for key, value in self.params.items():
if isinstance(value, PromptVertex):
# Build the PromptVertex, passing the tools if available
self.params[key] = value.build(tools=tools, force=force)
self._build()
return self._built_object
class PromptVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="prompts")
def build(
self,
force: bool = False,
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
) -> Any:
if not self._built or force:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build() for tool_node in tools]
if tools is not None
else []
)
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
for param in prompt_params:
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
self._build()
return self._built_object