diff --git a/client_http.hpp b/client_http.hpp index 6945914..e4bf516 100644 --- a/client_http.hpp +++ b/client_http.hpp @@ -70,32 +70,6 @@ namespace SimpleWeb { asio::streambuf content_buffer; Response() : content(content_buffer) {} - - void parse_header() { - std::string line; - getline(content, line); - size_t version_end = line.find(' '); - if(version_end != std::string::npos) { - if(5 < line.size()) - http_version = line.substr(5, version_end - 5); - if((version_end + 1) < line.size()) - status_code = line.substr(version_end + 1, line.size() - (version_end + 1) - 1); - - getline(content, line); - size_t param_end; - while((param_end = line.find(':')) != std::string::npos) { - size_t value_start = param_end + 1; - if((value_start) < line.size()) { - if(line[value_start] == ' ') - value_start++; - if(value_start < line.size()) - header.insert(std::make_pair(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1))); - } - - getline(content, line); - } - } - } }; class Config { @@ -490,7 +464,10 @@ namespace SimpleWeb { size_t num_additional_bytes = session->response->content_buffer.size() - bytes_transferred; - session->response->parse_header(); + if(!ResponseMessage::parse(session->response->content, session->response->http_version, session->response->status_code, session->response->header)) { + session->callback(session->connection, make_error_code::make_error_code(errc::protocol_error)); + return; + } auto header_it = session->response->header.find("Content-Length"); if(header_it != session->response->header.end()) { diff --git a/client_https.hpp b/client_https.hpp index b58788a..9d1fe1b 100644 --- a/client_https.hpp +++ b/client_https.hpp @@ -84,11 +84,14 @@ namespace SimpleWeb { if(cancel_pair.first) return; if(!ec) { - response->parse_header(); - if(response->status_code.empty() || response->status_code.compare(0, 3, "200") != 0) - session->callback(session->connection, make_error_code::make_error_code(errc::permission_denied)); - else - this->handshake(session); + if(!ResponseMessage::parse(response->content, response->http_version, response->status_code, response->header)) + session->callback(session->connection, make_error_code::make_error_code(errc::protocol_error)); + else { + if(response->status_code.empty() || response->status_code.compare(0, 3, "200") != 0) + session->callback(session->connection, make_error_code::make_error_code(errc::permission_denied)); + else + this->handshake(session); + } } else session->callback(session->connection, ec); diff --git a/server_http.hpp b/server_http.hpp index 48b223e..b42b835 100644 --- a/server_http.hpp +++ b/server_http.hpp @@ -196,62 +196,6 @@ namespace SimpleWeb { Request(const std::string &remote_endpoint_address = std::string(), unsigned short remote_endpoint_port = 0) : content(streambuf), remote_endpoint_address(remote_endpoint_address), remote_endpoint_port(remote_endpoint_port) {} - - bool parse() { - std::string line; - getline(content, line); - size_t method_end; - if((method_end = line.find(' ')) != std::string::npos) { - method = line.substr(0, method_end); - - size_t query_start = std::string::npos; - size_t path_and_query_string_end = std::string::npos; - for(size_t i = method_end + 1; i < line.size(); ++i) { - if(line[i] == '?' && (i + 1) < line.size()) - query_start = i + 1; - else if(line[i] == ' ') { - path_and_query_string_end = i; - break; - } - } - if(path_and_query_string_end != std::string::npos) { - if(query_start != std::string::npos) { - path = line.substr(method_end + 1, query_start - method_end - 2); - query_string = line.substr(query_start, path_and_query_string_end - query_start); - } - else - path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1); - - size_t protocol_end; - if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) { - if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0) - return false; - http_version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); - } - else - return false; - - getline(content, line); - size_t param_end; - while((param_end = line.find(':')) != std::string::npos) { - size_t value_start = param_end + 1; - if(value_start < line.size()) { - if(line[value_start] == ' ') - value_start++; - if(value_start < line.size()) - header.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1)); - } - - getline(content, line); - } - } - else - return false; - } - else - return false; - return true; - } }; protected: @@ -494,8 +438,12 @@ namespace SimpleWeb { // streambuf (maybe some bytes of the content) is appended to in the async_read-function below (for retrieving content). size_t num_additional_bytes = session->request->streambuf.size() - bytes_transferred; - if(!session->request->parse()) + if(!RequestMessage::parse(session->request->content, session->request->method, session->request->path, + session->request->query_string, session->request->http_version, session->request->header)) { + if(this->on_error) + this->on_error(session->request, make_error_code::make_error_code(errc::protocol_error)); return; + } // If content, read that as well auto it = session->request->header.find("Content-Length"); diff --git a/tests/parse_test.cpp b/tests/parse_test.cpp index 935e295..7887aed 100644 --- a/tests/parse_test.cpp +++ b/tests/parse_test.cpp @@ -23,7 +23,8 @@ public: stream << "TestHeader3:test3b\r\n"; stream << "\r\n"; - assert(session->request->parse()); + assert(RequestMessage::parse(session->request->content, session->request->method, session->request->path, + session->request->query_string, session->request->http_version, session->request->header)); assert(session->request->method == "GET"); assert(session->request->path == "/test/"); @@ -81,7 +82,7 @@ public: stream << "TestHeader3:test3b\r\n"; stream << "\r\n"; - response->parse_header(); + assert(ResponseMessage::parse(response->content, response->http_version, response->status_code, response->header)); assert(response->http_version == "1.1"); assert(response->status_code == "200 OK"); diff --git a/utility.hpp b/utility.hpp index cae0b48..dd5b0da 100644 --- a/utility.hpp +++ b/utility.hpp @@ -135,6 +135,106 @@ namespace SimpleWeb { return result; } }; + + + class RequestMessage { + public: + /// Parse request line and header fields + static bool parse(std::istream &stream, std::string &method, std::string &path, std::string &query_string, std::string &version, CaseInsensitiveMultimap &header) { + header.clear(); + std::string line; + getline(stream, line); + size_t method_end; + if((method_end = line.find(' ')) != std::string::npos) { + method = line.substr(0, method_end); + + size_t query_start = std::string::npos; + size_t path_and_query_string_end = std::string::npos; + for(size_t i = method_end + 1; i < line.size(); ++i) { + if(line[i] == '?' && (i + 1) < line.size()) + query_start = i + 1; + else if(line[i] == ' ') { + path_and_query_string_end = i; + break; + } + } + if(path_and_query_string_end != std::string::npos) { + if(query_start != std::string::npos) { + path = line.substr(method_end + 1, query_start - method_end - 2); + query_string = line.substr(query_start, path_and_query_string_end - query_start); + } + else + path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1); + + size_t protocol_end; + if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) { + if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0) + return false; + version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); + } + else + return false; + + getline(stream, line); + size_t param_end; + while((param_end = line.find(':')) != std::string::npos) { + size_t value_start = param_end + 1; + if(value_start < line.size()) { + if(line[value_start] == ' ') + value_start++; + if(value_start < line.size()) + header.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1)); + } + + getline(stream, line); + } + } + else + return false; + } + else + return false; + return true; + } + }; + + class ResponseMessage { + public: + /// Parse status line and header fields + static bool parse(std::istream &stream, std::string &version, std::string &status_code, CaseInsensitiveMultimap &header) { + header.clear(); + std::string line; + getline(stream, line); + size_t version_end = line.find(' '); + if(version_end != std::string::npos) { + if(5 < line.size()) + version = line.substr(5, version_end - 5); + else + return false; + if((version_end + 1) < line.size()) + status_code = line.substr(version_end + 1, line.size() - (version_end + 1) - 1); + else + return false; + + getline(stream, line); + size_t param_end; + while((param_end = line.find(':')) != std::string::npos) { + size_t value_start = param_end + 1; + if((value_start) < line.size()) { + if(line[value_start] == ' ') + value_start++; + if(value_start < line.size()) + header.insert(std::make_pair(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1))); + } + + getline(stream, line); + } + } + else + return false; + return true; + } + }; } #ifdef PTHREAD_RWLOCK_INITIALIZER