diff --git a/server_http.hpp b/server_http.hpp index ac62d35..57d4a00 100644 --- a/server_http.hpp +++ b/server_http.hpp @@ -43,21 +43,25 @@ namespace SimpleWeb { class Server; template - class ServerBase { + class ServerBase : public std::enable_shared_from_this> { ServerBase(const ServerBase &) = delete; ServerBase &operator=(const ServerBase &) = delete; + protected: + class Session; + public: virtual ~ServerBase() {} class Response : public std::ostream { friend class ServerBase; + friend class Server; asio::streambuf streambuf; std::shared_ptr socket; - Response(std::shared_ptr &socket) : std::ostream(&streambuf), socket(socket) {} + Response(const std::shared_ptr &socket) : std::ostream(&streambuf), socket(socket) {} template void write_header(const CaseInsensitiveMultimap &header, size_type size) { @@ -156,6 +160,7 @@ namespace SimpleWeb { class Request { friend class ServerBase; friend class Server; + friend class Session; public: std::string method, path, query_string, http_version; @@ -187,6 +192,18 @@ namespace SimpleWeb { asio::streambuf streambuf; }; + protected: + class Session { + public: + Session(const std::shared_ptr> &self, const std::shared_ptr &socket) + : self(self), socket(socket), request(new Request(*this->socket)) {} + + std::shared_ptr> self; + std::shared_ptr socket; + std::shared_ptr request; + }; + + public: class Config { friend class ServerBase; @@ -262,7 +279,7 @@ namespace SimpleWeb { threads.clear(); for(size_t c = 1; c < config.thread_pool_size; c++) { threads.emplace_back([this]() { - io_service->run(); + this->io_service->run(); }); } @@ -271,9 +288,8 @@ namespace SimpleWeb { io_service->run(); //Wait for the rest of the threads, if any, to finish as well - for(auto &t : threads) { + for(auto &t : threads) t.join(); - } } } @@ -285,7 +301,7 @@ namespace SimpleWeb { ///Use this function if you need to recursively send parts of a longer message void send(std::shared_ptr &response, const std::function &callback = nullptr) const { - asio::async_write(*response->socket, response->streambuf, [this, response, callback](const error_code &ec, size_t /*bytes_transferred*/) mutable { + asio::async_write(*response->socket, response->streambuf, [response, callback](const error_code &ec, size_t /*bytes_transferred*/) mutable { if(callback) callback(ec); }); @@ -304,31 +320,27 @@ namespace SimpleWeb { virtual void accept() = 0; - std::shared_ptr get_timeout_timer(std::shared_ptr &socket, long seconds) { + std::shared_ptr get_timeout_timer(std::shared_ptr &session, long seconds) { if(seconds == 0) return nullptr; auto timer = std::make_shared(*io_service); timer->expires_from_now(boost::posix_time::seconds(seconds)); - timer->async_wait([socket](const error_code &ec) mutable { + timer->async_wait([session](const error_code &ec) mutable { if(!ec) { error_code ec; - socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); - socket->lowest_layer().close(); + session->socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); + session->socket->lowest_layer().close(); } }); return timer; } - void read_request_and_content(std::shared_ptr &socket) { - //Create new streambuf (Request::streambuf) for async_read_until() - //shared_ptr is used to pass temporary objects to the asynchronous functions - std::shared_ptr request(new Request(*socket)); - + void read_request_and_content(std::shared_ptr &session) { //Set timeout on the following asio::async-read or write function - auto timer = this->get_timeout_timer(socket, config.timeout_request); + auto timer = get_timeout_timer(session, config.timeout_request); - asio::async_read_until(*socket, request->streambuf, "\r\n\r\n", [this, socket, request, timer](const error_code &ec, size_t bytes_transferred) mutable { + asio::async_read_until(*session->socket, session->request->streambuf, "\r\n\r\n", [this, session, timer](const error_code &ec, size_t bytes_transferred) mutable { if(timer) timer->cancel(); if(!ec) { @@ -336,52 +348,52 @@ namespace SimpleWeb { //"After a successful async_read_until operation, the streambuf may contain additional data beyond the delimiter" //The chosen solution is to extract lines from the stream directly when parsing the header. What is left of the //streambuf (maybe some bytes of the content) is appended to in the async_read-function below (for retrieving content). - size_t num_additional_bytes = request->streambuf.size() - bytes_transferred; + size_t num_additional_bytes = session->request->streambuf.size() - bytes_transferred; - if(!this->parse_request(request)) + if(!this->parse_request(session)) return; //If content, read that as well - auto it = request->header.find("Content-Length"); - if(it != request->header.end()) { + auto it = session->request->header.find("Content-Length"); + if(it != session->request->header.end()) { unsigned long long content_length; try { content_length = stoull(it->second); } catch(const std::exception &e) { - if(on_error) - on_error(request, make_error_code::make_error_code(errc::protocol_error)); + if(this->on_error) + this->on_error(session->request, make_error_code::make_error_code(errc::protocol_error)); return; } if(content_length > num_additional_bytes) { //Set timeout on the following asio::async-read or write function - auto timer = this->get_timeout_timer(socket, config.timeout_content); - asio::async_read(*socket, request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, socket, request, timer](const error_code &ec, size_t /*bytes_transferred*/) mutable { + auto timer = this->get_timeout_timer(session, config.timeout_content); + asio::async_read(*session->socket, session->request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, session, timer](const error_code &ec, size_t /*bytes_transferred*/) mutable { if(timer) timer->cancel(); if(!ec) - this->find_resource(socket, request); - else if(on_error) - on_error(request, ec); + this->find_resource(session); + else if(this->on_error) + this->on_error(session->request, ec); }); } else - this->find_resource(socket, request); + this->find_resource(session); } else - this->find_resource(socket, request); + this->find_resource(session); } - else if(on_error) - on_error(request, ec); + else if(this->on_error) + this->on_error(session->request, ec); }); } - bool parse_request(std::shared_ptr &request) const { + bool parse_request(std::shared_ptr &session) const { std::string line; - getline(request->content, line); + getline(session->request->content, line); size_t method_end; if((method_end = line.find(' ')) != std::string::npos) { - request->method = line.substr(0, method_end); + session->request->method = line.substr(0, method_end); size_t query_start = std::string::npos; size_t path_and_query_string_end = std::string::npos; @@ -395,22 +407,22 @@ namespace SimpleWeb { } if(path_and_query_string_end != std::string::npos) { if(query_start != std::string::npos) { - request->path = line.substr(method_end + 1, query_start - method_end - 2); - request->query_string = line.substr(query_start, path_and_query_string_end - query_start); + session->request->path = line.substr(method_end + 1, query_start - method_end - 2); + session->request->query_string = line.substr(query_start, path_and_query_string_end - query_start); } else - request->path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1); + session->request->path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1); size_t protocol_end; if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) { if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0) return false; - request->http_version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); + session->request->http_version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); } else return false; - getline(request->content, line); + getline(session->request->content, line); size_t param_end; while((param_end = line.find(':')) != std::string::npos) { size_t value_start = param_end + 1; @@ -418,10 +430,10 @@ namespace SimpleWeb { if(line[value_start] == ' ') value_start++; if(value_start < line.size()) - request->header.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1)); + session->request->header.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1)); } - getline(request->content, line); + getline(session->request->content, line); } } else @@ -432,42 +444,42 @@ namespace SimpleWeb { return true; } - void - find_resource(std::shared_ptr &socket, std::shared_ptr &request) { + void find_resource(std::shared_ptr &session) { //Upgrade connection if(on_upgrade) { - auto it = request->header.find("Upgrade"); - if(it != request->header.end()) { - on_upgrade(socket, request); + auto it = session->request->header.find("Upgrade"); + if(it != session->request->header.end()) { + on_upgrade(session->socket, session->request); return; } } //Find path- and method-match, and call write_response for(auto ®ex_method : resource) { - auto it = regex_method.second.find(request->method); + auto it = regex_method.second.find(session->request->method); if(it != regex_method.second.end()) { regex::smatch sm_res; - if(regex::regex_match(request->path, sm_res, regex_method.first)) { - request->path_match = std::move(sm_res); - write_response(socket, request, it->second); + if(regex::regex_match(session->request->path, sm_res, regex_method.first)) { + session->request->path_match = std::move(sm_res); + write_response(session, it->second); return; } } } - auto it = default_resource.find(request->method); - if(it != default_resource.end()) { - write_response(socket, request, it->second); - } + auto it = default_resource.find(session->request->method); + if(it != default_resource.end()) + write_response(session, it->second); } - void write_response(std::shared_ptr &socket, std::shared_ptr &request, + void write_response(std::shared_ptr &session, std::function::Response> &, std::shared_ptr::Request> &)> &resource_function) { //Set timeout on the following asio::async-read or write function - auto timer = this->get_timeout_timer(socket, config.timeout_content); + auto timer = get_timeout_timer(session, config.timeout_content); - auto response = std::shared_ptr(new Response(socket), [this, request, timer](Response *response_ptr) mutable { + auto self = session->self; + auto request = session->request; + auto response = std::shared_ptr(new Response(session->socket), [self, request, timer](Response *response_ptr) mutable { auto response = std::shared_ptr(response_ptr); - this->send(response, [this, response, request, timer](const error_code &ec) mutable { + self->send(response, [self, response, request, timer](const error_code &ec) mutable { if(timer) timer->cancel(); if(!ec) { @@ -476,28 +488,31 @@ namespace SimpleWeb { auto range = request->header.equal_range("Connection"); for(auto it = range.first; it != range.second; it++) { - if(case_insensitive_equal(it->second, "close")) { + if(case_insensitive_equal(it->second, "close")) return; - } else if(case_insensitive_equal(it->second, "keep-alive")) { - this->read_request_and_content(response->socket); + auto session = std::make_shared(self, response->socket); + self->read_request_and_content(session); return; } } - if(request->http_version >= "1.1") - this->read_request_and_content(response->socket); + if(request->http_version >= "1.1") { + auto session = std::make_shared(self, response->socket); + self->read_request_and_content(session); + return; + } } - else if(on_error) - on_error(request, ec); + else if(self->on_error) + self->on_error(request, ec); }); }); try { - resource_function(response, request); + resource_function(response, session->request); } catch(const std::exception &e) { if(on_error) - on_error(request, make_error_code::make_error_code(errc::operation_canceled)); + on_error(session->request, make_error_code::make_error_code(errc::operation_canceled)); return; } } @@ -524,23 +539,21 @@ namespace SimpleWeb { void accept() override { //Create new socket for this connection //Shared_ptr is used to pass temporary objects to the asynchronous functions - auto socket = std::make_shared(*io_service); + auto session = std::make_shared(this->shared_from_this(), std::make_shared(*io_service)); - acceptor->async_accept(*socket, [this, socket](const error_code &ec) mutable { + acceptor->async_accept(*session->socket, [this, session](const error_code &ec) mutable { //Immediately start accepting a new connection (if io_service hasn't been stopped) if(ec != asio::error::operation_aborted) - accept(); + this->accept(); if(!ec) { asio::ip::tcp::no_delay option(true); - socket->set_option(option); + session->socket->set_option(option); - this->read_request_and_content(socket); - } - else if(on_error) { - std::shared_ptr request(new Request(*socket)); - on_error(request, ec); + this->read_request_and_content(session); } + else if(this->on_error) + this->on_error(session->request, ec); }); } }; diff --git a/server_https.hpp b/server_https.hpp index 7a649b5..73d9980 100644 --- a/server_https.hpp +++ b/server_https.hpp @@ -57,35 +57,31 @@ namespace SimpleWeb { void accept() override { //Create new socket for this connection //Shared_ptr is used to pass temporary objects to the asynchronous functions - auto socket = std::make_shared(*io_service, context); + auto session = std::make_shared(this->shared_from_this(), std::make_shared(*io_service, context)); - acceptor->async_accept((*socket).lowest_layer(), [this, socket](const error_code &ec) mutable { + acceptor->async_accept(session->socket->lowest_layer(), [this, session](const error_code &ec) mutable { //Immediately start accepting a new connection (if io_service hasn't been stopped) if(ec != asio::error::operation_aborted) - accept(); + this->accept(); if(!ec) { asio::ip::tcp::no_delay option(true); - socket->lowest_layer().set_option(option); + session->socket->lowest_layer().set_option(option); //Set timeout on the following asio::ssl::stream::async_handshake - auto timer = get_timeout_timer(socket, config.timeout_request); - socket->async_handshake(asio::ssl::stream_base::server, [this, socket, timer](const error_code &ec) mutable { + auto timer = this->get_timeout_timer(session, config.timeout_request); + session->socket->async_handshake(asio::ssl::stream_base::server, [this, session, timer](const error_code &ec) mutable { if(timer) timer->cancel(); if(!ec) - read_request_and_content(socket); - else if(on_error) { - std::shared_ptr request(new Request(*socket)); - on_error(request, ec); - } + this->read_request_and_content(session); + else if(this->on_error) + this->on_error(session->request, ec); }); } - else if(on_error) { - std::shared_ptr request(new Request(*socket)); - on_error(request, ec); - } + else if(this->on_error) + this->on_error(session->request, ec); }); } }; diff --git a/tests/parse_test.cpp b/tests/parse_test.cpp index 3cf10d6..8be35fd 100644 --- a/tests/parse_test.cpp +++ b/tests/parse_test.cpp @@ -13,10 +13,9 @@ public: void accept() override {} void parse_request_test() { - HTTP socket(*io_service); - std::shared_ptr request(new Request(socket)); + auto session = std::make_shared(this->shared_from_this(), std::make_shared(*io_service)); - std::ostream stream(&request->content.streambuf); + std::ostream stream(&session->request->content.streambuf); stream << "GET /test/ HTTP/1.1\r\n"; stream << "TestHeader: test\r\n"; stream << "TestHeader2:test2\r\n"; @@ -24,28 +23,28 @@ public: stream << "TestHeader3:test3b\r\n"; stream << "\r\n"; - assert(parse_request(request)); + assert(parse_request(session)); - assert(request->method == "GET"); - assert(request->path == "/test/"); - assert(request->http_version == "1.1"); + assert(session->request->method == "GET"); + assert(session->request->path == "/test/"); + assert(session->request->http_version == "1.1"); - assert(request->header.size() == 4); - auto header_it = request->header.find("TestHeader"); - assert(header_it != request->header.end() && header_it->second == "test"); - header_it = request->header.find("TestHeader2"); - assert(header_it != request->header.end() && header_it->second == "test2"); + assert(session->request->header.size() == 4); + auto header_it = session->request->header.find("TestHeader"); + assert(header_it != session->request->header.end() && header_it->second == "test"); + header_it = session->request->header.find("TestHeader2"); + assert(header_it != session->request->header.end() && header_it->second == "test2"); - header_it = request->header.find("testheader"); - assert(header_it != request->header.end() && header_it->second == "test"); - header_it = request->header.find("testheader2"); - assert(header_it != request->header.end() && header_it->second == "test2"); + header_it = session->request->header.find("testheader"); + assert(header_it != session->request->header.end() && header_it->second == "test"); + header_it = session->request->header.find("testheader2"); + assert(header_it != session->request->header.end() && header_it->second == "test2"); - auto range = request->header.equal_range("testheader3"); + auto range = session->request->header.equal_range("testheader3"); auto first = range.first; auto second = first; ++second; - assert(range.first != request->header.end() && range.second != request->header.end() && + assert(range.first != session->request->header.end() && range.second != session->request->header.end() && ((first->second == "test3a" && second->second == "test3b") || (first->second == "test3b" && second->second == "test3a"))); }