From f3f527467ff9f94af7854581d8099074ad485af1 Mon Sep 17 00:00:00 2001 From: eidheim Date: Fri, 7 Jul 2017 23:17:10 +0200 Subject: [PATCH] Now close connections when Server::stop is called --- client_http.hpp | 2 +- server_http.hpp | 92 +++++++++++++++++++++++++++++++------------- server_https.hpp | 8 ++-- tests/parse_test.cpp | 2 +- 4 files changed, 72 insertions(+), 32 deletions(-) diff --git a/client_http.hpp b/client_http.hpp index cf9da1a..bb9c6f7 100644 --- a/client_http.hpp +++ b/client_http.hpp @@ -121,7 +121,7 @@ namespace SimpleWeb { public: Connection(std::unique_ptr &&socket) : socket(std::move(socket)) {} - std::unique_ptr socket; + std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable std::mutex socket_close_mutex; bool in_use = false; bool attempt_reconnect = true; diff --git a/server_http.hpp b/server_http.hpp index fd43736..81cfd87 100644 --- a/server_http.hpp +++ b/server_http.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #ifdef USE_STANDALONE_ASIO #include @@ -92,7 +93,7 @@ namespace SimpleWeb { return; auto self = this->shared_from_this(); session->set_timeout(session->server->config.timeout_content); - asio::async_write(*session->socket, streambuf, [self, callback](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_write(*session->connection->socket, streambuf, [self, callback](const error_code &ec, size_t /*bytes_transferred*/) { self->session->cancel_timeout(); auto lock = self->session->cancel_callbacks_mutex->shared_lock(); if(*self->session->cancel_callbacks) @@ -265,17 +266,31 @@ namespace SimpleWeb { }; protected: + class Connection { + public: + Connection(std::unique_ptr &&socket) : socket(std::move(socket)) {} + + std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable + std::mutex socket_close_mutex; + + void close() { + error_code ec; + std::unique_lock lock(socket_close_mutex); // the following operations seems to be needed to run sequentially + socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); + socket->lowest_layer().close(ec); + } + }; + class Session { public: - Session(ServerBase *server, const std::shared_ptr &socket) + Session(ServerBase *server, const std::shared_ptr &connection) : server(server), cancel_callbacks(server->cancel_callbacks), cancel_callbacks_mutex(server->cancel_callbacks_mutex), - socket(socket), socket_close_mutex(new std::mutex()), request(new Request(*this->socket)) {} + connection(connection), request(new Request(*connection->socket)) {} ServerBase *server; std::shared_ptr cancel_callbacks; std::shared_ptr cancel_callbacks_mutex; - std::shared_ptr socket; - std::shared_ptr socket_close_mutex; + std::shared_ptr connection; std::shared_ptr request; std::unique_ptr timer; @@ -286,17 +301,12 @@ namespace SimpleWeb { return; } - timer = std::unique_ptr(new asio::deadline_timer(socket->get_io_service())); + timer = std::unique_ptr(new asio::deadline_timer(connection->socket->get_io_service())); timer->expires_from_now(boost::posix_time::seconds(seconds)); - auto socket = this->socket; - auto socket_close_mutex = this->socket_close_mutex; - timer->async_wait([socket, socket_close_mutex](const error_code &ec) { - if(!ec) { - error_code ec; - std::unique_lock lock(*socket_close_mutex); // the following operations seems to be needed to run sequentially - socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); - socket->lowest_layer().close(ec); - } + auto connection = this->connection; + timer->async_wait([connection](const error_code &ec) { + if(!ec) + connection->close(); }); } @@ -351,7 +361,7 @@ namespace SimpleWeb { std::function::Request>, const error_code &)> on_error; - std::function, std::shared_ptr::Request>)> on_upgrade; + std::function &, std::shared_ptr::Request>)> on_upgrade; /// If you have your own asio::io_service, store its pointer here before running start(). std::shared_ptr io_service; @@ -399,12 +409,19 @@ namespace SimpleWeb { } } - /// Stop accepting new requests, and close current sessions. + /// Stop accepting new requests, and close current connections. void stop() { if(acceptor) { acceptor->close(); if(internal_io_service) io_service->stop(); + + std::unique_lock lock(*connections_mutex); + if(!internal_io_service) { + for(auto &connection : *connections) + connection->close(); + } + connections->clear(); } } @@ -422,16 +439,39 @@ namespace SimpleWeb { std::unique_ptr acceptor; std::vector threads; + std::shared_ptr> connections; + std::shared_ptr connections_mutex; + std::shared_ptr cancel_callbacks; std::shared_ptr cancel_callbacks_mutex; - ServerBase(unsigned short port) : config(port), cancel_callbacks(new bool(false)), cancel_callbacks_mutex(new SharedMutex()) {} + ServerBase(unsigned short port) : config(port), connections(new std::unordered_set()), connections_mutex(new std::mutex()), + cancel_callbacks(new bool(false)), cancel_callbacks_mutex(new SharedMutex()) {} virtual void accept() = 0; + std::shared_ptr create_connection(socket_type *socket) { + auto connections = this->connections; + auto connections_mutex = this->connections_mutex; + auto connection = std::shared_ptr(new Connection(std::unique_ptr(socket)), [connections, connections_mutex](Connection *connection) { + { + std::unique_lock lock(*connections_mutex); + auto it = connections->find(connection); + if(it != connections->end()) + connections->erase(it); + } + delete connection; + }); + { + std::unique_lock lock(*connections_mutex); + connections->emplace(connection.get()); + } + return connection; + } + void read_request_and_content(const std::shared_ptr &session) { session->set_timeout(config.timeout_request); - asio::async_read_until(*session->socket, session->request->streambuf, "\r\n\r\n", [this, session](const error_code &ec, size_t bytes_transferred) { + asio::async_read_until(*session->connection->socket, session->request->streambuf, "\r\n\r\n", [this, session](const error_code &ec, size_t bytes_transferred) { session->cancel_timeout(); auto lock = session->cancel_callbacks_mutex->shared_lock(); if(*session->cancel_callbacks) @@ -460,7 +500,7 @@ namespace SimpleWeb { } if(content_length > num_additional_bytes) { session->set_timeout(config.timeout_content); - asio::async_read(*session->socket, session->request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, session](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_read(*session->connection->socket, session->request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, session](const error_code &ec, size_t /*bytes_transferred*/) { session->cancel_timeout(); auto lock = session->cancel_callbacks_mutex->shared_lock(); if(*session->cancel_callbacks) @@ -487,7 +527,7 @@ namespace SimpleWeb { if(on_upgrade) { auto it = session->request->header.find("Upgrade"); if(it != session->request->header.end()) { - on_upgrade(session->socket, session->request); + on_upgrade(session->connection->socket, session->request); return; } } @@ -523,13 +563,13 @@ namespace SimpleWeb { if(case_insensitive_equal(it->second, "close")) return; else if(case_insensitive_equal(it->second, "keep-alive")) { - auto new_session = std::make_shared(response->session->server, response->session->socket); + auto new_session = std::make_shared(response->session->server, response->session->connection); this->read_request_and_content(new_session); return; } } if(response->session->request->http_version >= "1.1") { - auto new_session = std::make_shared(response->session->server, response->session->socket); + auto new_session = std::make_shared(response->session->server, response->session->connection); this->read_request_and_content(new_session); return; } @@ -565,9 +605,9 @@ namespace SimpleWeb { protected: void accept() override { - auto session = std::make_shared(this, std::make_shared(*io_service)); + auto session = std::make_shared(this, create_connection(new HTTP(*io_service))); - acceptor->async_accept(*session->socket, [this, session](const error_code &ec) { + acceptor->async_accept(*session->connection->socket, [this, session](const error_code &ec) { auto lock = session->cancel_callbacks_mutex->shared_lock(); if(*session->cancel_callbacks) return; @@ -579,7 +619,7 @@ namespace SimpleWeb { if(!ec) { asio::ip::tcp::no_delay option(true); error_code ec; - session->socket->set_option(option, ec); + session->connection->socket->set_option(option, ec); this->read_request_and_content(session); } diff --git a/server_https.hpp b/server_https.hpp index 44a209c..76390c6 100644 --- a/server_https.hpp +++ b/server_https.hpp @@ -51,9 +51,9 @@ namespace SimpleWeb { asio::ssl::context context; void accept() override { - auto session = std::make_shared(this, std::make_shared(*io_service, context)); + auto session = std::make_shared(this, create_connection(new HTTPS(*io_service, context))); - acceptor->async_accept(session->socket->lowest_layer(), [this, session](const error_code &ec) { + acceptor->async_accept(session->connection->socket->lowest_layer(), [this, session](const error_code &ec) { auto lock = session->cancel_callbacks_mutex->shared_lock(); if(*session->cancel_callbacks) return; @@ -64,10 +64,10 @@ namespace SimpleWeb { if(!ec) { asio::ip::tcp::no_delay option(true); error_code ec; - session->socket->lowest_layer().set_option(option, ec); + session->connection->socket->lowest_layer().set_option(option, ec); session->set_timeout(config.timeout_request); - session->socket->async_handshake(asio::ssl::stream_base::server, [this, session](const error_code &ec) { + session->connection->socket->async_handshake(asio::ssl::stream_base::server, [this, session](const error_code &ec) { session->cancel_timeout(); auto lock = session->cancel_callbacks_mutex->shared_lock(); if(*session->cancel_callbacks) diff --git a/tests/parse_test.cpp b/tests/parse_test.cpp index 35cfb57..965b229 100644 --- a/tests/parse_test.cpp +++ b/tests/parse_test.cpp @@ -13,7 +13,7 @@ public: void accept() override {} void parse_request_test() { - auto session = std::make_shared(this, std::make_shared(*io_service)); + auto session = std::make_shared(this, create_connection(new HTTP(*io_service))); std::ostream stream(&session->request->content.streambuf); stream << "GET /test/ HTTP/1.1\r\n";