1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_basic_handshake_stream.h"
12 #include "base/base64.h"
13 #include "base/basictypes.h"
14 #include "base/bind.h"
15 #include "base/containers/hash_tables.h"
16 #include "base/stl_util.h"
17 #include "base/strings/string_util.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/time/time.h"
20 #include "crypto/random.h"
21 #include "net/http/http_request_headers.h"
22 #include "net/http/http_request_info.h"
23 #include "net/http/http_response_body_drainer.h"
24 #include "net/http/http_response_headers.h"
25 #include "net/http/http_status_code.h"
26 #include "net/http/http_stream_parser.h"
27 #include "net/socket/client_socket_handle.h"
28 #include "net/websockets/websocket_basic_stream.h"
29 #include "net/websockets/websocket_extension_parser.h"
30 #include "net/websockets/websocket_handshake_constants.h"
31 #include "net/websockets/websocket_handshake_handler.h"
32 #include "net/websockets/websocket_handshake_request_info.h"
33 #include "net/websockets/websocket_handshake_response_info.h"
34 #include "net/websockets/websocket_stream.h"
39 enum GetHeaderResult {
45 std::string MissingHeaderMessage(const std::string& header_name) {
46 return std::string("'") + header_name + "' header is missing";
49 std::string MultipleHeaderValuesMessage(const std::string& header_name) {
53 "' header must not appear more than once in a response";
56 std::string GenerateHandshakeChallenge() {
57 std::string raw_challenge(websockets::kRawChallengeLength, '\0');
58 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
59 std::string encoded_challenge;
60 base::Base64Encode(raw_challenge, &encoded_challenge);
61 return encoded_challenge;
64 void AddVectorHeaderIfNonEmpty(const char* name,
65 const std::vector<std::string>& value,
66 HttpRequestHeaders* headers) {
69 headers->SetHeader(name, JoinString(value, ", "));
72 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
73 const base::StringPiece& name,
76 size_t num_values = 0;
77 std::string temp_value;
78 while (headers->EnumerateHeader(&state, name, &temp_value)) {
80 return GET_HEADER_MULTIPLE;
83 return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
86 bool ValidateHeaderHasSingleValue(GetHeaderResult result,
87 const std::string& header_name,
88 std::string* failure_message) {
89 if (result == GET_HEADER_MISSING) {
90 *failure_message = MissingHeaderMessage(header_name);
93 if (result == GET_HEADER_MULTIPLE) {
94 *failure_message = MultipleHeaderValuesMessage(header_name);
97 DCHECK_EQ(result, GET_HEADER_OK);
101 bool ValidateUpgrade(const HttpResponseHeaders* headers,
102 std::string* failure_message) {
104 GetHeaderResult result =
105 GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
106 if (!ValidateHeaderHasSingleValue(result,
107 websockets::kUpgrade,
112 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
114 "'Upgrade' header value is not 'WebSocket': " + value;
120 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
121 const std::string& expected,
122 std::string* failure_message) {
124 GetHeaderResult result =
125 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
126 if (!ValidateHeaderHasSingleValue(result,
127 websockets::kSecWebSocketAccept,
132 if (expected != actual) {
133 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
139 bool ValidateConnection(const HttpResponseHeaders* headers,
140 std::string* failure_message) {
141 // Connection header is permitted to contain other tokens.
142 if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
143 *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
146 if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
147 websockets::kUpgrade)) {
148 *failure_message = "'Connection' header value must contain 'Upgrade'";
154 bool ValidateSubProtocol(
155 const HttpResponseHeaders* headers,
156 const std::vector<std::string>& requested_sub_protocols,
157 std::string* sub_protocol,
158 std::string* failure_message) {
161 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
162 requested_sub_protocols.end());
164 bool has_multiple_protocols = false;
165 bool has_invalid_protocol = false;
167 while (!has_invalid_protocol || !has_multiple_protocols) {
168 std::string temp_value;
169 if (!headers->EnumerateHeader(
170 &state, websockets::kSecWebSocketProtocol, &temp_value))
173 if (requested_set.count(value) == 0)
174 has_invalid_protocol = true;
176 has_multiple_protocols = true;
179 if (has_multiple_protocols) {
181 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
183 } else if (count > 0 && requested_sub_protocols.size() == 0) {
185 std::string("Response must not include 'Sec-WebSocket-Protocol' "
186 "header if not present in request: ")
189 } else if (has_invalid_protocol) {
191 "'Sec-WebSocket-Protocol' header value '" +
193 "' in response does not match any of sent values";
195 } else if (requested_sub_protocols.size() > 0 && count == 0) {
197 "Sent non-empty 'Sec-WebSocket-Protocol' header "
198 "but no response was received";
201 *sub_protocol = value;
205 bool ValidateExtensions(const HttpResponseHeaders* headers,
206 const std::vector<std::string>& requested_extensions,
207 std::string* extensions,
208 std::string* failure_message) {
211 while (headers->EnumerateHeader(
212 &state, websockets::kSecWebSocketExtensions, &value)) {
213 WebSocketExtensionParser parser;
215 if (parser.has_error()) {
216 // TODO(yhirano) Set appropriate failure message.
218 "'Sec-WebSocket-Extensions' header value is "
219 "rejected by the parser: " +
223 // TODO(ricea): Accept permessage-deflate with valid parameters.
225 "Found an unsupported extension '" +
226 parser.extension().name() +
227 "' in 'Sec-WebSocket-Extensions' header";
235 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
236 scoped_ptr<ClientSocketHandle> connection,
237 WebSocketStream::ConnectDelegate* connect_delegate,
239 std::vector<std::string> requested_sub_protocols,
240 std::vector<std::string> requested_extensions)
241 : state_(connection.release(), using_proxy),
242 connect_delegate_(connect_delegate),
243 http_response_info_(NULL),
244 requested_sub_protocols_(requested_sub_protocols),
245 requested_extensions_(requested_extensions) {}
247 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
249 int WebSocketBasicHandshakeStream::InitializeStream(
250 const HttpRequestInfo* request_info,
251 RequestPriority priority,
252 const BoundNetLog& net_log,
253 const CompletionCallback& callback) {
254 url_ = request_info->url;
255 state_.Initialize(request_info, priority, net_log, callback);
259 int WebSocketBasicHandshakeStream::SendRequest(
260 const HttpRequestHeaders& headers,
261 HttpResponseInfo* response,
262 const CompletionCallback& callback) {
263 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
264 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
265 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
266 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
267 DCHECK(headers.HasHeader(websockets::kUpgrade));
268 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
269 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
272 http_response_info_ = response;
274 // Create a copy of the headers object, so that we can add the
275 // Sec-WebSockey-Key header.
276 HttpRequestHeaders enriched_headers;
277 enriched_headers.CopyFrom(headers);
278 std::string handshake_challenge;
279 if (handshake_challenge_for_testing_) {
280 handshake_challenge = *handshake_challenge_for_testing_;
281 handshake_challenge_for_testing_.reset();
283 handshake_challenge = GenerateHandshakeChallenge();
285 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
287 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
288 requested_sub_protocols_,
290 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
291 requested_extensions_,
294 ComputeSecWebSocketAccept(handshake_challenge,
295 &handshake_challenge_response_);
297 DCHECK(connect_delegate_);
298 scoped_ptr<WebSocketHandshakeRequestInfo> request(
299 new WebSocketHandshakeRequestInfo(url_, base::Time::Now()));
300 request->headers.CopyFrom(enriched_headers);
301 connect_delegate_->OnStartOpeningHandshake(request.Pass());
303 return parser()->SendRequest(
304 state_.GenerateRequestLine(), enriched_headers, response, callback);
307 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
308 const CompletionCallback& callback) {
309 // HttpStreamParser uses a weak pointer when reading from the
310 // socket, so it won't be called back after being destroyed. The
311 // HttpStreamParser is owned by HttpBasicState which is owned by this object,
312 // so this use of base::Unretained() is safe.
313 int rv = parser()->ReadResponseHeaders(
314 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
315 base::Unretained(this),
317 if (rv == ERR_IO_PENDING)
320 return ValidateResponse();
321 OnFinishOpeningHandshake();
325 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const {
326 return parser()->GetResponseInfo();
329 int WebSocketBasicHandshakeStream::ReadResponseBody(
332 const CompletionCallback& callback) {
333 return parser()->ReadResponseBody(buf, buf_len, callback);
336 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
337 // This class ignores the value of |not_reusable| and never lets the socket be
340 parser()->Close(true);
343 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
344 return parser()->IsResponseBodyComplete();
347 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
348 return parser() && parser()->CanFindEndOfResponse();
351 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
352 return parser()->IsConnectionReused();
355 void WebSocketBasicHandshakeStream::SetConnectionReused() {
356 parser()->SetConnectionReused();
359 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
363 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
367 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
368 LoadTimingInfo* load_timing_info) const {
369 return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
373 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
374 parser()->GetSSLInfo(ssl_info);
377 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
378 SSLCertRequestInfo* cert_request_info) {
379 parser()->GetSSLCertRequestInfo(cert_request_info);
382 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
384 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
385 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
386 drainer->Start(session);
387 // |drainer| will delete itself.
390 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
391 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
392 // gone, then copy whatever has happened there over here.
395 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
396 // TODO(ricea): Add deflate support.
398 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
399 // sure it does not touch it again before it is destroyed.
400 state_.DeleteParser();
401 return scoped_ptr<WebSocketStream>(
402 new WebSocketBasicStream(state_.ReleaseConnection(),
408 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
409 const std::string& key) {
410 handshake_challenge_for_testing_.reset(new std::string(key));
413 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const {
414 return failure_message_;
417 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
418 const CompletionCallback& callback,
421 result = ValidateResponse();
423 OnFinishOpeningHandshake();
424 callback.Run(result);
427 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
428 DCHECK(connect_delegate_);
429 DCHECK(http_response_info_);
430 scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers;
431 scoped_ptr<WebSocketHandshakeResponseInfo> response(
432 new WebSocketHandshakeResponseInfo(url_,
433 headers->response_code(),
434 headers->GetStatusText(),
436 http_response_info_->response_time));
437 connect_delegate_->OnFinishOpeningHandshake(response.Pass());
440 int WebSocketBasicHandshakeStream::ValidateResponse() {
441 DCHECK(http_response_info_);
442 const scoped_refptr<HttpResponseHeaders>& headers =
443 http_response_info_->headers;
445 switch (headers->response_code()) {
446 case HTTP_SWITCHING_PROTOCOLS:
447 OnFinishOpeningHandshake();
448 return ValidateUpgradeResponse(headers);
450 // We need to pass these through for authentication to work.
451 case HTTP_UNAUTHORIZED:
452 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
455 // Other status codes are potentially risky (see the warnings in the
456 // WHATWG WebSocket API spec) and so are dropped by default.
458 failure_message_ = base::StringPrintf("Unexpected status code: %d",
459 headers->response_code());
460 OnFinishOpeningHandshake();
461 return ERR_INVALID_RESPONSE;
465 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
466 const scoped_refptr<HttpResponseHeaders>& headers) {
467 if (ValidateUpgrade(headers.get(), &failure_message_) &&
468 ValidateSecWebSocketAccept(headers.get(),
469 handshake_challenge_response_,
470 &failure_message_) &&
471 ValidateConnection(headers.get(), &failure_message_) &&
472 ValidateSubProtocol(headers.get(),
473 requested_sub_protocols_,
475 &failure_message_) &&
476 ValidateExtensions(headers.get(),
477 requested_extensions_,
479 &failure_message_)) {
482 failure_message_ = "Error during WebSocket handshake: " + failure_message_;
483 return ERR_INVALID_RESPONSE;