ServerBase now derives from and use std::enable_shared_from_this

This commit is contained in:
eidheim 2017-07-03 16:58:47 +02:00
commit 30f4f94a03
3 changed files with 118 additions and 110 deletions

View file

@ -43,21 +43,25 @@ namespace SimpleWeb {
class Server;
template <class socket_type>
class ServerBase {
class ServerBase : public std::enable_shared_from_this<ServerBase<socket_type>> {
ServerBase(const ServerBase &) = delete;
ServerBase &operator=(const ServerBase &) = delete;
protected:
class Session;
public:
virtual ~ServerBase() {}
class Response : public std::ostream {
friend class ServerBase<socket_type>;
friend class Server<socket_type>;
asio::streambuf streambuf;
std::shared_ptr<socket_type> socket;
Response(std::shared_ptr<socket_type> &socket) : std::ostream(&streambuf), socket(socket) {}
Response(const std::shared_ptr<socket_type> &socket) : std::ostream(&streambuf), socket(socket) {}
template <class size_type>
void write_header(const CaseInsensitiveMultimap &header, size_type size) {
@ -156,6 +160,7 @@ namespace SimpleWeb {
class Request {
friend class ServerBase<socket_type>;
friend class Server<socket_type>;
friend class Session;
public:
std::string method, path, query_string, http_version;
@ -187,6 +192,18 @@ namespace SimpleWeb {
asio::streambuf streambuf;
};
protected:
class Session {
public:
Session(const std::shared_ptr<ServerBase<socket_type>> &self, const std::shared_ptr<socket_type> &socket)
: self(self), socket(socket), request(new Request(*this->socket)) {}
std::shared_ptr<ServerBase<socket_type>> self;
std::shared_ptr<socket_type> socket;
std::shared_ptr<Request> request;
};
public:
class Config {
friend class ServerBase<socket_type>;
@ -262,7 +279,7 @@ namespace SimpleWeb {
threads.clear();
for(size_t c = 1; c < config.thread_pool_size; c++) {
threads.emplace_back([this]() {
io_service->run();
this->io_service->run();
});
}
@ -271,9 +288,8 @@ namespace SimpleWeb {
io_service->run();
//Wait for the rest of the threads, if any, to finish as well
for(auto &t : threads) {
for(auto &t : threads)
t.join();
}
}
}
@ -285,7 +301,7 @@ namespace SimpleWeb {
///Use this function if you need to recursively send parts of a longer message
void send(std::shared_ptr<Response> &response, const std::function<void(const error_code &)> &callback = nullptr) const {
asio::async_write(*response->socket, response->streambuf, [this, response, callback](const error_code &ec, size_t /*bytes_transferred*/) mutable {
asio::async_write(*response->socket, response->streambuf, [response, callback](const error_code &ec, size_t /*bytes_transferred*/) mutable {
if(callback)
callback(ec);
});
@ -304,31 +320,27 @@ namespace SimpleWeb {
virtual void accept() = 0;
std::shared_ptr<asio::deadline_timer> get_timeout_timer(std::shared_ptr<socket_type> &socket, long seconds) {
std::shared_ptr<asio::deadline_timer> get_timeout_timer(std::shared_ptr<Session> &session, long seconds) {
if(seconds == 0)
return nullptr;
auto timer = std::make_shared<asio::deadline_timer>(*io_service);
timer->expires_from_now(boost::posix_time::seconds(seconds));
timer->async_wait([socket](const error_code &ec) mutable {
timer->async_wait([session](const error_code &ec) mutable {
if(!ec) {
error_code ec;
socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
socket->lowest_layer().close();
session->socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec);
session->socket->lowest_layer().close();
}
});
return timer;
}
void read_request_and_content(std::shared_ptr<socket_type> &socket) {
//Create new streambuf (Request::streambuf) for async_read_until()
//shared_ptr is used to pass temporary objects to the asynchronous functions
std::shared_ptr<Request> request(new Request(*socket));
void read_request_and_content(std::shared_ptr<Session> &session) {
//Set timeout on the following asio::async-read or write function
auto timer = this->get_timeout_timer(socket, config.timeout_request);
auto timer = get_timeout_timer(session, config.timeout_request);
asio::async_read_until(*socket, request->streambuf, "\r\n\r\n", [this, socket, request, timer](const error_code &ec, size_t bytes_transferred) mutable {
asio::async_read_until(*session->socket, session->request->streambuf, "\r\n\r\n", [this, session, timer](const error_code &ec, size_t bytes_transferred) mutable {
if(timer)
timer->cancel();
if(!ec) {
@ -336,52 +348,52 @@ namespace SimpleWeb {
//"After a successful async_read_until operation, the streambuf may contain additional data beyond the delimiter"
//The chosen solution is to extract lines from the stream directly when parsing the header. What is left of the
//streambuf (maybe some bytes of the content) is appended to in the async_read-function below (for retrieving content).
size_t num_additional_bytes = request->streambuf.size() - bytes_transferred;
size_t num_additional_bytes = session->request->streambuf.size() - bytes_transferred;
if(!this->parse_request(request))
if(!this->parse_request(session))
return;
//If content, read that as well
auto it = request->header.find("Content-Length");
if(it != request->header.end()) {
auto it = session->request->header.find("Content-Length");
if(it != session->request->header.end()) {
unsigned long long content_length;
try {
content_length = stoull(it->second);
}
catch(const std::exception &e) {
if(on_error)
on_error(request, make_error_code::make_error_code(errc::protocol_error));
if(this->on_error)
this->on_error(session->request, make_error_code::make_error_code(errc::protocol_error));
return;
}
if(content_length > num_additional_bytes) {
//Set timeout on the following asio::async-read or write function
auto timer = this->get_timeout_timer(socket, config.timeout_content);
asio::async_read(*socket, request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, socket, request, timer](const error_code &ec, size_t /*bytes_transferred*/) mutable {
auto timer = this->get_timeout_timer(session, config.timeout_content);
asio::async_read(*session->socket, session->request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, session, timer](const error_code &ec, size_t /*bytes_transferred*/) mutable {
if(timer)
timer->cancel();
if(!ec)
this->find_resource(socket, request);
else if(on_error)
on_error(request, ec);
this->find_resource(session);
else if(this->on_error)
this->on_error(session->request, ec);
});
}
else
this->find_resource(socket, request);
this->find_resource(session);
}
else
this->find_resource(socket, request);
this->find_resource(session);
}
else if(on_error)
on_error(request, ec);
else if(this->on_error)
this->on_error(session->request, ec);
});
}
bool parse_request(std::shared_ptr<Request> &request) const {
bool parse_request(std::shared_ptr<Session> &session) const {
std::string line;
getline(request->content, line);
getline(session->request->content, line);
size_t method_end;
if((method_end = line.find(' ')) != std::string::npos) {
request->method = line.substr(0, method_end);
session->request->method = line.substr(0, method_end);
size_t query_start = std::string::npos;
size_t path_and_query_string_end = std::string::npos;
@ -395,22 +407,22 @@ namespace SimpleWeb {
}
if(path_and_query_string_end != std::string::npos) {
if(query_start != std::string::npos) {
request->path = line.substr(method_end + 1, query_start - method_end - 2);
request->query_string = line.substr(query_start, path_and_query_string_end - query_start);
session->request->path = line.substr(method_end + 1, query_start - method_end - 2);
session->request->query_string = line.substr(query_start, path_and_query_string_end - query_start);
}
else
request->path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1);
session->request->path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1);
size_t protocol_end;
if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) {
if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0)
return false;
request->http_version = line.substr(protocol_end + 1, line.size() - protocol_end - 2);
session->request->http_version = line.substr(protocol_end + 1, line.size() - protocol_end - 2);
}
else
return false;
getline(request->content, line);
getline(session->request->content, line);
size_t param_end;
while((param_end = line.find(':')) != std::string::npos) {
size_t value_start = param_end + 1;
@ -418,10 +430,10 @@ namespace SimpleWeb {
if(line[value_start] == ' ')
value_start++;
if(value_start < line.size())
request->header.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1));
session->request->header.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1));
}
getline(request->content, line);
getline(session->request->content, line);
}
}
else
@ -432,42 +444,42 @@ namespace SimpleWeb {
return true;
}
void
find_resource(std::shared_ptr<socket_type> &socket, std::shared_ptr<Request> &request) {
void find_resource(std::shared_ptr<Session> &session) {
//Upgrade connection
if(on_upgrade) {
auto it = request->header.find("Upgrade");
if(it != request->header.end()) {
on_upgrade(socket, request);
auto it = session->request->header.find("Upgrade");
if(it != session->request->header.end()) {
on_upgrade(session->socket, session->request);
return;
}
}
//Find path- and method-match, and call write_response
for(auto &regex_method : resource) {
auto it = regex_method.second.find(request->method);
auto it = regex_method.second.find(session->request->method);
if(it != regex_method.second.end()) {
regex::smatch sm_res;
if(regex::regex_match(request->path, sm_res, regex_method.first)) {
request->path_match = std::move(sm_res);
write_response(socket, request, it->second);
if(regex::regex_match(session->request->path, sm_res, regex_method.first)) {
session->request->path_match = std::move(sm_res);
write_response(session, it->second);
return;
}
}
}
auto it = default_resource.find(request->method);
if(it != default_resource.end()) {
write_response(socket, request, it->second);
}
auto it = default_resource.find(session->request->method);
if(it != default_resource.end())
write_response(session, it->second);
}
void write_response(std::shared_ptr<socket_type> &socket, std::shared_ptr<Request> &request,
void write_response(std::shared_ptr<Session> &session,
std::function<void(std::shared_ptr<typename ServerBase<socket_type>::Response> &, std::shared_ptr<typename ServerBase<socket_type>::Request> &)> &resource_function) {
//Set timeout on the following asio::async-read or write function
auto timer = this->get_timeout_timer(socket, config.timeout_content);
auto timer = get_timeout_timer(session, config.timeout_content);
auto response = std::shared_ptr<Response>(new Response(socket), [this, request, timer](Response *response_ptr) mutable {
auto self = session->self;
auto request = session->request;
auto response = std::shared_ptr<Response>(new Response(session->socket), [self, request, timer](Response *response_ptr) mutable {
auto response = std::shared_ptr<Response>(response_ptr);
this->send(response, [this, response, request, timer](const error_code &ec) mutable {
self->send(response, [self, response, request, timer](const error_code &ec) mutable {
if(timer)
timer->cancel();
if(!ec) {
@ -476,28 +488,31 @@ namespace SimpleWeb {
auto range = request->header.equal_range("Connection");
for(auto it = range.first; it != range.second; it++) {
if(case_insensitive_equal(it->second, "close")) {
if(case_insensitive_equal(it->second, "close"))
return;
}
else if(case_insensitive_equal(it->second, "keep-alive")) {
this->read_request_and_content(response->socket);
auto session = std::make_shared<Session>(self, response->socket);
self->read_request_and_content(session);
return;
}
}
if(request->http_version >= "1.1")
this->read_request_and_content(response->socket);
if(request->http_version >= "1.1") {
auto session = std::make_shared<Session>(self, response->socket);
self->read_request_and_content(session);
return;
}
}
else if(on_error)
on_error(request, ec);
else if(self->on_error)
self->on_error(request, ec);
});
});
try {
resource_function(response, request);
resource_function(response, session->request);
}
catch(const std::exception &e) {
if(on_error)
on_error(request, make_error_code::make_error_code(errc::operation_canceled));
on_error(session->request, make_error_code::make_error_code(errc::operation_canceled));
return;
}
}
@ -524,23 +539,21 @@ namespace SimpleWeb {
void accept() override {
//Create new socket for this connection
//Shared_ptr is used to pass temporary objects to the asynchronous functions
auto socket = std::make_shared<HTTP>(*io_service);
auto session = std::make_shared<Session>(this->shared_from_this(), std::make_shared<HTTP>(*io_service));
acceptor->async_accept(*socket, [this, socket](const error_code &ec) mutable {
acceptor->async_accept(*session->socket, [this, session](const error_code &ec) mutable {
//Immediately start accepting a new connection (if io_service hasn't been stopped)
if(ec != asio::error::operation_aborted)
accept();
this->accept();
if(!ec) {
asio::ip::tcp::no_delay option(true);
socket->set_option(option);
session->socket->set_option(option);
this->read_request_and_content(socket);
}
else if(on_error) {
std::shared_ptr<Request> request(new Request(*socket));
on_error(request, ec);
this->read_request_and_content(session);
}
else if(this->on_error)
this->on_error(session->request, ec);
});
}
};