From 69f945f686ff8c6f4ea0bce903501ea45cc2a4d3 Mon Sep 17 00:00:00 2001 From: eidheim Date: Thu, 7 Sep 2017 11:16:01 +0200 Subject: [PATCH] Fixes #155: added max streambuf config items to limit large requests/responses if needed --- client_http.hpp | 109 +++++++++++++++++++++++++++++-------------- client_https.hpp | 8 +++- server_http.hpp | 46 ++++++++++++------ server_https.hpp | 2 +- tests/parse_test.cpp | 8 ++-- 5 files changed, 116 insertions(+), 57 deletions(-) diff --git a/client_http.hpp b/client_http.hpp index 041da2a..7a4767b 100644 --- a/client_http.hpp +++ b/client_http.hpp @@ -66,17 +66,16 @@ namespace SimpleWeb { friend class ClientBase; friend class Client; + asio::streambuf streambuf; + + Response(size_t max_response_streambuf_size) noexcept : streambuf(max_response_streambuf_size), content(streambuf) {} + public: std::string http_version, status_code; Content content; CaseInsensitiveMultimap header; - - private: - asio::streambuf content_buffer; - - Response() noexcept : content(content_buffer) {} }; class Config { @@ -90,6 +89,9 @@ namespace SimpleWeb { long timeout = 0; /// Set connect timeout in seconds. Default value: 0 (Config::timeout is then used instead). long timeout_connect = 0; + /// Maximum size of response stream buffer. Defaults to architecture maximum. + /// Reaching this limit will result in a message_size error code. + size_t max_response_streambuf_size = static_cast(-1); /// Set proxy server (server:port) std::string proxy_server; }; @@ -138,11 +140,11 @@ namespace SimpleWeb { class Session { public: - Session(std::shared_ptr connection, std::unique_ptr request_buffer) noexcept - : connection(std::move(connection)), request_buffer(std::move(request_buffer)), response(new Response()) {} + Session(size_t max_response_streambuf_size, std::shared_ptr connection, std::unique_ptr request_streambuf) noexcept + : connection(std::move(connection)), request_streambuf(std::move(request_streambuf)), response(new Response(max_response_streambuf_size)) {} std::shared_ptr connection; - std::unique_ptr request_buffer; + std::unique_ptr request_streambuf; std::shared_ptr response; std::function &, const error_code &)> callback; }; @@ -219,7 +221,7 @@ namespace SimpleWeb { /// Do not use concurrently with the synchronous request functions. void request(const std::string &method, const std::string &path, string_view content, const CaseInsensitiveMultimap &header, std::function, const error_code &)> &&request_callback_) { - auto session = std::make_shared(get_connection(), create_request_header(method, path, header)); + auto session = std::make_shared(config.max_response_streambuf_size, get_connection(), create_request_header(method, path, header)); auto response = session->response; auto request_callback = std::make_shared, const error_code &)>>(std::move(request_callback_)); session->callback = [this, response, request_callback](const std::shared_ptr &connection, const error_code &ec) { @@ -248,7 +250,7 @@ namespace SimpleWeb { (*request_callback)(response, ec); }; - std::ostream write_stream(session->request_buffer.get()); + std::ostream write_stream(session->request_streambuf.get()); if(content.size() > 0) write_stream << "Content-Length: " << content.size() << "\r\n"; write_stream << "\r\n" @@ -278,7 +280,7 @@ namespace SimpleWeb { /// Asynchronous request where setting and/or running Client's io_service is required. void request(const std::string &method, const std::string &path, std::istream &content, const CaseInsensitiveMultimap &header, std::function, const error_code &)> &&request_callback_) { - auto session = std::make_shared(get_connection(), create_request_header(method, path, header)); + auto session = std::make_shared(config.max_response_streambuf_size, get_connection(), create_request_header(method, path, header)); auto response = session->response; auto request_callback = std::make_shared, const error_code &)>>(std::move(request_callback_)); session->callback = [this, response, request_callback](const std::shared_ptr &connection, const error_code &ec) { @@ -310,7 +312,7 @@ namespace SimpleWeb { content.seekg(0, std::ios::end); auto content_length = content.tellg(); content.seekg(0, std::ios::beg); - std::ostream write_stream(session->request_buffer.get()); + std::ostream write_stream(session->request_streambuf.get()); if(content_length > 0) write_stream << "Content-Length: " << content_length << "\r\n"; write_stream << "\r\n"; @@ -407,13 +409,13 @@ namespace SimpleWeb { if(!config.proxy_server.empty() && std::is_same::value) corrected_path = "http://" + host + ':' + std::to_string(port) + corrected_path; - std::unique_ptr request_buffer(new asio::streambuf()); - std::ostream write_stream(request_buffer.get()); + std::unique_ptr streambuf(new asio::streambuf()); + std::ostream write_stream(streambuf.get()); write_stream << method << " " << corrected_path << " HTTP/1.1\r\n"; write_stream << "Host: " << host << "\r\n"; for(auto &h : header) write_stream << h.first << ": " << h.second << "\r\n"; - return request_buffer; + return streambuf; } std::pair parse_host_port(const std::string &host_port, unsigned short default_port) const noexcept { @@ -432,7 +434,7 @@ namespace SimpleWeb { void write(const std::shared_ptr &session) { session->connection->set_timeout(); - asio::async_write(*session->connection->socket, session->request_buffer->data(), [this, session](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_write(*session->connection->socket, session->request_streambuf->data(), [this, session](const error_code &ec, size_t /*bytes_transferred*/) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) @@ -446,15 +448,18 @@ namespace SimpleWeb { void read(const std::shared_ptr &session) { session->connection->set_timeout(); - asio::async_read_until(*session->connection->socket, session->response->content_buffer, "\r\n\r\n", [this, session](const error_code &ec, size_t bytes_transferred) { + asio::async_read_until(*session->connection->socket, session->response->streambuf, "\r\n\r\n", [this, session](const error_code &ec, size_t bytes_transferred) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; + if((!ec || ec == asio::error::not_found) && session->response->streambuf.size() == session->response->streambuf.max_size()) { + session->callback(session->connection, make_error_code::make_error_code(errc::message_size)); + return; + } if(!ec) { session->connection->attempt_reconnect = true; - - size_t num_additional_bytes = session->response->content_buffer.size() - bytes_transferred; + size_t num_additional_bytes = session->response->streambuf.size() - bytes_transferred; if(!ResponseMessage::parse(session->response->content, session->response->http_version, session->response->status_code, session->response->header)) { session->callback(session->connection, make_error_code::make_error_code(errc::protocol_error)); @@ -466,13 +471,18 @@ namespace SimpleWeb { auto content_length = stoull(header_it->second); if(content_length > num_additional_bytes) { session->connection->set_timeout(); - asio::async_read(*session->connection->socket, session->response->content_buffer, asio::transfer_exactly(content_length - num_additional_bytes), [this, session](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_read(*session->connection->socket, session->response->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, session](const error_code &ec, size_t /*bytes_transferred*/) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; - if(!ec) + if(!ec) { + if(session->response->streambuf.size() == session->response->streambuf.max_size()) { + session->callback(session->connection, make_error_code::make_error_code(errc::message_size)); + return; + } session->callback(session->connection, ec); + } else session->callback(session->connection, ec); }); @@ -486,13 +496,18 @@ namespace SimpleWeb { } else if(session->response->http_version < "1.1" || ((header_it = session->response->header.find("Session")) != session->response->header.end() && header_it->second == "close")) { session->connection->set_timeout(); - asio::async_read(*session->connection->socket, session->response->content_buffer, [this, session](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_read(*session->connection->socket, session->response->streambuf, [this, session](const error_code &ec, size_t /*bytes_transferred*/) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; - if(!ec) + if(!ec) { + if(session->response->streambuf.size() == session->response->streambuf.max_size()) { + session->callback(session->connection, make_error_code::make_error_code(errc::message_size)); + return; + } session->callback(session->connection, ec); + } else session->callback(session->connection, ec == asio::error::eof ? error_code() : ec); }); @@ -525,15 +540,31 @@ namespace SimpleWeb { } void read_chunked(const std::shared_ptr &session, const std::shared_ptr &tmp_streambuf) { + if(tmp_streambuf->size() >= config.max_response_streambuf_size) { + session->callback(session->connection, make_error_code::make_error_code(errc::message_size)); + return; + } + // chunked_streambuf is needed as new read buffer with its size adjusted depending on the size of tmp_streambuf + auto chunked_streambuf = std::make_shared(config.max_response_streambuf_size - tmp_streambuf->size()); + // Move excess read data from session->response->streambuf to chunked_streambuf + if(session->response->streambuf.size() > 0) { + std::ostream chunked_stream(chunked_streambuf.get()); + chunked_stream << &session->response->streambuf; + } session->connection->set_timeout(); - asio::async_read_until(*session->connection->socket, session->response->content_buffer, "\r\n", [this, session, tmp_streambuf](const error_code &ec, size_t bytes_transferred) { + asio::async_read_until(*session->connection->socket, *chunked_streambuf, "\r\n", [this, session, chunked_streambuf, tmp_streambuf](const error_code &ec, size_t bytes_transferred) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; + if((!ec || ec == asio::error::not_found) && chunked_streambuf->size() == chunked_streambuf->max_size()) { + session->callback(session->connection, make_error_code::make_error_code(errc::message_size)); + return; + } if(!ec) { std::string line; - getline(session->response->content, line); + std::istream chunked_stream(chunked_streambuf.get()); + getline(chunked_stream, line); bytes_transferred -= line.size() + 1; line.pop_back(); unsigned long length; @@ -545,25 +576,26 @@ namespace SimpleWeb { return; } - auto num_additional_bytes = session->response->content_buffer.size() - bytes_transferred; + auto num_additional_bytes = chunked_streambuf->size() - bytes_transferred; - auto post_process = [this, session, tmp_streambuf, length]() { + auto post_process = [this, session, chunked_streambuf, tmp_streambuf, length]() { + std::istream chunked_stream(chunked_streambuf.get()); std::ostream tmp_stream(tmp_streambuf.get()); if(length > 0) { - std::vector buffer(static_cast(length)); - session->response->content.read(&buffer[0], static_cast(length)); - tmp_stream.write(&buffer[0], static_cast(length)); + std::unique_ptr buffer(new char[length]); + chunked_stream.read(buffer.get(), static_cast(length)); + tmp_stream.write(buffer.get(), static_cast(length)); } // Remove "\r\n" - session->response->content.get(); - session->response->content.get(); + chunked_stream.get(); + chunked_stream.get(); if(length > 0) this->read_chunked(session, tmp_streambuf); else { - std::ostream response_stream(&session->response->content_buffer); - response_stream << tmp_stream.rdbuf(); + std::ostream response_stream(&session->response->streambuf); + response_stream << tmp_streambuf.get(); error_code ec; session->callback(session->connection, ec); } @@ -571,13 +603,18 @@ namespace SimpleWeb { if((2 + length) > num_additional_bytes) { session->connection->set_timeout(); - asio::async_read(*session->connection->socket, session->response->content_buffer, asio::transfer_exactly(2 + length - num_additional_bytes), [this, session, post_process](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_read(*session->connection->socket, *chunked_streambuf, asio::transfer_exactly(2 + length - num_additional_bytes), [this, session, chunked_streambuf, post_process](const error_code &ec, size_t /*bytes_transferred*/) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; - if(!ec) + if(!ec) { + if(chunked_streambuf->size() == chunked_streambuf->max_size()) { + session->callback(session->connection, make_error_code::make_error_code(errc::message_size)); + return; + } post_process(); + } else session->callback(session->connection, ec); }); diff --git a/client_https.hpp b/client_https.hpp index 3f8b63f..3259248 100644 --- a/client_https.hpp +++ b/client_https.hpp @@ -76,13 +76,17 @@ namespace SimpleWeb { if(!lock) return; if(!ec) { - std::shared_ptr response(new Response()); + std::shared_ptr response(new Response(this->config.max_response_streambuf_size)); session->connection->set_timeout(this->config.timeout_connect); - asio::async_read_until(session->connection->socket->next_layer(), response->content_buffer, "\r\n\r\n", [this, session, response](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_read_until(session->connection->socket->next_layer(), response->streambuf, "\r\n\r\n", [this, session, response](const error_code &ec, size_t /*bytes_transferred*/) { session->connection->cancel_timeout(); auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; + if((!ec || ec == asio::error::not_found) && response->streambuf.size() == response->streambuf.max_size()) { + session->callback(session->connection, make_error_code::make_error_code(errc::message_size)); + return; + } if(!ec) { if(!ResponseMessage::parse(response->content, response->http_version, response->status_code, response->header)) session->callback(session->connection, make_error_code::make_error_code(errc::protocol_error)); diff --git a/server_http.hpp b/server_http.hpp index f1a8077..5b3bb70 100644 --- a/server_http.hpp +++ b/server_http.hpp @@ -2,10 +2,10 @@ #define SERVER_HTTP_HPP #include "utility.hpp" -#include #include #include #include +#include #include #include #include @@ -181,6 +181,10 @@ namespace SimpleWeb { friend class Server; friend class Session; + asio::streambuf streambuf; + Request(size_t max_request_streambuf_size, const std::string &remote_endpoint_address = std::string(), unsigned short remote_endpoint_port = 0) noexcept + : streambuf(max_request_streambuf_size), content(streambuf), remote_endpoint_address(remote_endpoint_address), remote_endpoint_port(remote_endpoint_port) {} + public: std::string method, path, query_string, http_version; @@ -197,12 +201,6 @@ namespace SimpleWeb { CaseInsensitiveMultimap parse_query_string() noexcept { return SimpleWeb::QueryString::parse(query_string); } - - private: - asio::streambuf streambuf; - - Request(const std::string &remote_endpoint_address = std::string(), unsigned short remote_endpoint_port = 0) noexcept - : content(streambuf), remote_endpoint_address(remote_endpoint_address), remote_endpoint_port(remote_endpoint_port) {} }; protected: @@ -250,13 +248,13 @@ namespace SimpleWeb { class Session { public: - Session(std::shared_ptr connection) noexcept : connection(std::move(connection)) { + Session(size_t max_request_streambuf_size, std::shared_ptr connection) noexcept : connection(std::move(connection)) { try { auto remote_endpoint = this->connection->socket->lowest_layer().remote_endpoint(); - request = std::shared_ptr(new Request(remote_endpoint.address().to_string(), remote_endpoint.port())); + request = std::shared_ptr(new Request(max_request_streambuf_size, remote_endpoint.address().to_string(), remote_endpoint.port())); } catch(...) { - request = std::shared_ptr(new Request()); + request = std::shared_ptr(new Request(max_request_streambuf_size)); } } @@ -280,6 +278,9 @@ namespace SimpleWeb { long timeout_request = 5; /// Timeout on content handling. Defaults to 300 seconds. long timeout_content = 300; + /// Maximum size of request stream buffer. Defaults to architecture maximum. + /// Reaching this limit will result in a message_size error code. + size_t max_request_streambuf_size = static_cast(-1); /// IPv4 address in dotted decimal form or IPv6 address in hexadecimal notation. /// If empty, the address will be any address. std::string address; @@ -422,6 +423,14 @@ namespace SimpleWeb { auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; + if((!ec || ec == asio::error::not_found) && session->request->streambuf.size() == session->request->streambuf.max_size()) { + auto response = std::shared_ptr(new Response(session, this->config.timeout_content)); + response->write(StatusCode::client_error_payload_too_large); + response->send(); + if(this->on_error) + this->on_error(session->request, make_error_code::make_error_code(errc::message_size)); + return; + } if(!ec) { // request->streambuf.size() is not necessarily the same as bytes_transferred, from Boost-docs: // "After a successful async_read_until operation, the streambuf may contain additional data beyond the delimiter" @@ -455,8 +464,17 @@ namespace SimpleWeb { auto lock = session->connection->handler_runner->continue_lock(); if(!lock) return; - if(!ec) + if(!ec) { + if(session->request->streambuf.size() == session->request->streambuf.max_size()) { + auto response = std::shared_ptr(new Response(session, this->config.timeout_content)); + response->write(StatusCode::client_error_payload_too_large); + response->send(); + if(this->on_error) + this->on_error(session->request, make_error_code::make_error_code(errc::message_size)); + return; + } this->find_resource(session); + } else if(this->on_error) this->on_error(session->request, ec); }); @@ -521,13 +539,13 @@ namespace SimpleWeb { if(case_insensitive_equal(it->second, "close")) return; else if(case_insensitive_equal(it->second, "keep-alive")) { - auto new_session = std::make_shared(response->session->connection); + auto new_session = std::make_shared(this->config.max_request_streambuf_size, response->session->connection); this->read_request_and_content(new_session); return; } } if(response->session->request->http_version >= "1.1") { - auto new_session = std::make_shared(response->session->connection); + auto new_session = std::make_shared(this->config.max_request_streambuf_size, response->session->connection); this->read_request_and_content(new_session); return; } @@ -560,7 +578,7 @@ namespace SimpleWeb { protected: void accept() override { - auto session = std::make_shared(create_connection(*io_service)); + auto session = std::make_shared(config.max_request_streambuf_size, create_connection(*io_service)); acceptor->async_accept(*session->connection->socket, [this, session](const error_code &ec) { auto lock = session->connection->handler_runner->continue_lock(); diff --git a/server_https.hpp b/server_https.hpp index 1f8b81d..5f5e278 100644 --- a/server_https.hpp +++ b/server_https.hpp @@ -48,7 +48,7 @@ namespace SimpleWeb { asio::ssl::context context; void accept() override { - auto session = std::make_shared(create_connection(*io_service, context)); + auto session = std::make_shared(config.max_request_streambuf_size, create_connection(*io_service, context)); acceptor->async_accept(session->connection->socket->lowest_layer(), [this, session](const error_code &ec) { auto lock = session->connection->handler_runner->continue_lock(); diff --git a/tests/parse_test.cpp b/tests/parse_test.cpp index 8c8a19e..b456d31 100644 --- a/tests/parse_test.cpp +++ b/tests/parse_test.cpp @@ -13,7 +13,7 @@ public: void accept() noexcept override {} void parse_request_test() { - auto session = std::make_shared(create_connection(*io_service)); + auto session = std::make_shared(static_cast(-1), create_connection(*io_service)); std::ostream stream(&session->request->content.streambuf); stream << "GET /test/ HTTP/1.1\r\n"; @@ -72,9 +72,9 @@ public: } void parse_response_header_test() { - std::shared_ptr response(new Response()); + std::shared_ptr response(new Response(static_cast(-1))); - ostream stream(&response->content_buffer); + ostream stream(&response->streambuf); stream << "HTTP/1.1 200 OK\r\n"; stream << "TestHeader: test\r\n"; stream << "TestHeader2:test2\r\n"; @@ -152,7 +152,7 @@ int main() { asio::io_service io_service; asio::ip::tcp::socket socket(io_service); - SimpleWeb::Server::Request request; + SimpleWeb::Server::Request request(static_cast(-1)); { request.query_string = ""; auto queries = request.parse_query_string();