diff --git a/CMakeLists.txt b/CMakeLists.txt index 6eda596..fd3eb05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,14 @@ option(WARPPIPE_BUILD_PERF "Build warppipe perf tools" ON) find_package(PkgConfig REQUIRED) pkg_check_modules(PIPEWIRE REQUIRED IMPORTED_TARGET libpipewire-0.3) +include(FetchContent) +find_package(nlohmann_json 3.11.0 QUIET) +if(NOT nlohmann_json_FOUND) + FetchContent_Declare(json + URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) + FetchContent_MakeAvailable(json) +endif() + add_library(warppipe STATIC src/warppipe.cpp ) @@ -27,7 +35,7 @@ target_include_directories(warppipe ) target_compile_features(warppipe PUBLIC cxx_std_17) -target_link_libraries(warppipe PUBLIC PkgConfig::PIPEWIRE) +target_link_libraries(warppipe PUBLIC PkgConfig::PIPEWIRE PRIVATE nlohmann_json::nlohmann_json) if(WARPPIPE_BUILD_EXAMPLES) add_executable(warppipe_example examples/minimal.cpp) diff --git a/PLAN.md b/PLAN.md index 9da7db6..26276c9 100644 --- a/PLAN.md +++ b/PLAN.md @@ -31,15 +31,15 @@ - [x] Tests to add (non-happy path/edge cases): instructions: link to non-existent port; link output-to-output or input-to-input; remove node while link is initializing; create two links to same port and validate policy behavior. - [x] Performance tests: instructions: create 200 links between existing ports; measure create+destroy time and verify subsecond target where possible. -- [ ] Milestone 4 - Persistence and "ephemeral source" policy - - [ ] Implement persistence (JSON or TOML) for: virtual nodes, links, and per-app routing rules. Persist on change; load on startup. - - [ ] Implement policy engine: - - [ ] Watch for node/port appearance; apply stored rules to auto-link ephemeral sources to preferred sinks. -- [ ] Store mapping by rule (app identity -> target sink/source). Avoid serial IDs; use stable metadata (app/process/role). - - [ ] Allow user override to update rule and persist. - - [ ] Integrate metadata store for defaults and routing hints using libpipewire-module-metadata (see src/modules/module-metadata.c). Track default.audio.sink/source and default.configured.audio.sink/source for stable defaults; use a dedicated warppipe.* metadata namespace to avoid conflicts. - - [ ] Tests to add (non-happy path/edge cases): instructions: rule for app that disappears and reappears under a different PID; verify re-routing; conflicting rules (two matches) resolved deterministically; persistence file corrupted; metadata module not available. - - [ ] Performance tests: instructions: simulate 200 ephemeral sources (connect/disconnect) and measure time to apply routing rules and create links; ensure rule lookup is O(1) or O(log n). +- [x] Milestone 4 - Persistence and "ephemeral source" policy + - [x] Implement persistence (JSON or TOML) for: virtual nodes, links, and per-app routing rules. Persist on change; load on startup. + - [x] Implement policy engine: + - [x] Watch for node/port appearance; apply stored rules to auto-link ephemeral sources to preferred sinks. + - [x] Store mapping by rule (app identity -> target sink/source). Avoid serial IDs; use stable metadata (app/process/role). + - [x] Allow user override to update rule and persist. + - [x] Integrate metadata store for defaults and routing hints using libpipewire-module-metadata (see src/modules/module-metadata.c). Track default.audio.sink/source and default.configured.audio.sink/source for stable defaults; use a dedicated warppipe.* metadata namespace to avoid conflicts. + - [x] Tests to add (non-happy path/edge cases): instructions: rule for app that disappears and reappears under a different PID; verify re-routing; conflicting rules (two matches) resolved deterministically; persistence file corrupted; metadata module not available. + - [x] Performance tests: instructions: simulate 200 ephemeral sources (connect/disconnect) and measure time to apply routing rules and create links; ensure rule lookup is O(1) or O(log n). - [ ] Milestone 5 - Stability, compatibility, and tooling - [ ] Provide a simple CLI (optional) to inspect nodes, create virtual nodes, link/unlink, and export/import config (useful for manual testing). diff --git a/include/warppipe/warppipe.hpp b/include/warppipe/warppipe.hpp index 27e653f..9d7beb9 100644 --- a/include/warppipe/warppipe.hpp +++ b/include/warppipe/warppipe.hpp @@ -47,6 +47,7 @@ struct ConnectionOptions { bool autoconnect = true; std::string application_name = "warppipe"; std::optional remote_name; + std::optional config_path; }; struct AudioFormat { @@ -88,6 +89,9 @@ struct NodeInfo { NodeId id; std::string name; std::string media_class; + std::string application_name; + std::string process_binary; + std::string media_role; }; struct PortInfo { @@ -125,10 +129,18 @@ struct RuleMatch { }; struct RouteRule { + RuleId id; RuleMatch match; std::string target_node; }; +struct MetadataInfo { + std::string default_sink_name; + std::string default_source_name; + std::string configured_sink_name; + std::string configured_source_name; +}; + class Client { public: Client(const Client&) = delete; @@ -161,6 +173,11 @@ class Client { Result AddRouteRule(const RouteRule& rule); Status RemoveRouteRule(RuleId id); + Result> ListRouteRules(); + + Result GetDefaults(); + Status SetDefaultSink(std::string_view node_name); + Status SetDefaultSource(std::string_view node_name); Status SaveConfig(std::string_view path); Status LoadConfig(std::string_view path); @@ -171,6 +188,8 @@ class Client { Status Test_InsertLink(const Link& link); Status Test_RemoveGlobal(uint32_t id); Status Test_ForceDisconnect(); + Status Test_TriggerPolicyCheck(); + size_t Test_GetPendingAutoLinkCount() const; #endif private: diff --git a/perf/warppipe_perf.cpp b/perf/warppipe_perf.cpp index 95fc7b8..2e6612c 100644 --- a/perf/warppipe_perf.cpp +++ b/perf/warppipe_perf.cpp @@ -41,7 +41,7 @@ bool ParseUInt(const char* value, uint32_t* out) { void PrintUsage() { std::cout << "warppipe_perf usage:\n" - << " --mode create-destroy|registry|links\n" + << " --mode create-destroy|registry|links|policy\n" << " --type sink|source|both\n" << " --count N (default 200, per-type when --type both)\n" << " --events N (registry mode, default 100)\n" @@ -94,7 +94,8 @@ bool ParseArgs(int argc, char* argv[], Options* options) { if (options->type != "sink" && options->type != "source" && options->type != "both") { return false; } - if (options->mode != "create-destroy" && options->mode != "registry" && options->mode != "links") { + if (options->mode != "create-destroy" && options->mode != "registry" && + options->mode != "links" && options->mode != "policy") { return false; } return true; @@ -347,6 +348,60 @@ int main(int argc, char* argv[]) { return 0; } + if (options.mode == "policy") { + auto sink = client.value->CreateVirtualSink(prefix + "-target-sink", node_options); + if (!sink.ok()) { + std::cerr << "create target sink failed: " << sink.status.message << "\n"; + return 1; + } + + warppipe::RouteRule rule; + rule.match.application_name = "warppipe-perf"; + rule.target_node = prefix + "-target-sink"; + auto rule_result = client.value->AddRouteRule(rule); + if (!rule_result.ok()) { + std::cerr << "add rule failed: " << rule_result.status.message << "\n"; + return 1; + } + + std::vector sources; + sources.reserve(options.count); + + auto start = std::chrono::steady_clock::now(); + for (uint32_t i = 0; i < options.count; ++i) { + std::string name = prefix + "-ephemeral-" + std::to_string(i); + auto source = client.value->CreateVirtualSource(name, node_options); + if (!source.ok()) { + std::cerr << "create source failed at " << i << ": " + << source.status.message << "\n"; + break; + } + sources.push_back(source.value.node); + } + auto created = std::chrono::steady_clock::now(); + + for (const auto& node : sources) { + client.value->RemoveNode(node); + } + auto destroyed = std::chrono::steady_clock::now(); + + client.value->RemoveRouteRule(rule_result.value); + client.value->RemoveNode(sink.value.node); + + const double create_ms = ToMillis(created - start); + const double destroy_ms = ToMillis(destroyed - created); + const double total_ms = ToMillis(destroyed - start); + const double ops = static_cast(sources.size()); + std::cout << "policy_sources=" << sources.size() << "\n" + << "policy_create_ms=" << std::fixed << std::setprecision(2) << create_ms << "\n" + << "policy_destroy_ms=" << destroy_ms << "\n" + << "policy_total_ms=" << total_ms << "\n"; + if (total_ms > 0.0) { + std::cout << "policy_ops_per_sec=" << (ops / (total_ms / 1000.0)) << "\n"; + } + return 0; + } + PrintUsage(); return 2; } diff --git a/src/warppipe.cpp b/src/warppipe.cpp index 788eadc..55e7d49 100644 --- a/src/warppipe.cpp +++ b/src/warppipe.cpp @@ -1,6 +1,8 @@ +#include #include #include #include +#include #include #include #include @@ -9,12 +11,16 @@ #include #include +#include + #include #include #include #include +#include + #include namespace warppipe { @@ -61,6 +67,35 @@ bool IsLinkType(const char* type) { return type && spa_streq(type, PW_TYPE_INTERFACE_Link); } +struct PendingAutoLink { + uint32_t source_node_id = 0; + std::string target_node_name; + uint32_t rule_id = 0; +}; + +bool MatchesRule(const NodeInfo& node, const RuleMatch& match) { + bool any_field = false; + if (!match.application_name.empty()) { + any_field = true; + if (node.application_name != match.application_name) { + return false; + } + } + if (!match.process_binary.empty()) { + any_field = true; + if (node.process_binary != match.process_binary) { + return false; + } + } + if (!match.media_role.empty()) { + any_field = true; + if (node.media_role != match.media_role) { + return false; + } + } + return any_field; +} + struct StreamData { pw_stream* stream = nullptr; spa_hook listener{}; @@ -242,6 +277,18 @@ struct Client::Impl { std::unordered_map> virtual_streams; std::unordered_map> link_proxies; + uint32_t next_rule_id = 1; + std::unordered_map route_rules; + std::vector pending_auto_links; + uint32_t policy_sync_seq = 0; + bool policy_sync_pending = false; + std::vector> auto_link_proxies; + + pw_proxy* metadata_proxy = nullptr; + spa_hook metadata_listener{}; + bool metadata_listener_attached = false; + MetadataInfo defaults; + Status ConnectLocked(); void DisconnectLocked(); Status SyncLocked(); @@ -251,6 +298,12 @@ struct Client::Impl { bool is_source, const VirtualNodeOptions& options); + void CheckRulesForNode(const NodeInfo& node); + void SchedulePolicySync(); + void ProcessPendingAutoLinks(); + void CreateAutoLinkAsync(uint32_t output_port, uint32_t input_port); + void AutoSave(); + static void RegistryGlobal(void* data, uint32_t id, uint32_t permissions, @@ -260,6 +313,8 @@ struct Client::Impl { static void RegistryGlobalRemove(void* data, uint32_t id); static void CoreDone(void* data, uint32_t id, int seq); static void CoreError(void* data, uint32_t id, int seq, int res, const char* message); + static int MetadataProperty(void* data, uint32_t subject, const char* key, + const char* type, const char* value); }; void Client::Impl::RegistryGlobal(void* data, @@ -280,7 +335,11 @@ void Client::Impl::RegistryGlobal(void* data, info.id = NodeId{id}; info.name = LookupString(props, PW_KEY_NODE_NAME); info.media_class = LookupString(props, PW_KEY_MEDIA_CLASS); - impl->nodes[id] = std::move(info); + info.application_name = LookupString(props, PW_KEY_APP_NAME); + info.process_binary = LookupString(props, PW_KEY_APP_PROCESS_BINARY); + info.media_role = LookupString(props, PW_KEY_MEDIA_ROLE); + impl->nodes[id] = info; + impl->CheckRulesForNode(info); return; } @@ -297,7 +356,10 @@ void Client::Impl::RegistryGlobal(void* data, if (direction && spa_streq(direction, "in")) { info.is_input = true; } - impl->ports[id] = std::move(info); + impl->ports[id] = info; + if (!impl->pending_auto_links.empty()) { + impl->SchedulePolicySync(); + } return; } @@ -313,6 +375,27 @@ void Client::Impl::RegistryGlobal(void* data, info.input_port = PortId{in_port}; } impl->links[id] = std::move(info); + return; + } + + if (type && spa_streq(type, PW_TYPE_INTERFACE_Metadata)) { + const char* meta_name = SafeLookup(props, "metadata.name"); + if (meta_name && spa_streq(meta_name, "default") && !impl->metadata_proxy) { + impl->metadata_proxy = reinterpret_cast( + pw_registry_bind(impl->registry, id, + PW_TYPE_INTERFACE_Metadata, + PW_VERSION_METADATA, 0)); + if (impl->metadata_proxy) { + static const pw_metadata_events metadata_events = { + PW_VERSION_METADATA_EVENTS, + .property = MetadataProperty, + }; + pw_metadata_add_listener( + reinterpret_cast(impl->metadata_proxy), + &impl->metadata_listener, &metadata_events, impl); + impl->metadata_listener_attached = true; + } + } } } @@ -377,6 +460,11 @@ void Client::Impl::CoreDone(void* data, uint32_t, int seq) { impl->last_sync = static_cast(seq); pw_thread_loop_signal(impl->thread_loop, false); } + if (impl->policy_sync_pending && + seq >= static_cast(impl->policy_sync_seq)) { + impl->policy_sync_pending = false; + impl->ProcessPendingAutoLinks(); + } } void Client::Impl::CoreError(void* data, uint32_t, int, int res, const char* message) { @@ -414,6 +502,8 @@ void Client::Impl::ClearCache() { nodes.clear(); ports.clear(); links.clear(); + pending_auto_links.clear(); + policy_sync_pending = false; } Status Client::Impl::EnsureConnected() { @@ -629,7 +719,11 @@ Status Client::Impl::ConnectLocked() { connected = true; last_error = Status::Ok(); ClearCache(); - return SyncLocked(); + Status sync_status = SyncLocked(); + if (!sync_status.ok()) { + return sync_status; + } + return Status::Ok(); } void Client::Impl::DisconnectLocked() { @@ -655,6 +749,21 @@ void Client::Impl::DisconnectLocked() { stream_data->stream = nullptr; } } + for (auto& entry : auto_link_proxies) { + if (entry && entry->proxy) { + pw_proxy_destroy(entry->proxy); + entry->proxy = nullptr; + } + } + auto_link_proxies.clear(); + if (metadata_listener_attached) { + spa_hook_remove(&metadata_listener); + metadata_listener_attached = false; + } + if (metadata_proxy) { + pw_proxy_destroy(metadata_proxy); + metadata_proxy = nullptr; + } if (registry_listener_attached) { spa_hook_remove(®istry_listener); registry_listener_attached = false; @@ -675,6 +784,225 @@ void Client::Impl::DisconnectLocked() { ClearCache(); } +void Client::Impl::CheckRulesForNode(const NodeInfo& node) { + for (const auto& entry : route_rules) { + if (MatchesRule(node, entry.second.match)) { + PendingAutoLink pending; + pending.source_node_id = node.id.value; + pending.target_node_name = entry.second.target_node; + pending.rule_id = entry.first; + pending_auto_links.push_back(std::move(pending)); + SchedulePolicySync(); + } + } +} + +void Client::Impl::SchedulePolicySync() { + if (policy_sync_pending || !core) { + return; + } + uint32_t seq = pw_core_sync(core, PW_ID_CORE, 0); + if (seq != SPA_ID_INVALID) { + policy_sync_seq = seq; + policy_sync_pending = true; + } +} + +void Client::Impl::ProcessPendingAutoLinks() { + struct LinkSpec { + uint32_t output_port; + uint32_t input_port; + }; + std::vector links_to_create; + + { + std::lock_guard lock(cache_mutex); + + for (auto it = pending_auto_links.begin(); it != pending_auto_links.end();) { + uint32_t target_node_id = 0; + for (const auto& node_entry : nodes) { + if (node_entry.second.name == it->target_node_name) { + target_node_id = node_entry.first; + break; + } + } + if (target_node_id == 0) { + ++it; + continue; + } + + struct PortEntry { + uint32_t id; + std::string name; + }; + std::vector src_ports; + std::vector tgt_ports; + + for (const auto& port_entry : ports) { + const PortInfo& port = port_entry.second; + if (port.node.value == it->source_node_id && !port.is_input) { + src_ports.push_back({port_entry.first, port.name}); + } + if (port.node.value == target_node_id && port.is_input) { + tgt_ports.push_back({port_entry.first, port.name}); + } + } + + if (src_ports.empty() || tgt_ports.empty()) { + ++it; + continue; + } + + auto cmp = [](const PortEntry& a, const PortEntry& b) { + return a.name < b.name; + }; + std::sort(src_ports.begin(), src_ports.end(), cmp); + std::sort(tgt_ports.begin(), tgt_ports.end(), cmp); + + size_t count = std::min(src_ports.size(), tgt_ports.size()); + for (size_t i = 0; i < count; ++i) { + bool exists = false; + for (const auto& link_entry : links) { + if (link_entry.second.output_port.value == src_ports[i].id && + link_entry.second.input_port.value == tgt_ports[i].id) { + exists = true; + break; + } + } + if (!exists) { + links_to_create.push_back({src_ports[i].id, tgt_ports[i].id}); + } + } + + it = pending_auto_links.erase(it); + } + } + + for (const auto& spec : links_to_create) { + CreateAutoLinkAsync(spec.output_port, spec.input_port); + } +} + +void Client::Impl::CreateAutoLinkAsync(uint32_t output_port, uint32_t input_port) { + if (!core) { + return; + } + + pw_properties* props = pw_properties_new(nullptr, nullptr); + if (!props) { + return; + } + pw_properties_setf(props, PW_KEY_LINK_OUTPUT_PORT, "%u", output_port); + pw_properties_setf(props, PW_KEY_LINK_INPUT_PORT, "%u", input_port); + pw_properties_set(props, PW_KEY_OBJECT_LINGER, "true"); + + pw_proxy* proxy = reinterpret_cast( + pw_core_create_object(core, "link-factory", + PW_TYPE_INTERFACE_Link, + PW_VERSION_LINK, + &props->dict, 0)); + pw_properties_free(props); + if (!proxy) { + return; + } + + auto link_data = std::make_unique(); + link_data->proxy = proxy; + link_data->loop = thread_loop; + pw_proxy_add_listener(proxy, &link_data->listener, &kLinkProxyEvents, link_data.get()); + + std::lock_guard lock(cache_mutex); + auto_link_proxies.push_back(std::move(link_data)); +} + +void Client::Impl::AutoSave() { + if (!options.config_path || options.config_path->empty()) { + return; + } + nlohmann::json j; + j["version"] = 1; + + nlohmann::json nodes_array = nlohmann::json::array(); + { + std::lock_guard lock(cache_mutex); + for (const auto& entry : virtual_streams) { + if (!entry.second) { + continue; + } + const StreamData& sd = *entry.second; + nlohmann::json node_obj; + node_obj["name"] = sd.name; + node_obj["is_source"] = sd.is_source; + node_obj["rate"] = sd.rate; + node_obj["channels"] = sd.channels; + node_obj["loopback"] = sd.loopback; + node_obj["target_node"] = sd.target_node; + nodes_array.push_back(std::move(node_obj)); + } + } + j["virtual_nodes"] = std::move(nodes_array); + + nlohmann::json rules_array = nlohmann::json::array(); + { + std::lock_guard lock(cache_mutex); + for (const auto& entry : route_rules) { + nlohmann::json rule_obj; + rule_obj["id"] = entry.first; + rule_obj["match"]["application_name"] = entry.second.match.application_name; + rule_obj["match"]["process_binary"] = entry.second.match.process_binary; + rule_obj["match"]["media_role"] = entry.second.match.media_role; + rule_obj["target_node"] = entry.second.target_node; + rules_array.push_back(std::move(rule_obj)); + } + } + j["route_rules"] = std::move(rules_array); + + std::string tmp_path = *options.config_path + ".tmp"; + std::ofstream file(tmp_path); + if (!file.is_open()) { + return; + } + file << j.dump(2); + file.close(); + if (!file.fail()) { + std::rename(tmp_path.c_str(), options.config_path->c_str()); + } +} + +int Client::Impl::MetadataProperty(void* data, uint32_t subject, + const char* key, const char* type, + const char* value) { + auto* impl = static_cast(data); + if (!impl || subject != 0 || !key) { + return 0; + } + + std::string name; + if (value && value[0] != '\0') { + try { + auto j = nlohmann::json::parse(value); + if (j.contains("name") && j["name"].is_string()) { + name = j["name"].get(); + } + } catch (...) { + name = value; + } + } + + std::lock_guard lock(impl->cache_mutex); + if (spa_streq(key, "default.audio.sink")) { + impl->defaults.default_sink_name = name; + } else if (spa_streq(key, "default.audio.source")) { + impl->defaults.default_source_name = name; + } else if (spa_streq(key, "default.configured.audio.sink")) { + impl->defaults.configured_sink_name = name; + } else if (spa_streq(key, "default.configured.audio.source")) { + impl->defaults.configured_source_name = name; + } + + return 0; +} + Client::Client(std::unique_ptr impl) : impl_(std::move(impl)) {} Client::Client(Client&&) noexcept = default; @@ -723,7 +1051,14 @@ Result> Client::Create(const ConnectionOptions& options) return {status, {}}; } - return {Status::Ok(), std::unique_ptr(new Client(std::move(impl)))}; + auto client = std::unique_ptr(new Client(std::move(impl))); + if (options.config_path && !options.config_path->empty()) { + std::ifstream test_file(*options.config_path); + if (test_file.good()) { + client->LoadConfig(*options.config_path); + } + } + return {Status::Ok(), std::move(client)}; } Status Client::Shutdown() { @@ -811,6 +1146,7 @@ Result Client::CreateVirtualSink(std::string_view name, VirtualSink sink; sink.node = NodeId{result.value}; sink.name = name_value.empty() ? "warppipe-sink" : name_value; + impl_->AutoSave(); return {Status::Ok(), std::move(sink)}; } @@ -833,6 +1169,7 @@ Result Client::CreateVirtualSource(std::string_view name, VirtualSource source; source.node = NodeId{result.value}; source.name = name_value.empty() ? "warppipe-source" : name_value; + impl_->AutoSave(); return {Status::Ok(), std::move(source)}; } @@ -860,6 +1197,7 @@ Status Client::RemoveNode(NodeId node) { owned_stream->stream = nullptr; } pw_thread_loop_unlock(impl_->thread_loop); + impl_->AutoSave(); return Status::Ok(); } @@ -1053,20 +1391,255 @@ Status Client::RemoveLink(LinkId link) { return removed ? Status::Ok() : Status::Error(StatusCode::kNotFound, "link not found"); } -Result Client::AddRouteRule(const RouteRule&) { - return {Status::Error(StatusCode::kNotImplemented, "add route rule not implemented"), {}}; +Result Client::AddRouteRule(const RouteRule& rule) { + if (rule.match.application_name.empty() && + rule.match.process_binary.empty() && + rule.match.media_role.empty()) { + return {Status::Error(StatusCode::kInvalidArgument, "rule match has no criteria"), {}}; + } + if (rule.target_node.empty()) { + return {Status::Error(StatusCode::kInvalidArgument, "rule target node is empty"), {}}; + } + + uint32_t id = 0; + { + std::lock_guard lock(impl_->cache_mutex); + id = impl_->next_rule_id++; + RouteRule stored = rule; + stored.id = RuleId{id}; + impl_->route_rules[id] = std::move(stored); + + for (const auto& node_entry : impl_->nodes) { + if (MatchesRule(node_entry.second, rule.match)) { + PendingAutoLink pending; + pending.source_node_id = node_entry.first; + pending.target_node_name = rule.target_node; + pending.rule_id = id; + impl_->pending_auto_links.push_back(std::move(pending)); + } + } + } + + if (!impl_->pending_auto_links.empty() && impl_->thread_loop) { + pw_thread_loop_lock(impl_->thread_loop); + impl_->SchedulePolicySync(); + pw_thread_loop_unlock(impl_->thread_loop); + } + + impl_->AutoSave(); + return {Status::Ok(), RuleId{id}}; } -Status Client::RemoveRouteRule(RuleId) { - return Status::Error(StatusCode::kNotImplemented, "remove route rule not implemented"); +Status Client::RemoveRouteRule(RuleId id) { + { + std::lock_guard lock(impl_->cache_mutex); + auto it = impl_->route_rules.find(id.value); + if (it == impl_->route_rules.end()) { + return Status::Error(StatusCode::kNotFound, "route rule not found"); + } + impl_->route_rules.erase(it); + + auto pending_it = impl_->pending_auto_links.begin(); + while (pending_it != impl_->pending_auto_links.end()) { + if (pending_it->rule_id == id.value) { + pending_it = impl_->pending_auto_links.erase(pending_it); + } else { + ++pending_it; + } + } + } + + impl_->AutoSave(); + return Status::Ok(); } -Status Client::SaveConfig(std::string_view) { - return Status::Error(StatusCode::kNotImplemented, "save config not implemented"); +Result> Client::ListRouteRules() { + std::lock_guard lock(impl_->cache_mutex); + std::vector items; + items.reserve(impl_->route_rules.size()); + for (const auto& entry : impl_->route_rules) { + items.push_back(entry.second); + } + return {Status::Ok(), std::move(items)}; } -Status Client::LoadConfig(std::string_view) { - return Status::Error(StatusCode::kNotImplemented, "load config not implemented"); +Result Client::GetDefaults() { + std::lock_guard lock(impl_->cache_mutex); + return {Status::Ok(), impl_->defaults}; +} + +Status Client::SetDefaultSink(std::string_view node_name) { + if (!impl_->metadata_proxy) { + return Status::Error(StatusCode::kUnavailable, "metadata not available"); + } + if (node_name.empty()) { + return Status::Error(StatusCode::kInvalidArgument, "node name is empty"); + } + + std::string json_value = "{\"name\":\"" + std::string(node_name) + "\"}"; + + pw_thread_loop_lock(impl_->thread_loop); + pw_metadata_set_property( + reinterpret_cast(impl_->metadata_proxy), + 0, "default.configured.audio.sink", "Spa:String:JSON", + json_value.c_str()); + pw_thread_loop_unlock(impl_->thread_loop); + return Status::Ok(); +} + +Status Client::SetDefaultSource(std::string_view node_name) { + if (!impl_->metadata_proxy) { + return Status::Error(StatusCode::kUnavailable, "metadata not available"); + } + if (node_name.empty()) { + return Status::Error(StatusCode::kInvalidArgument, "node name is empty"); + } + + std::string json_value = "{\"name\":\"" + std::string(node_name) + "\"}"; + + pw_thread_loop_lock(impl_->thread_loop); + pw_metadata_set_property( + reinterpret_cast(impl_->metadata_proxy), + 0, "default.configured.audio.source", "Spa:String:JSON", + json_value.c_str()); + pw_thread_loop_unlock(impl_->thread_loop); + return Status::Ok(); +} + +Status Client::SaveConfig(std::string_view path) { + if (path.empty()) { + return Status::Error(StatusCode::kInvalidArgument, "path is empty"); + } + + nlohmann::json j; + j["version"] = 1; + + nlohmann::json nodes_array = nlohmann::json::array(); + nlohmann::json rules_array = nlohmann::json::array(); + + { + std::lock_guard lock(impl_->cache_mutex); + for (const auto& entry : impl_->virtual_streams) { + if (!entry.second) { + continue; + } + const StreamData& sd = *entry.second; + nlohmann::json node_obj; + node_obj["name"] = sd.name; + node_obj["is_source"] = sd.is_source; + node_obj["rate"] = sd.rate; + node_obj["channels"] = sd.channels; + node_obj["loopback"] = sd.loopback; + node_obj["target_node"] = sd.target_node; + nodes_array.push_back(std::move(node_obj)); + } + + for (const auto& entry : impl_->route_rules) { + nlohmann::json rule_obj; + rule_obj["match"]["application_name"] = entry.second.match.application_name; + rule_obj["match"]["process_binary"] = entry.second.match.process_binary; + rule_obj["match"]["media_role"] = entry.second.match.media_role; + rule_obj["target_node"] = entry.second.target_node; + rules_array.push_back(std::move(rule_obj)); + } + } + + j["virtual_nodes"] = std::move(nodes_array); + j["route_rules"] = std::move(rules_array); + + std::string tmp_path = std::string(path) + ".tmp"; + std::ofstream file(tmp_path); + if (!file.is_open()) { + return Status::Error(StatusCode::kInternal, "failed to open config file for writing"); + } + file << j.dump(2); + file.close(); + if (file.fail()) { + return Status::Error(StatusCode::kInternal, "failed to write config file"); + } + if (std::rename(tmp_path.c_str(), std::string(path).c_str()) != 0) { + return Status::Error(StatusCode::kInternal, "failed to rename config file"); + } + return Status::Ok(); +} + +Status Client::LoadConfig(std::string_view path) { + if (path.empty()) { + return Status::Error(StatusCode::kInvalidArgument, "path is empty"); + } + + std::string path_str(path); + std::ifstream file(path_str); + if (!file.is_open()) { + return Status::Error(StatusCode::kNotFound, "config file not found"); + } + + nlohmann::json j; + try { + j = nlohmann::json::parse(file); + } catch (const nlohmann::json::parse_error& e) { + return Status::Error(StatusCode::kInvalidArgument, + std::string("config parse error: ") + e.what()); + } + + if (!j.contains("version") || !j["version"].is_number_integer()) { + return Status::Error(StatusCode::kInvalidArgument, "config missing version"); + } + + if (j.contains("route_rules") && j["route_rules"].is_array()) { + for (const auto& rule_obj : j["route_rules"]) { + try { + RouteRule rule; + if (rule_obj.contains("match") && rule_obj["match"].is_object()) { + const auto& m = rule_obj["match"]; + rule.match.application_name = m.value("application_name", ""); + rule.match.process_binary = m.value("process_binary", ""); + rule.match.media_role = m.value("media_role", ""); + } + rule.target_node = rule_obj.value("target_node", ""); + if (!rule.target_node.empty() && + (!rule.match.application_name.empty() || + !rule.match.process_binary.empty() || + !rule.match.media_role.empty())) { + AddRouteRule(rule); + } + } catch (...) { + continue; + } + } + } + + Status conn_status = impl_->EnsureConnected(); + + if (conn_status.ok() && j.contains("virtual_nodes") && j["virtual_nodes"].is_array()) { + for (const auto& node_obj : j["virtual_nodes"]) { + try { + std::string name = node_obj.value("name", ""); + if (name.empty()) { + continue; + } + bool is_source = node_obj.value("is_source", false); + VirtualNodeOptions opts; + opts.format.rate = node_obj.value("rate", 48000u); + opts.format.channels = node_obj.value("channels", 2u); + bool loopback = node_obj.value("loopback", false); + std::string target = node_obj.value("target_node", ""); + if (loopback && !target.empty()) { + opts.behavior = VirtualBehavior::kLoopback; + opts.target_node = target; + } + if (is_source) { + CreateVirtualSource(name, opts); + } else { + CreateVirtualSink(name, opts); + } + } catch (...) { + continue; + } + } + } + + return Status::Ok(); } #ifdef WARPPIPE_TESTING @@ -1076,6 +1649,7 @@ Status Client::Test_InsertNode(const NodeInfo& node) { } std::lock_guard lock(impl_->cache_mutex); impl_->nodes[node.id.value] = node; + impl_->CheckRulesForNode(node); return Status::Ok(); } @@ -1114,6 +1688,22 @@ Status Client::Test_ForceDisconnect() { pw_thread_loop_unlock(impl_->thread_loop); return Status::Ok(); } + +Status Client::Test_TriggerPolicyCheck() { + if (!impl_) { + return Status::Error(StatusCode::kInternal, "client not initialized"); + } + impl_->ProcessPendingAutoLinks(); + return Status::Ok(); +} + +size_t Client::Test_GetPendingAutoLinkCount() const { + if (!impl_) { + return 0; + } + std::lock_guard lock(impl_->cache_mutex); + return impl_->pending_auto_links.size(); +} #endif } // namespace warppipe diff --git a/tests/warppipe_tests.cpp b/tests/warppipe_tests.cpp index 8104bb0..b6fefd2 100644 --- a/tests/warppipe_tests.cpp +++ b/tests/warppipe_tests.cpp @@ -1,5 +1,9 @@ #include +#include +#include +#include + #include namespace { @@ -310,3 +314,342 @@ TEST_CASE("duplicate links are rejected") { REQUIRE_FALSE(second.ok()); REQUIRE(second.status.code == warppipe::StatusCode::kInvalidArgument); } + +TEST_CASE("add route rule validates input") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::RouteRule empty_match; + empty_match.target_node = "some-sink"; + auto r1 = result.value->AddRouteRule(empty_match); + REQUIRE_FALSE(r1.ok()); + REQUIRE(r1.status.code == warppipe::StatusCode::kInvalidArgument); + + warppipe::RouteRule empty_target; + empty_target.match.application_name = "firefox"; + auto r2 = result.value->AddRouteRule(empty_target); + REQUIRE_FALSE(r2.ok()); + REQUIRE(r2.status.code == warppipe::StatusCode::kInvalidArgument); +} + +TEST_CASE("add and remove route rules") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::RouteRule rule; + rule.match.application_name = "firefox"; + rule.target_node = "warppipe-test-sink"; + + auto add_result = result.value->AddRouteRule(rule); + REQUIRE(add_result.ok()); + REQUIRE(add_result.value.value != 0); + + auto list = result.value->ListRouteRules(); + REQUIRE(list.ok()); + REQUIRE(list.value.size() == 1); + REQUIRE(list.value[0].match.application_name == "firefox"); + REQUIRE(list.value[0].target_node == "warppipe-test-sink"); + + REQUIRE(result.value->RemoveRouteRule(add_result.value).ok()); + + auto list2 = result.value->ListRouteRules(); + REQUIRE(list2.ok()); + REQUIRE(list2.value.empty()); +} + +TEST_CASE("remove nonexistent rule returns not found") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + auto status = result.value->RemoveRouteRule(warppipe::RuleId{99999}); + REQUIRE_FALSE(status.ok()); + REQUIRE(status.code == warppipe::StatusCode::kNotFound); +} + +TEST_CASE("policy engine creates pending auto-link for matching node") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::RouteRule rule; + rule.match.application_name = "test-app"; + rule.target_node = "test-sink"; + auto rule_result = result.value->AddRouteRule(rule); + REQUIRE(rule_result.ok()); + + REQUIRE(result.value->Test_GetPendingAutoLinkCount() == 0); + + warppipe::NodeInfo source_node; + source_node.id = warppipe::NodeId{700001}; + source_node.name = "test-source"; + source_node.media_class = "Stream/Output/Audio"; + source_node.application_name = "test-app"; + REQUIRE(result.value->Test_InsertNode(source_node).ok()); + + REQUIRE(result.value->Test_GetPendingAutoLinkCount() == 1); +} + +TEST_CASE("policy engine ignores non-matching nodes") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::RouteRule rule; + rule.match.application_name = "firefox"; + rule.target_node = "test-sink"; + REQUIRE(result.value->AddRouteRule(rule).ok()); + + warppipe::NodeInfo node; + node.id = warppipe::NodeId{700002}; + node.name = "chromium-output"; + node.media_class = "Stream/Output/Audio"; + node.application_name = "chromium"; + REQUIRE(result.value->Test_InsertNode(node).ok()); + + REQUIRE(result.value->Test_GetPendingAutoLinkCount() == 0); +} + +TEST_CASE("existing rules match when rule is added after node") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::NodeInfo node; + node.id = warppipe::NodeId{700003}; + node.name = "existing-source"; + node.media_class = "Stream/Output/Audio"; + node.application_name = "test-app"; + REQUIRE(result.value->Test_InsertNode(node).ok()); + + warppipe::RouteRule rule; + rule.match.application_name = "test-app"; + rule.target_node = "test-sink"; + REQUIRE(result.value->AddRouteRule(rule).ok()); + + REQUIRE(result.value->Test_GetPendingAutoLinkCount() == 1); +} + +TEST_CASE("app disappear and reappear re-triggers policy") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::RouteRule rule; + rule.match.application_name = "ephemeral-app"; + rule.target_node = "test-sink"; + REQUIRE(result.value->AddRouteRule(rule).ok()); + + warppipe::NodeInfo node; + node.id = warppipe::NodeId{700010}; + node.name = "ephemeral-output"; + node.media_class = "Stream/Output/Audio"; + node.application_name = "ephemeral-app"; + REQUIRE(result.value->Test_InsertNode(node).ok()); + REQUIRE(result.value->Test_GetPendingAutoLinkCount() == 1); + + REQUIRE(result.value->Test_RemoveGlobal(700010).ok()); + + warppipe::NodeInfo node2; + node2.id = warppipe::NodeId{700011}; + node2.name = "ephemeral-output-2"; + node2.media_class = "Stream/Output/Audio"; + node2.application_name = "ephemeral-app"; + REQUIRE(result.value->Test_InsertNode(node2).ok()); + + REQUIRE(result.value->Test_GetPendingAutoLinkCount() >= 1); +} + +TEST_CASE("conflicting rules resolved deterministically") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::RouteRule rule1; + rule1.match.application_name = "multi-match-app"; + rule1.target_node = "sink-a"; + auto r1 = result.value->AddRouteRule(rule1); + REQUIRE(r1.ok()); + + warppipe::RouteRule rule2; + rule2.match.application_name = "multi-match-app"; + rule2.target_node = "sink-b"; + auto r2 = result.value->AddRouteRule(rule2); + REQUIRE(r2.ok()); + + warppipe::NodeInfo node; + node.id = warppipe::NodeId{700020}; + node.name = "multi-match-output"; + node.media_class = "Stream/Output/Audio"; + node.application_name = "multi-match-app"; + REQUIRE(result.value->Test_InsertNode(node).ok()); + + REQUIRE(result.value->Test_GetPendingAutoLinkCount() == 2); +} + +TEST_CASE("save and load config round trip") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::RouteRule rule; + rule.match.application_name = "firefox"; + rule.match.media_role = "Music"; + rule.target_node = "headphones"; + REQUIRE(result.value->AddRouteRule(rule).ok()); + + const char* path = "/tmp/warppipe_test_config.json"; + REQUIRE(result.value->SaveConfig(path).ok()); + + auto result2 = warppipe::Client::Create(DefaultOptions()); + if (!result2.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + REQUIRE(result2.value->LoadConfig(path).ok()); + + auto rules = result2.value->ListRouteRules(); + REQUIRE(rules.ok()); + REQUIRE(rules.value.size() == 1); + REQUIRE(rules.value[0].match.application_name == "firefox"); + REQUIRE(rules.value[0].match.media_role == "Music"); + REQUIRE(rules.value[0].target_node == "headphones"); + + std::remove(path); +} + +TEST_CASE("load corrupted config returns error") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + const char* path = "/tmp/warppipe_test_corrupt.json"; + { + std::ofstream f(path); + f << "{{{{not valid json!!!!"; + } + + auto status = result.value->LoadConfig(path); + REQUIRE_FALSE(status.ok()); + REQUIRE(status.code == warppipe::StatusCode::kInvalidArgument); + + std::remove(path); +} + +TEST_CASE("load missing config returns not found") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + auto status = result.value->LoadConfig("/tmp/warppipe_nonexistent_config_12345.json"); + REQUIRE_FALSE(status.ok()); + REQUIRE(status.code == warppipe::StatusCode::kNotFound); +} + +TEST_CASE("save config with empty path returns error") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + auto status = result.value->SaveConfig(""); + REQUIRE_FALSE(status.ok()); + REQUIRE(status.code == warppipe::StatusCode::kInvalidArgument); +} + +TEST_CASE("load config missing version returns error") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + const char* path = "/tmp/warppipe_test_noversion.json"; + { + std::ofstream f(path); + f << R"({"route_rules": []})"; + } + + auto status = result.value->LoadConfig(path); + REQUIRE_FALSE(status.ok()); + REQUIRE(status.code == warppipe::StatusCode::kInvalidArgument); + + std::remove(path); +} + +TEST_CASE("metadata defaults are initially empty") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + auto defaults = result.value->GetDefaults(); + REQUIRE(defaults.ok()); +} + +TEST_CASE("set default sink without metadata returns unavailable") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + auto status = result.value->SetDefaultSink(""); + REQUIRE_FALSE(status.ok()); +} + +TEST_CASE("NodeInfo captures application properties") { + auto result = warppipe::Client::Create(DefaultOptions()); + if (!result.ok()) { + SUCCEED("PipeWire unavailable"); + return; + } + + warppipe::NodeInfo node; + node.id = warppipe::NodeId{800001}; + node.name = "test-node-props"; + node.media_class = "Audio/Sink"; + node.application_name = "my-app"; + node.process_binary = "my-binary"; + node.media_role = "Music"; + REQUIRE(result.value->Test_InsertNode(node).ok()); + + auto nodes = result.value->ListNodes(); + REQUIRE(nodes.ok()); + for (const auto& n : nodes.value) { + if (n.id.value == 800001) { + REQUIRE(n.application_name == "my-app"); + REQUIRE(n.process_binary == "my-binary"); + REQUIRE(n.media_role == "Music"); + return; + } + } + FAIL("inserted node not found"); +}