🐛 fix(base.py): change return type of generator_build method to Generator[Vertex, None, None] for better type hinting and clarity

🐛 fix(base.py): add optional user_id parameter to _build and _build_list_of_nodes_and_update_params methods in Vertex class to support building with user-specific data
🐛 fix(base.py): add optional user_id parameter to _get_and_instantiate_class method in Vertex class to support building with user-specific data
🐛 fix(custom_component.py): add user_id attribute to CustomComponent class to store the user ID associated with the component
🐛 fix(custom_component.py): add user_id parameter to list_flows method in CustomComponent class to filter flows by user ID
🐛 fix(custom_component.py): add user_id parameter to get_flow method in CustomComponent class to filter flows by user ID
🐛 fix(conftest.py): add client parameter to active_user fixture in tests to fix missing dependency error
This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-25 15:19:31 -03:00
commit 1a50bcd183
4 changed files with 33 additions and 19 deletions

View file

@ -144,7 +144,7 @@ class Graph:
return list(reversed(sorted_vertices))
def generator_build(self) -> Generator:
def generator_build(self) -> Generator[Vertex, None, None]:
"""Builds each vertex in the graph and yields it."""
sorted_vertices = self.topological_sort()
logger.debug("Sorted vertices: %s", sorted_vertices)

View file

@ -133,13 +133,13 @@ class Vertex:
# Add _type to params
self.params = params
def _build(self):
def _build(self, user_id=None):
"""
Initiate the build process.
"""
logger.debug(f"Building {self.vertex_type}")
self._build_each_node_in_params_dict()
self._get_and_instantiate_class()
self._build_each_node_in_params_dict(user_id)
self._get_and_instantiate_class(user_id)
self._validate_built_object()
self._built = True
@ -169,23 +169,25 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
def _build_node_and_update_params(self, key, node):
def _build_node_and_update_params(self, key, node, user_id=None):
"""
Builds a given node and updates the params dictionary accordingly.
"""
result = node.build()
result = node.build(user_id)
self._handle_func(key, result)
if isinstance(result, list):
self._extend_params_list_with_result(key, result)
self.params[key] = result
def _build_list_of_nodes_and_update_params(self, key, nodes):
def _build_list_of_nodes_and_update_params(
self, key, nodes: List["Vertex"], user_id=None
):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
self.params[key] = []
for node in nodes:
built = node.build()
built = node.build(user_id)
if isinstance(built, list):
if key not in self.params:
self.params[key] = []
@ -215,7 +217,7 @@ class Vertex:
if isinstance(self.params[key], list):
self.params[key].extend(result)
def _get_and_instantiate_class(self):
def _get_and_instantiate_class(self, user_id=None):
"""
Gets the class from a dictionary and instantiates it with the params.
"""
@ -226,6 +228,7 @@ class Vertex:
node_type=self.vertex_type,
base_type=self.base_type,
params=self.params,
user_id=user_id,
)
self._update_built_object_and_artifacts(result)
except Exception as exc:
@ -255,9 +258,9 @@ class Vertex:
raise ValueError(message)
def build(self, force: bool = False) -> Any:
def build(self, force: bool = False, user_id=None) -> Any:
if not self._built or force:
self._build()
self._build(user_id)
return self._built_object

View file

@ -1,4 +1,5 @@
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Union
from uuid import UUID
from fastapi import HTTPException
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.custom.component import Component
@ -22,6 +23,7 @@ class CustomComponent(Component, extra=Extra.allow):
function: Optional[Callable] = None
return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
repr_value: Optional[Any] = ""
user_id: Optional[Union[UUID, str]] = None
def __init__(self, **data):
super().__init__(**data)
@ -187,11 +189,16 @@ class CustomComponent(Component, extra=Extra.allow):
return build_sorted_vertices_with_caching(graph_data)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
get_session = get_session or session_getter
db_manager = get_db_manager()
with get_session(db_manager) as session:
flows = session.query(Flow).all()
return flows
if not self.user_id:
raise ValueError("Session is invalid")
try:
get_session = get_session or session_getter
db_manager = get_db_manager()
with get_session(db_manager) as session:
flows = session.query(Flow).filter(Flow.user_id == self.user_id).all()
return flows
except Exception as e:
raise ValueError("Session is invalid") from e
def get_flow(
self,
@ -207,7 +214,11 @@ class CustomComponent(Component, extra=Extra.allow):
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = session.query(Flow).filter(Flow.name == flow_name).first()
flow = (
session.query(Flow)
.filter(Flow.name == flow_name)
.filter(Flow.user_id == self.user_id)
).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")

View file

@ -170,7 +170,7 @@ def test_user(client):
@pytest.fixture(scope="function")
def active_user(session):
def active_user(client, session):
user = User(
username="activeuser",
password=get_password_hash(