Upstream version 5.34.104.0
[platform/framework/web/crosswalk.git] / src / net / websockets / websocket_basic_handshake_stream.cc
index 3d2bcdf..a8d12fc 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <algorithm>
 #include <iterator>
+#include <set>
 #include <string>
 #include <vector>
 
@@ -14,6 +15,7 @@
 #include "base/bind.h"
 #include "base/containers/hash_tables.h"
 #include "base/stl_util.h"
+#include "base/strings/string_number_conversions.h"
 #include "base/strings/string_util.h"
 #include "base/strings/stringprintf.h"
 #include "base/time/time.h"
 #include "net/http/http_stream_parser.h"
 #include "net/socket/client_socket_handle.h"
 #include "net/websockets/websocket_basic_stream.h"
+#include "net/websockets/websocket_deflate_predictor.h"
+#include "net/websockets/websocket_deflate_predictor_impl.h"
+#include "net/websockets/websocket_deflate_stream.h"
+#include "net/websockets/websocket_deflater.h"
 #include "net/websockets/websocket_extension_parser.h"
 #include "net/websockets/websocket_handshake_constants.h"
 #include "net/websockets/websocket_handshake_handler.h"
 #include "net/websockets/websocket_stream.h"
 
 namespace net {
+
+// TODO(ricea): If more extensions are added, replace this with a more general
+// mechanism.
+struct WebSocketExtensionParams {
+  WebSocketExtensionParams()
+      : deflate_enabled(false),
+        client_window_bits(15),
+        deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
+
+  bool deflate_enabled;
+  int client_window_bits;
+  WebSocketDeflater::ContextTakeOverMode deflate_mode;
+};
+
 namespace {
 
 enum GetHeaderResult {
@@ -202,14 +222,82 @@ bool ValidateSubProtocol(
   return true;
 }
 
+bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
+                                        std::string* failure_message,
+                                        WebSocketExtensionParams* params) {
+  static const char kClientPrefix[] = "client_";
+  static const char kServerPrefix[] = "server_";
+  static const char kNoContextTakeover[] = "no_context_takeover";
+  static const char kMaxWindowBits[] = "max_window_bits";
+  const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
+  COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
+                 the_strings_server_and_client_must_be_the_same_length);
+  typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
+
+  DCHECK_EQ("permessage-deflate", extension.name());
+  const ParameterVector& parameters = extension.parameters();
+  std::set<std::string> seen_names;
+  for (ParameterVector::const_iterator it = parameters.begin();
+       it != parameters.end(); ++it) {
+    const std::string& name = it->name();
+    if (seen_names.count(name) != 0) {
+      *failure_message =
+          "Received duplicate permessage-deflate extension parameter " + name;
+      return false;
+    }
+    seen_names.insert(name);
+    const std::string client_or_server(name, 0, kPrefixLen);
+    const bool is_client = (client_or_server == kClientPrefix);
+    if (!is_client && client_or_server != kServerPrefix) {
+      *failure_message =
+          "Received an unexpected permessage-deflate extension parameter";
+      return false;
+    }
+    const std::string rest(name, kPrefixLen);
+    if (rest == kNoContextTakeover) {
+      if (it->HasValue()) {
+        *failure_message = "Received invalid " + name + " parameter";
+        return false;
+      }
+      if (is_client)
+        params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
+    } else if (rest == kMaxWindowBits) {
+      if (!it->HasValue()) {
+        *failure_message = name + " must have value";
+        return false;
+      }
+      int bits = 0;
+      if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
+          it->value()[0] == '0' ||
+          it->value().find_first_not_of("0123456789") != std::string::npos) {
+        *failure_message = "Received invalid " + name + " parameter";
+        return false;
+      }
+      if (is_client)
+        params->client_window_bits = bits;
+    } else {
+      *failure_message =
+          "Received an unexpected permessage-deflate extension parameter";
+      return false;
+    }
+  }
+  params->deflate_enabled = true;
+  return true;
+}
+
 bool ValidateExtensions(const HttpResponseHeaders* headers,
                         const std::vector<std::string>& requested_extensions,
                         std::string* extensions,
-                        std::string* failure_message) {
+                        std::string* failure_message,
+                        WebSocketExtensionParams* params) {
   void* state = NULL;
   std::string value;
+  std::vector<std::string> accepted_extensions;
+  // TODO(ricea): If adding support for additional extensions, generalise this
+  // code.
+  bool seen_permessage_deflate = false;
   while (headers->EnumerateHeader(
-      &state, websockets::kSecWebSocketExtensions, &value)) {
+             &state, websockets::kSecWebSocketExtensions, &value)) {
     WebSocketExtensionParser parser;
     parser.Parse(value);
     if (parser.has_error()) {
@@ -220,13 +308,25 @@ bool ValidateExtensions(const HttpResponseHeaders* headers,
           value;
       return false;
     }
-    // TODO(ricea): Accept permessage-deflate with valid parameters.
-    *failure_message =
-        "Found an unsupported extension '" +
-        parser.extension().name() +
-        "' in 'Sec-WebSocket-Extensions' header";
-    return false;
+    if (parser.extension().name() == "permessage-deflate") {
+      if (seen_permessage_deflate) {
+        *failure_message = "Received duplicate permessage-deflate response";
+        return false;
+      }
+      seen_permessage_deflate = true;
+      if (!ValidatePerMessageDeflateExtension(
+              parser.extension(), failure_message, params))
+        return false;
+    } else {
+      *failure_message =
+          "Found an unsupported extension '" +
+          parser.extension().name() +
+          "' in 'Sec-WebSocket-Extensions' header";
+      return false;
+    }
+    accepted_extensions.push_back(value);
   }
+  *extensions = JoinString(accepted_extensions, ", ");
   return true;
 }
 
@@ -284,12 +384,12 @@ int WebSocketBasicHandshakeStream::SendRequest(
   }
   enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
 
-  AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
-                            requested_sub_protocols_,
-                            &enriched_headers);
   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
                             requested_extensions_,
                             &enriched_headers);
+  AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
+                            requested_sub_protocols_,
+                            &enriched_headers);
 
   ComputeSecWebSocketAccept(handshake_challenge,
                             &handshake_challenge_response_);
@@ -316,10 +416,7 @@ int WebSocketBasicHandshakeStream::ReadResponseHeaders(
                  callback));
   if (rv == ERR_IO_PENDING)
     return rv;
-  if (rv == OK)
-    return ValidateResponse();
-  OnFinishOpeningHandshake();
-  return rv;
+  return ValidateResponse(rv);
 }
 
 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const {
@@ -393,16 +490,25 @@ void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
 }
 
 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
-  // TODO(ricea): Add deflate support.
-
   // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
   // sure it does not touch it again before it is destroyed.
   state_.DeleteParser();
-  return scoped_ptr<WebSocketStream>(
+  scoped_ptr<WebSocketStream> basic_stream(
       new WebSocketBasicStream(state_.ReleaseConnection(),
                                state_.read_buf(),
                                sub_protocol_,
                                extensions_));
+  DCHECK(extension_params_.get());
+  if (extension_params_->deflate_enabled) {
+    return scoped_ptr<WebSocketStream>(
+        new WebSocketDeflateStream(basic_stream.Pass(),
+                                   extension_params_->deflate_mode,
+                                   extension_params_->client_window_bits,
+                                   scoped_ptr<WebSocketDeflatePredictor>(
+                                       new WebSocketDeflatePredictorImpl)));
+  } else {
+    return basic_stream.Pass();
+  }
 }
 
 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
@@ -417,66 +523,79 @@ std::string WebSocketBasicHandshakeStream::GetFailureMessage() const {
 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
     const CompletionCallback& callback,
     int result) {
-  if (result == OK)
-    result = ValidateResponse();
-  else
-    OnFinishOpeningHandshake();
-  callback.Run(result);
+  callback.Run(ValidateResponse(result));
 }
 
 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
   DCHECK(connect_delegate_);
   DCHECK(http_response_info_);
   scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers;
-  scoped_ptr<WebSocketHandshakeResponseInfo> response(
-      new WebSocketHandshakeResponseInfo(url_,
-                                         headers->response_code(),
-                                         headers->GetStatusText(),
-                                         headers,
-                                         http_response_info_->response_time));
-  connect_delegate_->OnFinishOpeningHandshake(response.Pass());
+  // If the headers are too large, HttpStreamParser will just not parse them at
+  // all.
+  if (headers) {
+    scoped_ptr<WebSocketHandshakeResponseInfo> response(
+        new WebSocketHandshakeResponseInfo(url_,
+                                           headers->response_code(),
+                                           headers->GetStatusText(),
+                                           headers,
+                                           http_response_info_->response_time));
+    connect_delegate_->OnFinishOpeningHandshake(response.Pass());
+  }
 }
 
-int WebSocketBasicHandshakeStream::ValidateResponse() {
+int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
   DCHECK(http_response_info_);
-  const scoped_refptr<HttpResponseHeaders>& headers =
-      http_response_info_->headers;
-
-  switch (headers->response_code()) {
-    case HTTP_SWITCHING_PROTOCOLS:
-      OnFinishOpeningHandshake();
-      return ValidateUpgradeResponse(headers);
-
-    // We need to pass these through for authentication to work.
-    case HTTP_UNAUTHORIZED:
-    case HTTP_PROXY_AUTHENTICATION_REQUIRED:
-      return OK;
-
-    // Other status codes are potentially risky (see the warnings in the
-    // WHATWG WebSocket API spec) and so are dropped by default.
-    default:
-      failure_message_ = base::StringPrintf("Unexpected status code: %d",
-                                            headers->response_code());
-      OnFinishOpeningHandshake();
-      return ERR_INVALID_RESPONSE;
+  const HttpResponseHeaders* headers = http_response_info_->headers.get();
+  if (rv >= 0) {
+    switch (headers->response_code()) {
+      case HTTP_SWITCHING_PROTOCOLS:
+        OnFinishOpeningHandshake();
+        return ValidateUpgradeResponse(headers);
+
+      // We need to pass these through for authentication to work.
+      case HTTP_UNAUTHORIZED:
+      case HTTP_PROXY_AUTHENTICATION_REQUIRED:
+        return OK;
+
+      // Other status codes are potentially risky (see the warnings in the
+      // WHATWG WebSocket API spec) and so are dropped by default.
+      default:
+        failure_message_ = base::StringPrintf(
+            "Error during WebSocket handshake: Unexpected response code: %d",
+            headers->response_code());
+        OnFinishOpeningHandshake();
+        return ERR_INVALID_RESPONSE;
+    }
+  } else {
+    if (rv == ERR_EMPTY_RESPONSE) {
+      failure_message_ =
+          "Connection closed before receiving a handshake response";
+      return rv;
+    }
+    failure_message_ =
+        std::string("Error during WebSocket handshake: ") + ErrorToString(rv);
+    OnFinishOpeningHandshake();
+    return rv;
   }
 }
 
 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
-    const scoped_refptr<HttpResponseHeaders>& headers) {
-  if (ValidateUpgrade(headers.get(), &failure_message_) &&
-      ValidateSecWebSocketAccept(headers.get(),
+    const HttpResponseHeaders* headers) {
+  extension_params_.reset(new WebSocketExtensionParams);
+  if (ValidateUpgrade(headers, &failure_message_) &&
+      ValidateSecWebSocketAccept(headers,
                                  handshake_challenge_response_,
                                  &failure_message_) &&
-      ValidateConnection(headers.get(), &failure_message_) &&
-      ValidateSubProtocol(headers.get(),
+      ValidateConnection(headers, &failure_message_) &&
+      ValidateSubProtocol(headers,
                           requested_sub_protocols_,
                           &sub_protocol_,
                           &failure_message_) &&
-      ValidateExtensions(headers.get(),
+      ValidateExtensions(headers,
                          requested_extensions_,
                          &extensions_,
-                         &failure_message_)) {
+                         &failure_message_,
+                         extension_params_.get())) {
     return OK;
   }
   failure_message_ = "Error during WebSocket handshake: " + failure_message_;