🐛 fix(types.py): pass user_id parameter to build methods in AgentVertex, LLMVertex, WrapperVertex, ChainVertex, and PromptVertex to enable user-specific functionality

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-28 17:55:29 -03:00
commit 1b79c1bd7e

View file

@ -21,18 +21,18 @@ class AgentVertex(Vertex):
elif isinstance(source_node, ChainVertex):
self.chains.append(source_node)
def build(self, force: bool = False) -> Any:
def build(self, force: bool = False, user_id=None) -> Any:
if not self._built or force:
self._set_tools_and_chains()
# First, build the tools
for tool_node in self.tools:
tool_node.build()
tool_node.build(user_id=user_id)
# Next, build the chains and the rest
for chain_node in self.chains:
chain_node.build(tools=self.tools)
chain_node.build(tools=self.tools, user_id=user_id)
self._build()
self._build(user_id=user_id)
return self._built_object
@ -49,13 +49,13 @@ class LLMVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="llms")
def build(self, force: bool = False) -> Any:
def build(self, force: bool = False, user_id=None) -> Any:
# LLM is different because some models might take up too much memory
# or time to load. So we only load them when we need them.ß
if self.vertex_type == self.built_node_type:
return self.class_built_object
if not self._built or force:
self._build()
self._build(user_id=user_id)
self.built_node_type = self.vertex_type
self.class_built_object = self._built_object
# Avoid deepcopying the LLM
@ -77,11 +77,11 @@ class WrapperVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="wrappers")
def build(self, force: bool = False) -> Any:
def build(self, force: bool = False, user_id=None) -> Any:
if not self._built or force:
if "headers" in self.params:
self.params["headers"] = ast.literal_eval(self.params["headers"])
self._build()
self._build(user_id=user_id)
return self._built_object
@ -149,6 +149,7 @@ class ChainVertex(Vertex):
self,
force: bool = False,
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
user_id=None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptVertex
@ -157,7 +158,7 @@ class ChainVertex(Vertex):
# Build the PromptVertex, passing the tools if available
self.params[key] = value.build(tools=tools, force=force)
self._build()
self._build(user_id=user_id)
return self._built_object
@ -170,6 +171,7 @@ class PromptVertex(Vertex):
self,
force: bool = False,
tools: Optional[List[Union[ToolkitVertex, ToolVertex]]] = None,
user_id=None,
) -> Any:
if not self._built or force:
if (
@ -180,7 +182,7 @@ class PromptVertex(Vertex):
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = (
[tool_node.build() for tool_node in tools]
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
@ -208,7 +210,7 @@ class PromptVertex(Vertex):
else:
self.params.pop("input_variables", None)
self._build()
self._build(user_id=user_id)
return self._built_object
def _built_object_repr(self):