fix: remove subscribe call and add unsubscribe method in StateService (#5727)

* feat: Implement unsubscribe functionality in state management

- Added `unsubscribe` method to `GraphStateManager` for removing observers from the state service.
- Introduced `unsubscribe` method in `StateService` with a `NotImplementedError` for future implementation.
- Implemented `unsubscribe` in `InMemoryStateService` to allow observers to be removed from the list, ensuring better management of state subscriptions.

This enhancement improves the flexibility of state management by allowing observers to unsubscribe, thus preventing potential memory leaks and ensuring cleaner state handling.

* fix: remove state manager subscribe call
This commit is contained in:
Gabriel Luiz Freitas Almeida 2025-01-16 15:01:01 -03:00 committed by GitHub
commit dafa85c7c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 16 additions and 6 deletions

View file

@ -644,10 +644,7 @@ class Graph:
if run_id is None:
run_id = uuid.uuid4()
run_id_str = str(run_id)
for vertex in self.vertices:
self.state_manager.subscribe(run_id_str, vertex.update_graph_state)
self._run_id = run_id_str
self._run_id = str(run_id)
if self.tracing_service:
self.tracing_service.set_run_id(run_id)
@ -1430,6 +1427,7 @@ class Graph:
vertex.results = cached_vertex_dict["results"]
try:
vertex.finalize_build()
if vertex.result is not None:
vertex.result.used_frozen_result = True
except Exception: # noqa: BLE001

View file

@ -33,3 +33,6 @@ class GraphStateManager:
def subscribe(self, key, observer: Callable) -> None:
self.state_service.subscribe(key, observer)
def unsubscribe(self, key, observer: Callable) -> None:
self.state_service.unsubscribe(key, observer)

View file

@ -23,6 +23,9 @@ class StateService(Service):
def subscribe(self, key, observer: Callable) -> None:
raise NotImplementedError
def unsubscribe(self, key, observer: Callable) -> None:
raise NotImplementedError
def notify_observers(self, key, new_state) -> None:
raise NotImplementedError
@ -30,8 +33,8 @@ class StateService(Service):
class InMemoryStateService(StateService):
def __init__(self, settings_service: SettingsService):
self.settings_service = settings_service
self.states: dict = {}
self.observers: dict = defaultdict(list)
self.states: dict[str, dict] = {}
self.observers: dict[str, list[Callable]] = defaultdict(list)
self.lock = Lock()
def append_state(self, key, new_state, run_id: str) -> None:
@ -72,3 +75,9 @@ class InMemoryStateService(StateService):
except Exception: # noqa: BLE001
logger.exception(f"Error in observer {callback} for key {key}")
logger.warning("Callbacks not implemented yet")
def unsubscribe(self, key, observer: Callable) -> None:
with self.lock:
if observer in self.observers[key]:
# Use list.remove() since observers[key] is a list
self.observers[key].remove(observer)