#include <algorithm>
#include <iterator>
+#include <set>
#include <string>
#include <vector>
#include "base/bind.h"
#include "base/containers/hash_tables.h"
#include "base/stl_util.h"
+#include "base/strings/string_number_conversions.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/time/time.h"
#include "net/http/http_stream_parser.h"
#include "net/socket/client_socket_handle.h"
#include "net/websockets/websocket_basic_stream.h"
+#include "net/websockets/websocket_deflate_predictor.h"
+#include "net/websockets/websocket_deflate_predictor_impl.h"
+#include "net/websockets/websocket_deflate_stream.h"
+#include "net/websockets/websocket_deflater.h"
#include "net/websockets/websocket_extension_parser.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "net/websockets/websocket_handshake_handler.h"
#include "net/websockets/websocket_stream.h"
namespace net {
+
+// TODO(ricea): If more extensions are added, replace this with a more general
+// mechanism.
+struct WebSocketExtensionParams {
+ WebSocketExtensionParams()
+ : deflate_enabled(false),
+ client_window_bits(15),
+ deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
+
+ bool deflate_enabled;
+ int client_window_bits;
+ WebSocketDeflater::ContextTakeOverMode deflate_mode;
+};
+
namespace {
enum GetHeaderResult {
return true;
}
+bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
+ std::string* failure_message,
+ WebSocketExtensionParams* params) {
+ static const char kClientPrefix[] = "client_";
+ static const char kServerPrefix[] = "server_";
+ static const char kNoContextTakeover[] = "no_context_takeover";
+ static const char kMaxWindowBits[] = "max_window_bits";
+ const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
+ COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
+ the_strings_server_and_client_must_be_the_same_length);
+ typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
+
+ DCHECK_EQ("permessage-deflate", extension.name());
+ const ParameterVector& parameters = extension.parameters();
+ std::set<std::string> seen_names;
+ for (ParameterVector::const_iterator it = parameters.begin();
+ it != parameters.end(); ++it) {
+ const std::string& name = it->name();
+ if (seen_names.count(name) != 0) {
+ *failure_message =
+ "Received duplicate permessage-deflate extension parameter " + name;
+ return false;
+ }
+ seen_names.insert(name);
+ const std::string client_or_server(name, 0, kPrefixLen);
+ const bool is_client = (client_or_server == kClientPrefix);
+ if (!is_client && client_or_server != kServerPrefix) {
+ *failure_message =
+ "Received an unexpected permessage-deflate extension parameter";
+ return false;
+ }
+ const std::string rest(name, kPrefixLen);
+ if (rest == kNoContextTakeover) {
+ if (it->HasValue()) {
+ *failure_message = "Received invalid " + name + " parameter";
+ return false;
+ }
+ if (is_client)
+ params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
+ } else if (rest == kMaxWindowBits) {
+ if (!it->HasValue()) {
+ *failure_message = name + " must have value";
+ return false;
+ }
+ int bits = 0;
+ if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
+ it->value()[0] == '0' ||
+ it->value().find_first_not_of("0123456789") != std::string::npos) {
+ *failure_message = "Received invalid " + name + " parameter";
+ return false;
+ }
+ if (is_client)
+ params->client_window_bits = bits;
+ } else {
+ *failure_message =
+ "Received an unexpected permessage-deflate extension parameter";
+ return false;
+ }
+ }
+ params->deflate_enabled = true;
+ return true;
+}
+
bool ValidateExtensions(const HttpResponseHeaders* headers,
const std::vector<std::string>& requested_extensions,
std::string* extensions,
- std::string* failure_message) {
+ std::string* failure_message,
+ WebSocketExtensionParams* params) {
void* state = NULL;
std::string value;
+ std::vector<std::string> accepted_extensions;
+ // TODO(ricea): If adding support for additional extensions, generalise this
+ // code.
+ bool seen_permessage_deflate = false;
while (headers->EnumerateHeader(
- &state, websockets::kSecWebSocketExtensions, &value)) {
+ &state, websockets::kSecWebSocketExtensions, &value)) {
WebSocketExtensionParser parser;
parser.Parse(value);
if (parser.has_error()) {
value;
return false;
}
- // TODO(ricea): Accept permessage-deflate with valid parameters.
- *failure_message =
- "Found an unsupported extension '" +
- parser.extension().name() +
- "' in 'Sec-WebSocket-Extensions' header";
- return false;
+ if (parser.extension().name() == "permessage-deflate") {
+ if (seen_permessage_deflate) {
+ *failure_message = "Received duplicate permessage-deflate response";
+ return false;
+ }
+ seen_permessage_deflate = true;
+ if (!ValidatePerMessageDeflateExtension(
+ parser.extension(), failure_message, params))
+ return false;
+ } else {
+ *failure_message =
+ "Found an unsupported extension '" +
+ parser.extension().name() +
+ "' in 'Sec-WebSocket-Extensions' header";
+ return false;
+ }
+ accepted_extensions.push_back(value);
}
+ *extensions = JoinString(accepted_extensions, ", ");
return true;
}
}
enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
- AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
- requested_sub_protocols_,
- &enriched_headers);
AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
requested_extensions_,
&enriched_headers);
+ AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
+ requested_sub_protocols_,
+ &enriched_headers);
ComputeSecWebSocketAccept(handshake_challenge,
&handshake_challenge_response_);
callback));
if (rv == ERR_IO_PENDING)
return rv;
- if (rv == OK)
- return ValidateResponse();
- OnFinishOpeningHandshake();
- return rv;
+ return ValidateResponse(rv);
}
const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const {
}
scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
- // TODO(ricea): Add deflate support.
-
// The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
// sure it does not touch it again before it is destroyed.
state_.DeleteParser();
- return scoped_ptr<WebSocketStream>(
+ scoped_ptr<WebSocketStream> basic_stream(
new WebSocketBasicStream(state_.ReleaseConnection(),
state_.read_buf(),
sub_protocol_,
extensions_));
+ DCHECK(extension_params_.get());
+ if (extension_params_->deflate_enabled) {
+ return scoped_ptr<WebSocketStream>(
+ new WebSocketDeflateStream(basic_stream.Pass(),
+ extension_params_->deflate_mode,
+ extension_params_->client_window_bits,
+ scoped_ptr<WebSocketDeflatePredictor>(
+ new WebSocketDeflatePredictorImpl)));
+ } else {
+ return basic_stream.Pass();
+ }
}
void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
const CompletionCallback& callback,
int result) {
- if (result == OK)
- result = ValidateResponse();
- else
- OnFinishOpeningHandshake();
- callback.Run(result);
+ callback.Run(ValidateResponse(result));
}
void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
DCHECK(connect_delegate_);
DCHECK(http_response_info_);
scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers;
- scoped_ptr<WebSocketHandshakeResponseInfo> response(
- new WebSocketHandshakeResponseInfo(url_,
- headers->response_code(),
- headers->GetStatusText(),
- headers,
- http_response_info_->response_time));
- connect_delegate_->OnFinishOpeningHandshake(response.Pass());
+ // If the headers are too large, HttpStreamParser will just not parse them at
+ // all.
+ if (headers) {
+ scoped_ptr<WebSocketHandshakeResponseInfo> response(
+ new WebSocketHandshakeResponseInfo(url_,
+ headers->response_code(),
+ headers->GetStatusText(),
+ headers,
+ http_response_info_->response_time));
+ connect_delegate_->OnFinishOpeningHandshake(response.Pass());
+ }
}
-int WebSocketBasicHandshakeStream::ValidateResponse() {
+int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
DCHECK(http_response_info_);
- const scoped_refptr<HttpResponseHeaders>& headers =
- http_response_info_->headers;
-
- switch (headers->response_code()) {
- case HTTP_SWITCHING_PROTOCOLS:
- OnFinishOpeningHandshake();
- return ValidateUpgradeResponse(headers);
-
- // We need to pass these through for authentication to work.
- case HTTP_UNAUTHORIZED:
- case HTTP_PROXY_AUTHENTICATION_REQUIRED:
- return OK;
-
- // Other status codes are potentially risky (see the warnings in the
- // WHATWG WebSocket API spec) and so are dropped by default.
- default:
- failure_message_ = base::StringPrintf("Unexpected status code: %d",
- headers->response_code());
- OnFinishOpeningHandshake();
- return ERR_INVALID_RESPONSE;
+ const HttpResponseHeaders* headers = http_response_info_->headers.get();
+ if (rv >= 0) {
+ switch (headers->response_code()) {
+ case HTTP_SWITCHING_PROTOCOLS:
+ OnFinishOpeningHandshake();
+ return ValidateUpgradeResponse(headers);
+
+ // We need to pass these through for authentication to work.
+ case HTTP_UNAUTHORIZED:
+ case HTTP_PROXY_AUTHENTICATION_REQUIRED:
+ return OK;
+
+ // Other status codes are potentially risky (see the warnings in the
+ // WHATWG WebSocket API spec) and so are dropped by default.
+ default:
+ failure_message_ = base::StringPrintf(
+ "Error during WebSocket handshake: Unexpected response code: %d",
+ headers->response_code());
+ OnFinishOpeningHandshake();
+ return ERR_INVALID_RESPONSE;
+ }
+ } else {
+ if (rv == ERR_EMPTY_RESPONSE) {
+ failure_message_ =
+ "Connection closed before receiving a handshake response";
+ return rv;
+ }
+ failure_message_ =
+ std::string("Error during WebSocket handshake: ") + ErrorToString(rv);
+ OnFinishOpeningHandshake();
+ return rv;
}
}
int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
- const scoped_refptr<HttpResponseHeaders>& headers) {
- if (ValidateUpgrade(headers.get(), &failure_message_) &&
- ValidateSecWebSocketAccept(headers.get(),
+ const HttpResponseHeaders* headers) {
+ extension_params_.reset(new WebSocketExtensionParams);
+ if (ValidateUpgrade(headers, &failure_message_) &&
+ ValidateSecWebSocketAccept(headers,
handshake_challenge_response_,
&failure_message_) &&
- ValidateConnection(headers.get(), &failure_message_) &&
- ValidateSubProtocol(headers.get(),
+ ValidateConnection(headers, &failure_message_) &&
+ ValidateSubProtocol(headers,
requested_sub_protocols_,
&sub_protocol_,
&failure_message_) &&
- ValidateExtensions(headers.get(),
+ ValidateExtensions(headers,
requested_extensions_,
&extensions_,
- &failure_message_)) {
+ &failure_message_,
+ extension_params_.get())) {
return OK;
}
failure_message_ = "Error during WebSocket handshake: " + failure_message_;