diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index f0d3986cf..2b22d352c 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -144,7 +144,7 @@ class Graph: return list(reversed(sorted_vertices)) - def generator_build(self) -> Generator: + def generator_build(self) -> Generator[Vertex, None, None]: """Builds each vertex in the graph and yields it.""" sorted_vertices = self.topological_sort() logger.debug("Sorted vertices: %s", sorted_vertices) diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 425f66315..ade10365b 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -133,13 +133,13 @@ class Vertex: # Add _type to params self.params = params - def _build(self): + def _build(self, user_id=None): """ Initiate the build process. """ logger.debug(f"Building {self.vertex_type}") - self._build_each_node_in_params_dict() - self._get_and_instantiate_class() + self._build_each_node_in_params_dict(user_id) + self._get_and_instantiate_class(user_id) self._validate_built_object() self._built = True @@ -169,23 +169,25 @@ class Vertex: """ return all(self._is_node(node) for node in value) - def _build_node_and_update_params(self, key, node): + def _build_node_and_update_params(self, key, node, user_id=None): """ Builds a given node and updates the params dictionary accordingly. """ - result = node.build() + result = node.build(user_id) self._handle_func(key, result) if isinstance(result, list): self._extend_params_list_with_result(key, result) self.params[key] = result - def _build_list_of_nodes_and_update_params(self, key, nodes): + def _build_list_of_nodes_and_update_params( + self, key, nodes: List["Vertex"], user_id=None + ): """ Iterates over a list of nodes, builds each and updates the params dictionary. """ self.params[key] = [] for node in nodes: - built = node.build() + built = node.build(user_id) if isinstance(built, list): if key not in self.params: self.params[key] = [] @@ -215,7 +217,7 @@ class Vertex: if isinstance(self.params[key], list): self.params[key].extend(result) - def _get_and_instantiate_class(self): + def _get_and_instantiate_class(self, user_id=None): """ Gets the class from a dictionary and instantiates it with the params. """ @@ -226,6 +228,7 @@ class Vertex: node_type=self.vertex_type, base_type=self.base_type, params=self.params, + user_id=user_id, ) self._update_built_object_and_artifacts(result) except Exception as exc: @@ -255,9 +258,9 @@ class Vertex: raise ValueError(message) - def build(self, force: bool = False) -> Any: + def build(self, force: bool = False, user_id=None) -> Any: if not self._built or force: - self._build() + self._build(user_id) return self._built_object diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 88d2bcc82..1357daf68 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Union +from uuid import UUID from fastapi import HTTPException from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES from langflow.interface.custom.component import Component @@ -22,6 +23,7 @@ class CustomComponent(Component, extra=Extra.allow): function: Optional[Callable] = None return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys()) repr_value: Optional[Any] = "" + user_id: Optional[Union[UUID, str]] = None def __init__(self, **data): super().__init__(**data) @@ -187,11 +189,16 @@ class CustomComponent(Component, extra=Extra.allow): return build_sorted_vertices_with_caching(graph_data) def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]: - get_session = get_session or session_getter - db_manager = get_db_manager() - with get_session(db_manager) as session: - flows = session.query(Flow).all() - return flows + if not self.user_id: + raise ValueError("Session is invalid") + try: + get_session = get_session or session_getter + db_manager = get_db_manager() + with get_session(db_manager) as session: + flows = session.query(Flow).filter(Flow.user_id == self.user_id).all() + return flows + except Exception as e: + raise ValueError("Session is invalid") from e def get_flow( self, @@ -207,7 +214,11 @@ class CustomComponent(Component, extra=Extra.allow): if flow_id: flow = session.query(Flow).get(flow_id) elif flow_name: - flow = session.query(Flow).filter(Flow.name == flow_name).first() + flow = ( + session.query(Flow) + .filter(Flow.name == flow_name) + .filter(Flow.user_id == self.user_id) + ).first() else: raise ValueError("Either flow_name or flow_id must be provided") diff --git a/tests/conftest.py b/tests/conftest.py index 9abe89d49..1359664a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -170,7 +170,7 @@ def test_user(client): @pytest.fixture(scope="function") -def active_user(session): +def active_user(client, session): user = User( username="activeuser", password=get_password_hash(