Upstream version 5.34.104.0
[platform/framework/web/crosswalk.git] / src / net / socket / ssl_client_socket_unittest.cc
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/ssl_client_socket.h"
6
7 #include "base/callback_helpers.h"
8 #include "base/memory/ref_counted.h"
9 #include "net/base/address_list.h"
10 #include "net/base/io_buffer.h"
11 #include "net/base/net_errors.h"
12 #include "net/base/net_log.h"
13 #include "net/base/net_log_unittest.h"
14 #include "net/base/test_completion_callback.h"
15 #include "net/base/test_data_directory.h"
16 #include "net/cert/mock_cert_verifier.h"
17 #include "net/cert/test_root_certs.h"
18 #include "net/dns/host_resolver.h"
19 #include "net/http/transport_security_state.h"
20 #include "net/socket/client_socket_factory.h"
21 #include "net/socket/client_socket_handle.h"
22 #include "net/socket/socket_test_util.h"
23 #include "net/socket/tcp_client_socket.h"
24 #include "net/ssl/ssl_cert_request_info.h"
25 #include "net/ssl/ssl_config_service.h"
26 #include "net/test/cert_test_util.h"
27 #include "net/test/spawned_test_server/spawned_test_server.h"
28 #include "testing/gtest/include/gtest/gtest.h"
29 #include "testing/platform_test.h"
30
31 //-----------------------------------------------------------------------------
32
33 namespace net {
34
35 namespace {
36
37 const SSLConfig kDefaultSSLConfig;
38
39 // WrappedStreamSocket is a base class that wraps an existing StreamSocket,
40 // forwarding the Socket and StreamSocket interfaces to the underlying
41 // transport.
42 // This is to provide a common base class for subclasses to override specific
43 // StreamSocket methods for testing, while still communicating with a 'real'
44 // StreamSocket.
45 class WrappedStreamSocket : public StreamSocket {
46  public:
47   explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport)
48       : transport_(transport.Pass()) {}
49   virtual ~WrappedStreamSocket() {}
50
51   // StreamSocket implementation:
52   virtual int Connect(const CompletionCallback& callback) OVERRIDE {
53     return transport_->Connect(callback);
54   }
55   virtual void Disconnect() OVERRIDE { transport_->Disconnect(); }
56   virtual bool IsConnected() const OVERRIDE {
57     return transport_->IsConnected();
58   }
59   virtual bool IsConnectedAndIdle() const OVERRIDE {
60     return transport_->IsConnectedAndIdle();
61   }
62   virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
63     return transport_->GetPeerAddress(address);
64   }
65   virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
66     return transport_->GetLocalAddress(address);
67   }
68   virtual const BoundNetLog& NetLog() const OVERRIDE {
69     return transport_->NetLog();
70   }
71   virtual void SetSubresourceSpeculation() OVERRIDE {
72     transport_->SetSubresourceSpeculation();
73   }
74   virtual void SetOmniboxSpeculation() OVERRIDE {
75     transport_->SetOmniboxSpeculation();
76   }
77   virtual bool WasEverUsed() const OVERRIDE {
78     return transport_->WasEverUsed();
79   }
80   virtual bool UsingTCPFastOpen() const OVERRIDE {
81     return transport_->UsingTCPFastOpen();
82   }
83   virtual bool WasNpnNegotiated() const OVERRIDE {
84     return transport_->WasNpnNegotiated();
85   }
86   virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
87     return transport_->GetNegotiatedProtocol();
88   }
89   virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
90     return transport_->GetSSLInfo(ssl_info);
91   }
92
93   // Socket implementation:
94   virtual int Read(IOBuffer* buf,
95                    int buf_len,
96                    const CompletionCallback& callback) OVERRIDE {
97     return transport_->Read(buf, buf_len, callback);
98   }
99   virtual int Write(IOBuffer* buf,
100                     int buf_len,
101                     const CompletionCallback& callback) OVERRIDE {
102     return transport_->Write(buf, buf_len, callback);
103   }
104   virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
105     return transport_->SetReceiveBufferSize(size);
106   }
107   virtual bool SetSendBufferSize(int32 size) OVERRIDE {
108     return transport_->SetSendBufferSize(size);
109   }
110
111  protected:
112   scoped_ptr<StreamSocket> transport_;
113 };
114
115 // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that
116 // will ensure a certain amount of data is internally buffered before
117 // satisfying a Read() request. It exists to mimic OS-level internal
118 // buffering, but in a way to guarantee that X number of bytes will be
119 // returned to callers of Read(), regardless of how quickly the OS receives
120 // them from the TestServer.
121 class ReadBufferingStreamSocket : public WrappedStreamSocket {
122  public:
123   explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport);
124   virtual ~ReadBufferingStreamSocket() {}
125
126   // Socket implementation:
127   virtual int Read(IOBuffer* buf,
128                    int buf_len,
129                    const CompletionCallback& callback) OVERRIDE;
130
131   // Sets the internal buffer to |size|. This must not be greater than
132   // the largest value supplied to Read() - that is, it does not handle
133   // having "leftovers" at the end of Read().
134   // Each call to Read() will be prevented from completion until at least
135   // |size| data has been read.
136   // Set to 0 to turn off buffering, causing Read() to transparently
137   // read via the underlying transport.
138   void SetBufferSize(int size);
139
140  private:
141   enum State {
142     STATE_NONE,
143     STATE_READ,
144     STATE_READ_COMPLETE,
145   };
146
147   int DoLoop(int result);
148   int DoRead();
149   int DoReadComplete(int result);
150   void OnReadCompleted(int result);
151
152   State state_;
153   scoped_refptr<GrowableIOBuffer> read_buffer_;
154   int buffer_size_;
155
156   scoped_refptr<IOBuffer> user_read_buf_;
157   CompletionCallback user_read_callback_;
158 };
159
160 ReadBufferingStreamSocket::ReadBufferingStreamSocket(
161     scoped_ptr<StreamSocket> transport)
162     : WrappedStreamSocket(transport.Pass()),
163       read_buffer_(new GrowableIOBuffer()),
164       buffer_size_(0) {}
165
166 void ReadBufferingStreamSocket::SetBufferSize(int size) {
167   DCHECK(!user_read_buf_.get());
168   buffer_size_ = size;
169   read_buffer_->SetCapacity(size);
170 }
171
172 int ReadBufferingStreamSocket::Read(IOBuffer* buf,
173                                     int buf_len,
174                                     const CompletionCallback& callback) {
175   if (buffer_size_ == 0)
176     return transport_->Read(buf, buf_len, callback);
177
178   if (buf_len < buffer_size_)
179     return ERR_UNEXPECTED;
180
181   state_ = STATE_READ;
182   user_read_buf_ = buf;
183   int result = DoLoop(OK);
184   if (result == ERR_IO_PENDING)
185     user_read_callback_ = callback;
186   else
187     user_read_buf_ = NULL;
188   return result;
189 }
190
191 int ReadBufferingStreamSocket::DoLoop(int result) {
192   int rv = result;
193   do {
194     State current_state = state_;
195     state_ = STATE_NONE;
196     switch (current_state) {
197       case STATE_READ:
198         rv = DoRead();
199         break;
200       case STATE_READ_COMPLETE:
201         rv = DoReadComplete(rv);
202         break;
203       case STATE_NONE:
204       default:
205         NOTREACHED() << "Unexpected state: " << current_state;
206         rv = ERR_UNEXPECTED;
207         break;
208     }
209   } while (rv != ERR_IO_PENDING && state_ != STATE_NONE);
210   return rv;
211 }
212
213 int ReadBufferingStreamSocket::DoRead() {
214   state_ = STATE_READ_COMPLETE;
215   int rv =
216       transport_->Read(read_buffer_.get(),
217                        read_buffer_->RemainingCapacity(),
218                        base::Bind(&ReadBufferingStreamSocket::OnReadCompleted,
219                                   base::Unretained(this)));
220   return rv;
221 }
222
223 int ReadBufferingStreamSocket::DoReadComplete(int result) {
224   state_ = STATE_NONE;
225   if (result <= 0)
226     return result;
227
228   read_buffer_->set_offset(read_buffer_->offset() + result);
229   if (read_buffer_->RemainingCapacity() > 0) {
230     state_ = STATE_READ;
231     return OK;
232   }
233
234   memcpy(user_read_buf_->data(),
235          read_buffer_->StartOfBuffer(),
236          read_buffer_->capacity());
237   read_buffer_->set_offset(0);
238   return read_buffer_->capacity();
239 }
240
241 void ReadBufferingStreamSocket::OnReadCompleted(int result) {
242   result = DoLoop(result);
243   if (result == ERR_IO_PENDING)
244     return;
245
246   user_read_buf_ = NULL;
247   base::ResetAndReturn(&user_read_callback_).Run(result);
248 }
249
250 // Simulates synchronously receiving an error during Read() or Write()
251 class SynchronousErrorStreamSocket : public WrappedStreamSocket {
252  public:
253   explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport);
254   virtual ~SynchronousErrorStreamSocket() {}
255
256   // Socket implementation:
257   virtual int Read(IOBuffer* buf,
258                    int buf_len,
259                    const CompletionCallback& callback) OVERRIDE;
260   virtual int Write(IOBuffer* buf,
261                     int buf_len,
262                     const CompletionCallback& callback) OVERRIDE;
263
264   // Sets the next Read() call and all future calls to return |error|.
265   // If there is already a pending asynchronous read, the configured error
266   // will not be returned until that asynchronous read has completed and Read()
267   // is called again.
268   void SetNextReadError(Error error) {
269     DCHECK_GE(0, error);
270     have_read_error_ = true;
271     pending_read_error_ = error;
272   }
273
274   // Sets the next Write() call and all future calls to return |error|.
275   // If there is already a pending asynchronous write, the configured error
276   // will not be returned until that asynchronous write has completed and
277   // Write() is called again.
278   void SetNextWriteError(Error error) {
279     DCHECK_GE(0, error);
280     have_write_error_ = true;
281     pending_write_error_ = error;
282   }
283
284  private:
285   bool have_read_error_;
286   int pending_read_error_;
287
288   bool have_write_error_;
289   int pending_write_error_;
290
291   DISALLOW_COPY_AND_ASSIGN(SynchronousErrorStreamSocket);
292 };
293
294 SynchronousErrorStreamSocket::SynchronousErrorStreamSocket(
295     scoped_ptr<StreamSocket> transport)
296     : WrappedStreamSocket(transport.Pass()),
297       have_read_error_(false),
298       pending_read_error_(OK),
299       have_write_error_(false),
300       pending_write_error_(OK) {}
301
302 int SynchronousErrorStreamSocket::Read(IOBuffer* buf,
303                                        int buf_len,
304                                        const CompletionCallback& callback) {
305   if (have_read_error_)
306     return pending_read_error_;
307   return transport_->Read(buf, buf_len, callback);
308 }
309
310 int SynchronousErrorStreamSocket::Write(IOBuffer* buf,
311                                         int buf_len,
312                                         const CompletionCallback& callback) {
313   if (have_write_error_)
314     return pending_write_error_;
315   return transport_->Write(buf, buf_len, callback);
316 }
317
318 // FakeBlockingStreamSocket wraps an existing StreamSocket and simulates the
319 // underlying transport needing to complete things asynchronously in a
320 // deterministic manner (e.g.: independent of the TestServer and the OS's
321 // semantics).
322 class FakeBlockingStreamSocket : public WrappedStreamSocket {
323  public:
324   explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport);
325   virtual ~FakeBlockingStreamSocket() {}
326
327   // Socket implementation:
328   virtual int Read(IOBuffer* buf,
329                    int buf_len,
330                    const CompletionCallback& callback) OVERRIDE {
331     return read_state_.RunWrappedFunction(buf, buf_len, callback);
332   }
333   virtual int Write(IOBuffer* buf,
334                     int buf_len,
335                     const CompletionCallback& callback) OVERRIDE {
336     return write_state_.RunWrappedFunction(buf, buf_len, callback);
337   }
338
339   // Causes the next call to Read() to return ERR_IO_PENDING, not completing
340   // (invoking the callback) until UnblockRead() has been called and the
341   // underlying transport has completed.
342   void SetNextReadShouldBlock() { read_state_.SetShouldBlock(); }
343   void UnblockRead() { read_state_.Unblock(); }
344
345   // Causes the next call to Write() to return ERR_IO_PENDING, not completing
346   // (invoking the callback) until UnblockWrite() has been called and the
347   // underlying transport has completed.
348   void SetNextWriteShouldBlock() { write_state_.SetShouldBlock(); }
349   void UnblockWrite() { write_state_.Unblock(); }
350
351  private:
352   // Tracks the state for simulating a blocking Read/Write operation.
353   class BlockingState {
354    public:
355     // Wrapper for the underlying Socket function to call (ie: Read/Write).
356     typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)>
357         WrappedSocketFunction;
358
359     explicit BlockingState(const WrappedSocketFunction& function);
360     ~BlockingState() {}
361
362     // Sets the next call to RunWrappedFunction() to block, returning
363     // ERR_IO_PENDING and not invoking the user callback until Unblock() is
364     // called.
365     void SetShouldBlock();
366
367     // Unblocks the currently blocked pending function, invoking the user
368     // callback if the results are immediately available.
369     // Note: It's not valid to call this unless SetShouldBlock() has been
370     // called beforehand.
371     void Unblock();
372
373     // Performs the wrapped socket function on the underlying transport. If
374     // configured to block via SetShouldBlock(), then |user_callback| will not
375     // be invoked until Unblock() has been called.
376     int RunWrappedFunction(IOBuffer* buf,
377                            int len,
378                            const CompletionCallback& user_callback);
379
380    private:
381     // Handles completion from the underlying wrapped socket function.
382     void OnCompleted(int result);
383
384     WrappedSocketFunction wrapped_function_;
385     bool should_block_;
386     bool have_result_;
387     int pending_result_;
388     CompletionCallback user_callback_;
389   };
390
391   BlockingState read_state_;
392   BlockingState write_state_;
393
394   DISALLOW_COPY_AND_ASSIGN(FakeBlockingStreamSocket);
395 };
396
397 FakeBlockingStreamSocket::FakeBlockingStreamSocket(
398     scoped_ptr<StreamSocket> transport)
399     : WrappedStreamSocket(transport.Pass()),
400       read_state_(base::Bind(&Socket::Read,
401                              base::Unretained(transport_.get()))),
402       write_state_(base::Bind(&Socket::Write,
403                               base::Unretained(transport_.get()))) {}
404
405 FakeBlockingStreamSocket::BlockingState::BlockingState(
406     const WrappedSocketFunction& function)
407     : wrapped_function_(function),
408       should_block_(false),
409       have_result_(false),
410       pending_result_(OK) {}
411
412 void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() {
413   DCHECK(!should_block_);
414   should_block_ = true;
415 }
416
417 void FakeBlockingStreamSocket::BlockingState::Unblock() {
418   DCHECK(should_block_);
419   should_block_ = false;
420
421   // If the operation is still pending in the underlying transport, immediately
422   // return - OnCompleted() will handle invoking the callback once the transport
423   // has completed.
424   if (!have_result_)
425     return;
426
427   have_result_ = false;
428
429   base::ResetAndReturn(&user_callback_).Run(pending_result_);
430 }
431
432 int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction(
433     IOBuffer* buf,
434     int len,
435     const CompletionCallback& callback) {
436
437   // The callback to be called by the underlying transport. Either forward
438   // directly to the user's callback if not set to block, or intercept it with
439   // OnCompleted so that the user's callback is not invoked until Unblock() is
440   // called.
441   CompletionCallback transport_callback =
442       !should_block_ ? callback : base::Bind(&BlockingState::OnCompleted,
443                                              base::Unretained(this));
444   int rv = wrapped_function_.Run(buf, len, transport_callback);
445   if (should_block_) {
446     user_callback_ = callback;
447     // May have completed synchronously.
448     have_result_ = (rv != ERR_IO_PENDING);
449     pending_result_ = rv;
450     return ERR_IO_PENDING;
451   }
452
453   return rv;
454 }
455
456 void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) {
457   if (should_block_) {
458     // Store the result so that the callback can be invoked once Unblock() is
459     // called.
460     have_result_ = true;
461     pending_result_ = result;
462     return;
463   }
464
465   // Otherwise, the Unblock() function was called before the underlying
466   // transport completed, so run the user's callback immediately.
467   base::ResetAndReturn(&user_callback_).Run(result);
468 }
469
470 // CompletionCallback that will delete the associated StreamSocket when
471 // the callback is invoked.
472 class DeleteSocketCallback : public TestCompletionCallbackBase {
473  public:
474   explicit DeleteSocketCallback(StreamSocket* socket)
475       : socket_(socket),
476         callback_(base::Bind(&DeleteSocketCallback::OnComplete,
477                              base::Unretained(this))) {}
478   virtual ~DeleteSocketCallback() {}
479
480   const CompletionCallback& callback() const { return callback_; }
481
482  private:
483   void OnComplete(int result) {
484     if (socket_) {
485       delete socket_;
486       socket_ = NULL;
487     } else {
488       ADD_FAILURE() << "Deleting socket twice";
489     }
490     SetResult(result);
491   }
492
493   StreamSocket* socket_;
494   CompletionCallback callback_;
495
496   DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback);
497 };
498
499 class SSLClientSocketTest : public PlatformTest {
500  public:
501   SSLClientSocketTest()
502       : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
503         cert_verifier_(new MockCertVerifier),
504         transport_security_state_(new TransportSecurityState) {
505     cert_verifier_->set_default_result(OK);
506     context_.cert_verifier = cert_verifier_.get();
507     context_.transport_security_state = transport_security_state_.get();
508   }
509
510  protected:
511   scoped_ptr<SSLClientSocket> CreateSSLClientSocket(
512       scoped_ptr<StreamSocket> transport_socket,
513       const HostPortPair& host_and_port,
514       const SSLConfig& ssl_config) {
515     scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
516     connection->SetSocket(transport_socket.Pass());
517     return socket_factory_->CreateSSLClientSocket(
518         connection.Pass(), host_and_port, ssl_config, context_);
519   }
520
521   ClientSocketFactory* socket_factory_;
522   scoped_ptr<MockCertVerifier> cert_verifier_;
523   scoped_ptr<TransportSecurityState> transport_security_state_;
524   SSLClientSocketContext context_;
525 };
526
527 //-----------------------------------------------------------------------------
528
529 // LogContainsSSLConnectEndEvent returns true if the given index in the given
530 // log is an SSL connect end event. The NSS sockets will cork in an attempt to
531 // merge the first application data record with the Finished message when false
532 // starting. However, in order to avoid the server timing out the handshake,
533 // they'll give up waiting for application data and send the Finished after a
534 // timeout. This means that an SSL connect end event may appear as a socket
535 // write.
536 static bool LogContainsSSLConnectEndEvent(
537     const CapturingNetLog::CapturedEntryList& log,
538     int i) {
539   return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) ||
540          LogContainsEvent(
541              log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
542 }
543
544 TEST_F(SSLClientSocketTest, Connect) {
545   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
546                                 SpawnedTestServer::kLocalhost,
547                                 base::FilePath());
548   ASSERT_TRUE(test_server.Start());
549
550   AddressList addr;
551   ASSERT_TRUE(test_server.GetAddressList(&addr));
552
553   TestCompletionCallback callback;
554   CapturingNetLog log;
555   scoped_ptr<StreamSocket> transport(
556       new TCPClientSocket(addr, &log, NetLog::Source()));
557   int rv = transport->Connect(callback.callback());
558   if (rv == ERR_IO_PENDING)
559     rv = callback.WaitForResult();
560   EXPECT_EQ(OK, rv);
561
562   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
563       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
564
565   EXPECT_FALSE(sock->IsConnected());
566
567   rv = sock->Connect(callback.callback());
568
569   CapturingNetLog::CapturedEntryList entries;
570   log.GetEntries(&entries);
571   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
572   if (rv == ERR_IO_PENDING)
573     rv = callback.WaitForResult();
574   EXPECT_EQ(OK, rv);
575   EXPECT_TRUE(sock->IsConnected());
576   log.GetEntries(&entries);
577   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
578
579   sock->Disconnect();
580   EXPECT_FALSE(sock->IsConnected());
581 }
582
583 TEST_F(SSLClientSocketTest, ConnectExpired) {
584   SpawnedTestServer::SSLOptions ssl_options(
585       SpawnedTestServer::SSLOptions::CERT_EXPIRED);
586   SpawnedTestServer test_server(
587       SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
588   ASSERT_TRUE(test_server.Start());
589
590   cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
591
592   AddressList addr;
593   ASSERT_TRUE(test_server.GetAddressList(&addr));
594
595   TestCompletionCallback callback;
596   CapturingNetLog log;
597   scoped_ptr<StreamSocket> transport(
598       new TCPClientSocket(addr, &log, NetLog::Source()));
599   int rv = transport->Connect(callback.callback());
600   if (rv == ERR_IO_PENDING)
601     rv = callback.WaitForResult();
602   EXPECT_EQ(OK, rv);
603
604   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
605       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
606
607   EXPECT_FALSE(sock->IsConnected());
608
609   rv = sock->Connect(callback.callback());
610
611   CapturingNetLog::CapturedEntryList entries;
612   log.GetEntries(&entries);
613   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
614   if (rv == ERR_IO_PENDING)
615     rv = callback.WaitForResult();
616
617   EXPECT_EQ(ERR_CERT_DATE_INVALID, rv);
618
619   // Rather than testing whether or not the underlying socket is connected,
620   // test that the handshake has finished. This is because it may be
621   // desirable to disconnect the socket before showing a user prompt, since
622   // the user may take indefinitely long to respond.
623   log.GetEntries(&entries);
624   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
625 }
626
627 TEST_F(SSLClientSocketTest, ConnectMismatched) {
628   SpawnedTestServer::SSLOptions ssl_options(
629       SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME);
630   SpawnedTestServer test_server(
631       SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
632   ASSERT_TRUE(test_server.Start());
633
634   cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID);
635
636   AddressList addr;
637   ASSERT_TRUE(test_server.GetAddressList(&addr));
638
639   TestCompletionCallback callback;
640   CapturingNetLog log;
641   scoped_ptr<StreamSocket> transport(
642       new TCPClientSocket(addr, &log, NetLog::Source()));
643   int rv = transport->Connect(callback.callback());
644   if (rv == ERR_IO_PENDING)
645     rv = callback.WaitForResult();
646   EXPECT_EQ(OK, rv);
647
648   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
649       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
650
651   EXPECT_FALSE(sock->IsConnected());
652
653   rv = sock->Connect(callback.callback());
654
655   CapturingNetLog::CapturedEntryList entries;
656   log.GetEntries(&entries);
657   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
658   if (rv == ERR_IO_PENDING)
659     rv = callback.WaitForResult();
660
661   EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, rv);
662
663   // Rather than testing whether or not the underlying socket is connected,
664   // test that the handshake has finished. This is because it may be
665   // desirable to disconnect the socket before showing a user prompt, since
666   // the user may take indefinitely long to respond.
667   log.GetEntries(&entries);
668   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
669 }
670
671 // Attempt to connect to a page which requests a client certificate. It should
672 // return an error code on connect.
673 TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) {
674   SpawnedTestServer::SSLOptions ssl_options;
675   ssl_options.request_client_certificate = true;
676   SpawnedTestServer test_server(
677       SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
678   ASSERT_TRUE(test_server.Start());
679
680   AddressList addr;
681   ASSERT_TRUE(test_server.GetAddressList(&addr));
682
683   TestCompletionCallback callback;
684   CapturingNetLog log;
685   scoped_ptr<StreamSocket> transport(
686       new TCPClientSocket(addr, &log, NetLog::Source()));
687   int rv = transport->Connect(callback.callback());
688   if (rv == ERR_IO_PENDING)
689     rv = callback.WaitForResult();
690   EXPECT_EQ(OK, rv);
691
692   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
693       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
694
695   EXPECT_FALSE(sock->IsConnected());
696
697   rv = sock->Connect(callback.callback());
698
699   CapturingNetLog::CapturedEntryList entries;
700   log.GetEntries(&entries);
701   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
702   if (rv == ERR_IO_PENDING)
703     rv = callback.WaitForResult();
704
705   log.GetEntries(&entries);
706   // Because we prematurely kill the handshake at CertificateRequest,
707   // the server may still send data (notably the ServerHelloDone)
708   // after the error is returned. As a result, the SSL_CONNECT may not
709   // be the last entry. See http://crbug.com/54445. We use
710   // ExpectLogContainsSomewhere instead of
711   // LogContainsSSLConnectEndEvent to avoid assuming, e.g., only one
712   // extra read instead of two. This occurs before the handshake ends,
713   // so the corking logic of LogContainsSSLConnectEndEvent isn't
714   // necessary.
715   //
716   // TODO(davidben): When SSL_RestartHandshakeAfterCertReq in NSS is
717   // fixed and we can respond to the first CertificateRequest
718   // without closing the socket, add a unit test for sending the
719   // certificate. This test may still be useful as we'll want to close
720   // the socket on a timeout if the user takes a long time to pick a
721   // cert. Related bug: https://bugzilla.mozilla.org/show_bug.cgi?id=542832
722   ExpectLogContainsSomewhere(
723       entries, 0, NetLog::TYPE_SSL_CONNECT, NetLog::PHASE_END);
724   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv);
725   EXPECT_FALSE(sock->IsConnected());
726 }
727
728 // Connect to a server requesting optional client authentication. Send it a
729 // null certificate. It should allow the connection.
730 //
731 // TODO(davidben): Also test providing an actual certificate.
732 TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) {
733   SpawnedTestServer::SSLOptions ssl_options;
734   ssl_options.request_client_certificate = true;
735   SpawnedTestServer test_server(
736       SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
737   ASSERT_TRUE(test_server.Start());
738
739   AddressList addr;
740   ASSERT_TRUE(test_server.GetAddressList(&addr));
741
742   TestCompletionCallback callback;
743   CapturingNetLog log;
744   scoped_ptr<StreamSocket> transport(
745       new TCPClientSocket(addr, &log, NetLog::Source()));
746   int rv = transport->Connect(callback.callback());
747   if (rv == ERR_IO_PENDING)
748     rv = callback.WaitForResult();
749   EXPECT_EQ(OK, rv);
750
751   SSLConfig ssl_config = kDefaultSSLConfig;
752   ssl_config.send_client_cert = true;
753   ssl_config.client_cert = NULL;
754
755   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
756       transport.Pass(), test_server.host_port_pair(), ssl_config));
757
758   EXPECT_FALSE(sock->IsConnected());
759
760   // Our test server accepts certificate-less connections.
761   // TODO(davidben): Add a test which requires them and verify the error.
762   rv = sock->Connect(callback.callback());
763
764   CapturingNetLog::CapturedEntryList entries;
765   log.GetEntries(&entries);
766   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
767   if (rv == ERR_IO_PENDING)
768     rv = callback.WaitForResult();
769
770   EXPECT_EQ(OK, rv);
771   EXPECT_TRUE(sock->IsConnected());
772   log.GetEntries(&entries);
773   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
774
775   // We responded to the server's certificate request with a Certificate
776   // message with no client certificate in it.  ssl_info.client_cert_sent
777   // should be false in this case.
778   SSLInfo ssl_info;
779   sock->GetSSLInfo(&ssl_info);
780   EXPECT_FALSE(ssl_info.client_cert_sent);
781
782   sock->Disconnect();
783   EXPECT_FALSE(sock->IsConnected());
784 }
785
786 // TODO(wtc): Add unit tests for IsConnectedAndIdle:
787 //   - Server closes an SSL connection (with a close_notify alert message).
788 //   - Server closes the underlying TCP connection directly.
789 //   - Server sends data unexpectedly.
790
791 TEST_F(SSLClientSocketTest, Read) {
792   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
793                                 SpawnedTestServer::kLocalhost,
794                                 base::FilePath());
795   ASSERT_TRUE(test_server.Start());
796
797   AddressList addr;
798   ASSERT_TRUE(test_server.GetAddressList(&addr));
799
800   TestCompletionCallback callback;
801   scoped_ptr<StreamSocket> transport(
802       new TCPClientSocket(addr, NULL, NetLog::Source()));
803   int rv = transport->Connect(callback.callback());
804   if (rv == ERR_IO_PENDING)
805     rv = callback.WaitForResult();
806   EXPECT_EQ(OK, rv);
807
808   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
809       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
810
811   rv = sock->Connect(callback.callback());
812   if (rv == ERR_IO_PENDING)
813     rv = callback.WaitForResult();
814   EXPECT_EQ(OK, rv);
815   EXPECT_TRUE(sock->IsConnected());
816
817   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
818   scoped_refptr<IOBuffer> request_buffer(
819       new IOBuffer(arraysize(request_text) - 1));
820   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
821
822   rv = sock->Write(
823       request_buffer.get(), arraysize(request_text) - 1, callback.callback());
824   EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
825
826   if (rv == ERR_IO_PENDING)
827     rv = callback.WaitForResult();
828   EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
829
830   scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
831   for (;;) {
832     rv = sock->Read(buf.get(), 4096, callback.callback());
833     EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
834
835     if (rv == ERR_IO_PENDING)
836       rv = callback.WaitForResult();
837
838     EXPECT_GE(rv, 0);
839     if (rv <= 0)
840       break;
841   }
842 }
843
844 // Tests that the SSLClientSocket properly handles when the underlying transport
845 // synchronously returns an error code - such as if an intermediary terminates
846 // the socket connection uncleanly.
847 // This is a regression test for http://crbug.com/238536
848 TEST_F(SSLClientSocketTest, Read_WithSynchronousError) {
849   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
850                                 SpawnedTestServer::kLocalhost,
851                                 base::FilePath());
852   ASSERT_TRUE(test_server.Start());
853
854   AddressList addr;
855   ASSERT_TRUE(test_server.GetAddressList(&addr));
856
857   TestCompletionCallback callback;
858   scoped_ptr<StreamSocket> real_transport(
859       new TCPClientSocket(addr, NULL, NetLog::Source()));
860   scoped_ptr<SynchronousErrorStreamSocket> transport(
861       new SynchronousErrorStreamSocket(real_transport.Pass()));
862   int rv = callback.GetResult(transport->Connect(callback.callback()));
863   EXPECT_EQ(OK, rv);
864
865   // Disable TLS False Start to avoid handshake non-determinism.
866   SSLConfig ssl_config;
867   ssl_config.false_start_enabled = false;
868
869   SynchronousErrorStreamSocket* raw_transport = transport.get();
870   scoped_ptr<SSLClientSocket> sock(
871       CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
872                             test_server.host_port_pair(),
873                             ssl_config));
874
875   rv = callback.GetResult(sock->Connect(callback.callback()));
876   EXPECT_EQ(OK, rv);
877   EXPECT_TRUE(sock->IsConnected());
878
879   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
880   static const int kRequestTextSize =
881       static_cast<int>(arraysize(request_text) - 1);
882   scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
883   memcpy(request_buffer->data(), request_text, kRequestTextSize);
884
885   rv = callback.GetResult(
886       sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
887   EXPECT_EQ(kRequestTextSize, rv);
888
889   // Simulate an unclean/forcible shutdown.
890   raw_transport->SetNextReadError(ERR_CONNECTION_RESET);
891
892   scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
893
894   // Note: This test will hang if this bug has regressed. Simply checking that
895   // rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate
896   // result when using a dedicated task runner for NSS.
897   rv = callback.GetResult(sock->Read(buf.get(), 4096, callback.callback()));
898
899 #if !defined(USE_OPENSSL)
900   // SSLClientSocketNSS records the error exactly
901   EXPECT_EQ(ERR_CONNECTION_RESET, rv);
902 #else
903   // SSLClientSocketOpenSSL treats any errors as a simple EOF.
904   EXPECT_EQ(0, rv);
905 #endif
906 }
907
908 // Tests that the SSLClientSocket properly handles when the underlying transport
909 // asynchronously returns an error code while writing data - such as if an
910 // intermediary terminates the socket connection uncleanly.
911 // This is a regression test for http://crbug.com/249848
912 TEST_F(SSLClientSocketTest, Write_WithSynchronousError) {
913   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
914                                 SpawnedTestServer::kLocalhost,
915                                 base::FilePath());
916   ASSERT_TRUE(test_server.Start());
917
918   AddressList addr;
919   ASSERT_TRUE(test_server.GetAddressList(&addr));
920
921   TestCompletionCallback callback;
922   scoped_ptr<StreamSocket> real_transport(
923       new TCPClientSocket(addr, NULL, NetLog::Source()));
924   // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
925   // is retained in order to configure additional errors.
926   scoped_ptr<SynchronousErrorStreamSocket> error_socket(
927       new SynchronousErrorStreamSocket(real_transport.Pass()));
928   SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
929   scoped_ptr<FakeBlockingStreamSocket> transport(
930       new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
931   FakeBlockingStreamSocket* raw_transport = transport.get();
932   int rv = callback.GetResult(transport->Connect(callback.callback()));
933   EXPECT_EQ(OK, rv);
934
935   // Disable TLS False Start to avoid handshake non-determinism.
936   SSLConfig ssl_config;
937   ssl_config.false_start_enabled = false;
938
939   scoped_ptr<SSLClientSocket> sock(
940       CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
941                             test_server.host_port_pair(),
942                             ssl_config));
943
944   rv = callback.GetResult(sock->Connect(callback.callback()));
945   EXPECT_EQ(OK, rv);
946   EXPECT_TRUE(sock->IsConnected());
947
948   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
949   static const int kRequestTextSize =
950       static_cast<int>(arraysize(request_text) - 1);
951   scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
952   memcpy(request_buffer->data(), request_text, kRequestTextSize);
953
954   // Simulate an unclean/forcible shutdown on the underlying socket.
955   // However, simulate this error asynchronously.
956   raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
957   raw_transport->SetNextWriteShouldBlock();
958
959   // This write should complete synchronously, because the TLS ciphertext
960   // can be created and placed into the outgoing buffers independent of the
961   // underlying transport.
962   rv = callback.GetResult(
963       sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
964   EXPECT_EQ(kRequestTextSize, rv);
965
966   scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
967
968   rv = sock->Read(buf.get(), 4096, callback.callback());
969   EXPECT_EQ(ERR_IO_PENDING, rv);
970
971   // Now unblock the outgoing request, having it fail with the connection
972   // being reset.
973   raw_transport->UnblockWrite();
974
975   // Note: This will cause an inifite loop if this bug has regressed. Simply
976   // checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING
977   // is a legitimate result when using a dedicated task runner for NSS.
978   rv = callback.GetResult(rv);
979
980 #if !defined(USE_OPENSSL)
981   // SSLClientSocketNSS records the error exactly
982   EXPECT_EQ(ERR_CONNECTION_RESET, rv);
983 #else
984   // SSLClientSocketOpenSSL treats any errors as a simple EOF.
985   EXPECT_EQ(0, rv);
986 #endif
987 }
988
989 // Test the full duplex mode, with Read and Write pending at the same time.
990 // This test also serves as a regression test for http://crbug.com/29815.
991 TEST_F(SSLClientSocketTest, Read_FullDuplex) {
992   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
993                                 SpawnedTestServer::kLocalhost,
994                                 base::FilePath());
995   ASSERT_TRUE(test_server.Start());
996
997   AddressList addr;
998   ASSERT_TRUE(test_server.GetAddressList(&addr));
999
1000   TestCompletionCallback callback;  // Used for everything except Write.
1001
1002   scoped_ptr<StreamSocket> transport(
1003       new TCPClientSocket(addr, NULL, NetLog::Source()));
1004   int rv = transport->Connect(callback.callback());
1005   if (rv == ERR_IO_PENDING)
1006     rv = callback.WaitForResult();
1007   EXPECT_EQ(OK, rv);
1008
1009   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1010       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1011
1012   rv = sock->Connect(callback.callback());
1013   if (rv == ERR_IO_PENDING)
1014     rv = callback.WaitForResult();
1015   EXPECT_EQ(OK, rv);
1016   EXPECT_TRUE(sock->IsConnected());
1017
1018   // Issue a "hanging" Read first.
1019   scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
1020   rv = sock->Read(buf.get(), 4096, callback.callback());
1021   // We haven't written the request, so there should be no response yet.
1022   ASSERT_EQ(ERR_IO_PENDING, rv);
1023
1024   // Write the request.
1025   // The request is padded with a User-Agent header to a size that causes the
1026   // memio circular buffer (4k bytes) in SSLClientSocketNSS to wrap around.
1027   // This tests the fix for http://crbug.com/29815.
1028   std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
1029   for (int i = 0; i < 3770; ++i)
1030     request_text.push_back('*');
1031   request_text.append("\r\n\r\n");
1032   scoped_refptr<IOBuffer> request_buffer(new StringIOBuffer(request_text));
1033
1034   TestCompletionCallback callback2;  // Used for Write only.
1035   rv = sock->Write(
1036       request_buffer.get(), request_text.size(), callback2.callback());
1037   EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
1038
1039   if (rv == ERR_IO_PENDING)
1040     rv = callback2.WaitForResult();
1041   EXPECT_EQ(static_cast<int>(request_text.size()), rv);
1042
1043   // Now get the Read result.
1044   rv = callback.WaitForResult();
1045   EXPECT_GT(rv, 0);
1046 }
1047
1048 // Attempts to Read() and Write() from an SSLClientSocketNSS in full duplex
1049 // mode when the underlying transport is blocked on sending data. When the
1050 // underlying transport completes due to an error, it should invoke both the
1051 // Read() and Write() callbacks. If the socket is deleted by the Read()
1052 // callback, the Write() callback should not be invoked.
1053 // Regression test for http://crbug.com/232633
1054 TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) {
1055   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1056                                 SpawnedTestServer::kLocalhost,
1057                                 base::FilePath());
1058   ASSERT_TRUE(test_server.Start());
1059
1060   AddressList addr;
1061   ASSERT_TRUE(test_server.GetAddressList(&addr));
1062
1063   TestCompletionCallback callback;
1064   scoped_ptr<StreamSocket> real_transport(
1065       new TCPClientSocket(addr, NULL, NetLog::Source()));
1066   // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
1067   // is retained in order to configure additional errors.
1068   scoped_ptr<SynchronousErrorStreamSocket> error_socket(
1069       new SynchronousErrorStreamSocket(real_transport.Pass()));
1070   SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
1071   scoped_ptr<FakeBlockingStreamSocket> transport(
1072       new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
1073   FakeBlockingStreamSocket* raw_transport = transport.get();
1074
1075   int rv = callback.GetResult(transport->Connect(callback.callback()));
1076   EXPECT_EQ(OK, rv);
1077
1078   // Disable TLS False Start to avoid handshake non-determinism.
1079   SSLConfig ssl_config;
1080   ssl_config.false_start_enabled = false;
1081
1082   scoped_ptr<SSLClientSocket> sock =
1083       CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
1084                             test_server.host_port_pair(),
1085                             ssl_config);
1086
1087   rv = callback.GetResult(sock->Connect(callback.callback()));
1088   EXPECT_EQ(OK, rv);
1089   EXPECT_TRUE(sock->IsConnected());
1090
1091   std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
1092   request_text.append(20 * 1024, '*');
1093   request_text.append("\r\n\r\n");
1094   scoped_refptr<DrainableIOBuffer> request_buffer(new DrainableIOBuffer(
1095       new StringIOBuffer(request_text), request_text.size()));
1096
1097   // Simulate errors being returned from the underlying Read() and Write() ...
1098   raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET);
1099   raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
1100   // ... but have those errors returned asynchronously. Because the Write() will
1101   // return first, this will trigger the error.
1102   raw_transport->SetNextReadShouldBlock();
1103   raw_transport->SetNextWriteShouldBlock();
1104
1105   // Enqueue a Read() before calling Write(), which should "hang" due to
1106   // the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return.
1107   SSLClientSocket* raw_sock = sock.get();
1108   DeleteSocketCallback read_callback(sock.release());
1109   scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096));
1110   rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback());
1111
1112   // Ensure things didn't complete synchronously, otherwise |sock| is invalid.
1113   ASSERT_EQ(ERR_IO_PENDING, rv);
1114   ASSERT_FALSE(read_callback.have_result());
1115
1116 #if !defined(USE_OPENSSL)
1117   // NSS follows a pattern where a call to PR_Write will only consume as
1118   // much data as it can encode into application data records before the
1119   // internal memio buffer is full, which should only fill if writing a large
1120   // amount of data and the underlying transport is blocked. Once this happens,
1121   // NSS will return (total size of all application data records it wrote) - 1,
1122   // with the caller expected to resume with the remaining unsent data.
1123   //
1124   // This causes SSLClientSocketNSS::Write to return that it wrote some data
1125   // before it will return ERR_IO_PENDING, so make an extra call to Write() to
1126   // get the socket in the state needed for the test below.
1127   //
1128   // This is not needed for OpenSSL, because for OpenSSL,
1129   // SSL_MODE_ENABLE_PARTIAL_WRITE is not specified - thus
1130   // SSLClientSocketOpenSSL::Write() will not return until all of
1131   // |request_buffer| has been written to the underlying BIO (although not
1132   // necessarily the underlying transport).
1133   rv = callback.GetResult(raw_sock->Write(request_buffer.get(),
1134                                           request_buffer->BytesRemaining(),
1135                                           callback.callback()));
1136   ASSERT_LT(0, rv);
1137   request_buffer->DidConsume(rv);
1138
1139   // Guard to ensure that |request_buffer| was larger than all of the internal
1140   // buffers (transport, memio, NSS) along the way - otherwise the next call
1141   // to Write() will crash with an invalid buffer.
1142   ASSERT_LT(0, request_buffer->BytesRemaining());
1143 #endif
1144
1145   // Attempt to write the remaining data. NSS will not be able to consume the
1146   // application data because the internal buffers are full, while OpenSSL will
1147   // return that its blocked because the underlying transport is blocked.
1148   rv = raw_sock->Write(request_buffer.get(),
1149                        request_buffer->BytesRemaining(),
1150                        callback.callback());
1151   ASSERT_EQ(ERR_IO_PENDING, rv);
1152   ASSERT_FALSE(callback.have_result());
1153
1154   // Now unblock Write(), which will invoke OnSendComplete and (eventually)
1155   // call the Read() callback, deleting the socket and thus aborting calling
1156   // the Write() callback.
1157   raw_transport->UnblockWrite();
1158
1159   rv = read_callback.WaitForResult();
1160
1161 #if !defined(USE_OPENSSL)
1162   // NSS records the error exactly.
1163   EXPECT_EQ(ERR_CONNECTION_RESET, rv);
1164 #else
1165   // OpenSSL treats any errors as a simple EOF.
1166   EXPECT_EQ(0, rv);
1167 #endif
1168
1169   // The Write callback should not have been called.
1170   EXPECT_FALSE(callback.have_result());
1171 }
1172
1173 // Tests that the SSLClientSocket does not crash if data is received on the
1174 // transport socket after a failing write. This can occur if we have a Write
1175 // error in a SPDY socket.
1176 // Regression test for http://crbug.com/335557
1177 TEST_F(SSLClientSocketTest, Read_WithWriteError) {
1178   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1179                                 SpawnedTestServer::kLocalhost,
1180                                 base::FilePath());
1181   ASSERT_TRUE(test_server.Start());
1182
1183   AddressList addr;
1184   ASSERT_TRUE(test_server.GetAddressList(&addr));
1185
1186   TestCompletionCallback callback;
1187   scoped_ptr<StreamSocket> real_transport(
1188       new TCPClientSocket(addr, NULL, NetLog::Source()));
1189   // Note: |error_socket|'s ownership is handed to |transport|, but a pointer
1190   // is retained in order to configure additional errors.
1191   scoped_ptr<SynchronousErrorStreamSocket> error_socket(
1192       new SynchronousErrorStreamSocket(real_transport.Pass()));
1193   SynchronousErrorStreamSocket* raw_error_socket = error_socket.get();
1194   scoped_ptr<FakeBlockingStreamSocket> transport(
1195       new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>()));
1196   FakeBlockingStreamSocket* raw_transport = transport.get();
1197
1198   int rv = callback.GetResult(transport->Connect(callback.callback()));
1199   EXPECT_EQ(OK, rv);
1200
1201   // Disable TLS False Start to avoid handshake non-determinism.
1202   SSLConfig ssl_config;
1203   ssl_config.false_start_enabled = false;
1204
1205   scoped_ptr<SSLClientSocket> sock(
1206       CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
1207                             test_server.host_port_pair(),
1208                             ssl_config));
1209
1210   rv = callback.GetResult(sock->Connect(callback.callback()));
1211   EXPECT_EQ(OK, rv);
1212   EXPECT_TRUE(sock->IsConnected());
1213
1214   // Send a request so there is something to read from the socket.
1215   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
1216   static const int kRequestTextSize =
1217       static_cast<int>(arraysize(request_text) - 1);
1218   scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize));
1219   memcpy(request_buffer->data(), request_text, kRequestTextSize);
1220
1221   rv = callback.GetResult(
1222       sock->Write(request_buffer.get(), kRequestTextSize, callback.callback()));
1223   EXPECT_EQ(kRequestTextSize, rv);
1224
1225   // Start a hanging read.
1226   TestCompletionCallback read_callback;
1227   raw_transport->SetNextReadShouldBlock();
1228   scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
1229   rv = sock->Read(buf.get(), 4096, read_callback.callback());
1230   EXPECT_EQ(ERR_IO_PENDING, rv);
1231
1232   // Perform another write, but have it fail. Write a request larger than the
1233   // internal socket buffers so that the request hits the underlying transport
1234   // socket and detects the error.
1235   std::string long_request_text =
1236       "GET / HTTP/1.1\r\nUser-Agent: long browser name ";
1237   long_request_text.append(20 * 1024, '*');
1238   long_request_text.append("\r\n\r\n");
1239   scoped_refptr<DrainableIOBuffer> long_request_buffer(new DrainableIOBuffer(
1240       new StringIOBuffer(long_request_text), long_request_text.size()));
1241
1242   raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET);
1243
1244   // Write as much data as possible until hitting an error. This is necessary
1245   // for NSS. PR_Write will only consume as much data as it can encode into
1246   // application data records before the internal memio buffer is full, which
1247   // should only fill if writing a large amount of data and the underlying
1248   // transport is blocked. Once this happens, NSS will return (total size of all
1249   // application data records it wrote) - 1, with the caller expected to resume
1250   // with the remaining unsent data.
1251   do {
1252     rv = callback.GetResult(sock->Write(long_request_buffer.get(),
1253                                         long_request_buffer->BytesRemaining(),
1254                                         callback.callback()));
1255     if (rv > 0) {
1256       long_request_buffer->DidConsume(rv);
1257       // Abort if the entire buffer is ever consumed.
1258       ASSERT_LT(0, long_request_buffer->BytesRemaining());
1259     }
1260   } while (rv > 0);
1261
1262 #if !defined(USE_OPENSSL)
1263   // NSS records the error exactly.
1264   EXPECT_EQ(ERR_CONNECTION_RESET, rv);
1265 #else
1266   // OpenSSL treats the reset as a generic protocol error.
1267   EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
1268 #endif
1269
1270   // Release the read. Some bytes should go through.
1271   raw_transport->UnblockRead();
1272   rv = read_callback.WaitForResult();
1273
1274   // Per the fix for http://crbug.com/249848, write failures currently break
1275   // reads. Change this assertion if they're changed to not collide.
1276   EXPECT_EQ(ERR_CONNECTION_RESET, rv);
1277 }
1278
1279 TEST_F(SSLClientSocketTest, Read_SmallChunks) {
1280   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1281                                 SpawnedTestServer::kLocalhost,
1282                                 base::FilePath());
1283   ASSERT_TRUE(test_server.Start());
1284
1285   AddressList addr;
1286   ASSERT_TRUE(test_server.GetAddressList(&addr));
1287
1288   TestCompletionCallback callback;
1289   scoped_ptr<StreamSocket> transport(
1290       new TCPClientSocket(addr, NULL, NetLog::Source()));
1291   int rv = transport->Connect(callback.callback());
1292   if (rv == ERR_IO_PENDING)
1293     rv = callback.WaitForResult();
1294   EXPECT_EQ(OK, rv);
1295
1296   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1297       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1298
1299   rv = sock->Connect(callback.callback());
1300   if (rv == ERR_IO_PENDING)
1301     rv = callback.WaitForResult();
1302   EXPECT_EQ(OK, rv);
1303
1304   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
1305   scoped_refptr<IOBuffer> request_buffer(
1306       new IOBuffer(arraysize(request_text) - 1));
1307   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
1308
1309   rv = sock->Write(
1310       request_buffer.get(), arraysize(request_text) - 1, callback.callback());
1311   EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
1312
1313   if (rv == ERR_IO_PENDING)
1314     rv = callback.WaitForResult();
1315   EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
1316
1317   scoped_refptr<IOBuffer> buf(new IOBuffer(1));
1318   for (;;) {
1319     rv = sock->Read(buf.get(), 1, callback.callback());
1320     EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
1321
1322     if (rv == ERR_IO_PENDING)
1323       rv = callback.WaitForResult();
1324
1325     EXPECT_GE(rv, 0);
1326     if (rv <= 0)
1327       break;
1328   }
1329 }
1330
1331 TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
1332   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1333                                 SpawnedTestServer::kLocalhost,
1334                                 base::FilePath());
1335   ASSERT_TRUE(test_server.Start());
1336
1337   AddressList addr;
1338   ASSERT_TRUE(test_server.GetAddressList(&addr));
1339
1340   TestCompletionCallback callback;
1341
1342   scoped_ptr<StreamSocket> real_transport(
1343       new TCPClientSocket(addr, NULL, NetLog::Source()));
1344   scoped_ptr<ReadBufferingStreamSocket> transport(
1345       new ReadBufferingStreamSocket(real_transport.Pass()));
1346   ReadBufferingStreamSocket* raw_transport = transport.get();
1347   int rv = callback.GetResult(transport->Connect(callback.callback()));
1348   ASSERT_EQ(OK, rv);
1349
1350   scoped_ptr<SSLClientSocket> sock(
1351       CreateSSLClientSocket(transport.PassAs<StreamSocket>(),
1352                             test_server.host_port_pair(),
1353                             kDefaultSSLConfig));
1354
1355   rv = callback.GetResult(sock->Connect(callback.callback()));
1356   ASSERT_EQ(OK, rv);
1357   ASSERT_TRUE(sock->IsConnected());
1358
1359   const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n";
1360   scoped_refptr<IOBuffer> request_buffer(
1361       new IOBuffer(arraysize(request_text) - 1));
1362   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
1363
1364   rv = callback.GetResult(sock->Write(
1365       request_buffer.get(), arraysize(request_text) - 1, callback.callback()));
1366   ASSERT_GT(rv, 0);
1367   ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
1368
1369   // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of
1370   // data (the max SSL record size) at a time. Ensure that at least 15K worth
1371   // of SSL data is buffered first. The 15K of buffered data is made up of
1372   // many smaller SSL records (the TestServer writes along 1350 byte
1373   // plaintext boundaries), although there may also be a few records that are
1374   // smaller or larger, due to timing and SSL False Start.
1375   // 15K was chosen because 15K is smaller than the 17K (max) read issued by
1376   // the SSLClientSocket implementation, and larger than the minimum amount
1377   // of ciphertext necessary to contain the 8K of plaintext requested below.
1378   raw_transport->SetBufferSize(15000);
1379
1380   scoped_refptr<IOBuffer> buffer(new IOBuffer(8192));
1381   rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback()));
1382   ASSERT_EQ(rv, 8192);
1383 }
1384
1385 TEST_F(SSLClientSocketTest, Read_Interrupted) {
1386   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1387                                 SpawnedTestServer::kLocalhost,
1388                                 base::FilePath());
1389   ASSERT_TRUE(test_server.Start());
1390
1391   AddressList addr;
1392   ASSERT_TRUE(test_server.GetAddressList(&addr));
1393
1394   TestCompletionCallback callback;
1395   scoped_ptr<StreamSocket> transport(
1396       new TCPClientSocket(addr, NULL, NetLog::Source()));
1397   int rv = transport->Connect(callback.callback());
1398   if (rv == ERR_IO_PENDING)
1399     rv = callback.WaitForResult();
1400   EXPECT_EQ(OK, rv);
1401
1402   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1403       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1404
1405   rv = sock->Connect(callback.callback());
1406   if (rv == ERR_IO_PENDING)
1407     rv = callback.WaitForResult();
1408   EXPECT_EQ(OK, rv);
1409
1410   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
1411   scoped_refptr<IOBuffer> request_buffer(
1412       new IOBuffer(arraysize(request_text) - 1));
1413   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
1414
1415   rv = sock->Write(
1416       request_buffer.get(), arraysize(request_text) - 1, callback.callback());
1417   EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
1418
1419   if (rv == ERR_IO_PENDING)
1420     rv = callback.WaitForResult();
1421   EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
1422
1423   // Do a partial read and then exit.  This test should not crash!
1424   scoped_refptr<IOBuffer> buf(new IOBuffer(512));
1425   rv = sock->Read(buf.get(), 512, callback.callback());
1426   EXPECT_TRUE(rv > 0 || rv == ERR_IO_PENDING);
1427
1428   if (rv == ERR_IO_PENDING)
1429     rv = callback.WaitForResult();
1430
1431   EXPECT_GT(rv, 0);
1432 }
1433
1434 TEST_F(SSLClientSocketTest, Read_FullLogging) {
1435   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1436                                 SpawnedTestServer::kLocalhost,
1437                                 base::FilePath());
1438   ASSERT_TRUE(test_server.Start());
1439
1440   AddressList addr;
1441   ASSERT_TRUE(test_server.GetAddressList(&addr));
1442
1443   TestCompletionCallback callback;
1444   CapturingNetLog log;
1445   log.SetLogLevel(NetLog::LOG_ALL);
1446   scoped_ptr<StreamSocket> transport(
1447       new TCPClientSocket(addr, &log, NetLog::Source()));
1448   int rv = transport->Connect(callback.callback());
1449   if (rv == ERR_IO_PENDING)
1450     rv = callback.WaitForResult();
1451   EXPECT_EQ(OK, rv);
1452
1453   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1454       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1455
1456   rv = sock->Connect(callback.callback());
1457   if (rv == ERR_IO_PENDING)
1458     rv = callback.WaitForResult();
1459   EXPECT_EQ(OK, rv);
1460   EXPECT_TRUE(sock->IsConnected());
1461
1462   const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
1463   scoped_refptr<IOBuffer> request_buffer(
1464       new IOBuffer(arraysize(request_text) - 1));
1465   memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
1466
1467   rv = sock->Write(
1468       request_buffer.get(), arraysize(request_text) - 1, callback.callback());
1469   EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
1470
1471   if (rv == ERR_IO_PENDING)
1472     rv = callback.WaitForResult();
1473   EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
1474
1475   CapturingNetLog::CapturedEntryList entries;
1476   log.GetEntries(&entries);
1477   size_t last_index = ExpectLogContainsSomewhereAfter(
1478       entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE);
1479
1480   scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
1481   for (;;) {
1482     rv = sock->Read(buf.get(), 4096, callback.callback());
1483     EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
1484
1485     if (rv == ERR_IO_PENDING)
1486       rv = callback.WaitForResult();
1487
1488     EXPECT_GE(rv, 0);
1489     if (rv <= 0)
1490       break;
1491
1492     log.GetEntries(&entries);
1493     last_index =
1494         ExpectLogContainsSomewhereAfter(entries,
1495                                         last_index + 1,
1496                                         NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED,
1497                                         NetLog::PHASE_NONE);
1498   }
1499 }
1500
1501 // Regression test for http://crbug.com/42538
1502 TEST_F(SSLClientSocketTest, PrematureApplicationData) {
1503   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1504                                 SpawnedTestServer::kLocalhost,
1505                                 base::FilePath());
1506   ASSERT_TRUE(test_server.Start());
1507
1508   AddressList addr;
1509   TestCompletionCallback callback;
1510
1511   static const unsigned char application_data[] = {
1512       0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b,
1513       0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46,
1514       0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6,
1515       0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36,
1516       0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d,
1517       0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f,
1518       0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57,
1519       0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82,
1520       0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02,
1521       0x0a};
1522
1523   // All reads and writes complete synchronously (async=false).
1524   MockRead data_reads[] = {
1525       MockRead(SYNCHRONOUS,
1526                reinterpret_cast<const char*>(application_data),
1527                arraysize(application_data)),
1528       MockRead(SYNCHRONOUS, OK), };
1529
1530   StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0);
1531
1532   scoped_ptr<StreamSocket> transport(
1533       new MockTCPClientSocket(addr, NULL, &data));
1534   int rv = transport->Connect(callback.callback());
1535   if (rv == ERR_IO_PENDING)
1536     rv = callback.WaitForResult();
1537   EXPECT_EQ(OK, rv);
1538
1539   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1540       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1541
1542   rv = sock->Connect(callback.callback());
1543   if (rv == ERR_IO_PENDING)
1544     rv = callback.WaitForResult();
1545   EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
1546 }
1547
1548 TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
1549   // Rather than exhaustively disabling every RC4 ciphersuite defined at
1550   // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml,
1551   // only disabling those cipher suites that the test server actually
1552   // implements.
1553   const uint16 kCiphersToDisable[] = {0x0005,  // TLS_RSA_WITH_RC4_128_SHA
1554   };
1555
1556   SpawnedTestServer::SSLOptions ssl_options;
1557   // Enable only RC4 on the test server.
1558   ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4;
1559   SpawnedTestServer test_server(
1560       SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
1561   ASSERT_TRUE(test_server.Start());
1562
1563   AddressList addr;
1564   ASSERT_TRUE(test_server.GetAddressList(&addr));
1565
1566   TestCompletionCallback callback;
1567   CapturingNetLog log;
1568   scoped_ptr<StreamSocket> transport(
1569       new TCPClientSocket(addr, &log, NetLog::Source()));
1570   int rv = transport->Connect(callback.callback());
1571   if (rv == ERR_IO_PENDING)
1572     rv = callback.WaitForResult();
1573   EXPECT_EQ(OK, rv);
1574
1575   SSLConfig ssl_config;
1576   for (size_t i = 0; i < arraysize(kCiphersToDisable); ++i)
1577     ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]);
1578
1579   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1580       transport.Pass(), test_server.host_port_pair(), ssl_config));
1581
1582   EXPECT_FALSE(sock->IsConnected());
1583
1584   rv = sock->Connect(callback.callback());
1585   CapturingNetLog::CapturedEntryList entries;
1586   log.GetEntries(&entries);
1587   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
1588
1589   // NSS has special handling that maps a handshake_failure alert received
1590   // immediately after a client_hello to be a mismatched cipher suite error,
1591   // leading to ERR_SSL_VERSION_OR_CIPHER_MISMATCH. When using OpenSSL or
1592   // Secure Transport (OS X), the handshake_failure is bubbled up without any
1593   // interpretation, leading to ERR_SSL_PROTOCOL_ERROR. Either way, a failure
1594   // indicates that no cipher suite was negotiated with the test server.
1595   if (rv == ERR_IO_PENDING)
1596     rv = callback.WaitForResult();
1597   EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH ||
1598               rv == ERR_SSL_PROTOCOL_ERROR);
1599   // The exact ordering differs between SSLClientSocketNSS (which issues an
1600   // extra read) and SSLClientSocketMac (which does not). Just make sure the
1601   // error appears somewhere in the log.
1602   log.GetEntries(&entries);
1603   ExpectLogContainsSomewhere(
1604       entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE);
1605
1606   // We cannot test sock->IsConnected(), as the NSS implementation disconnects
1607   // the socket when it encounters an error, whereas other implementations
1608   // leave it connected.
1609   // Because this an error that the test server is mutually aware of, as opposed
1610   // to being an error such as a certificate name mismatch, which is
1611   // client-only, the exact index of the SSL connect end depends on how
1612   // quickly the test server closes the underlying socket. If the test server
1613   // closes before the IO message loop pumps messages, there may be a 0-byte
1614   // Read event in the NetLog due to TCPClientSocket picking up the EOF. As a
1615   // result, the SSL connect end event will be the second-to-last entry,
1616   // rather than the last entry.
1617   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1) ||
1618               LogContainsSSLConnectEndEvent(entries, -2));
1619 }
1620
1621 // When creating an SSLClientSocket, it is allowed to pass in a
1622 // ClientSocketHandle that is not obtained from a client socket pool.
1623 // Here we verify that such a simple ClientSocketHandle, not associated with any
1624 // client socket pool, can be destroyed safely.
1625 TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) {
1626   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1627                                 SpawnedTestServer::kLocalhost,
1628                                 base::FilePath());
1629   ASSERT_TRUE(test_server.Start());
1630
1631   AddressList addr;
1632   ASSERT_TRUE(test_server.GetAddressList(&addr));
1633
1634   TestCompletionCallback callback;
1635   scoped_ptr<StreamSocket> transport(
1636       new TCPClientSocket(addr, NULL, NetLog::Source()));
1637   int rv = transport->Connect(callback.callback());
1638   if (rv == ERR_IO_PENDING)
1639     rv = callback.WaitForResult();
1640   EXPECT_EQ(OK, rv);
1641
1642   scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle());
1643   socket_handle->SetSocket(transport.Pass());
1644
1645   scoped_ptr<SSLClientSocket> sock(
1646       socket_factory_->CreateSSLClientSocket(socket_handle.Pass(),
1647                                              test_server.host_port_pair(),
1648                                              kDefaultSSLConfig,
1649                                              context_));
1650
1651   EXPECT_FALSE(sock->IsConnected());
1652   rv = sock->Connect(callback.callback());
1653   if (rv == ERR_IO_PENDING)
1654     rv = callback.WaitForResult();
1655   EXPECT_EQ(OK, rv);
1656 }
1657
1658 // Verifies that SSLClientSocket::ExportKeyingMaterial return a success
1659 // code and different keying label results in different keying material.
1660 TEST_F(SSLClientSocketTest, ExportKeyingMaterial) {
1661   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1662                                 SpawnedTestServer::kLocalhost,
1663                                 base::FilePath());
1664   ASSERT_TRUE(test_server.Start());
1665
1666   AddressList addr;
1667   ASSERT_TRUE(test_server.GetAddressList(&addr));
1668
1669   TestCompletionCallback callback;
1670
1671   scoped_ptr<StreamSocket> transport(
1672       new TCPClientSocket(addr, NULL, NetLog::Source()));
1673   int rv = transport->Connect(callback.callback());
1674   if (rv == ERR_IO_PENDING)
1675     rv = callback.WaitForResult();
1676   EXPECT_EQ(OK, rv);
1677
1678   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1679       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1680
1681   rv = sock->Connect(callback.callback());
1682   if (rv == ERR_IO_PENDING)
1683     rv = callback.WaitForResult();
1684   EXPECT_EQ(OK, rv);
1685   EXPECT_TRUE(sock->IsConnected());
1686
1687   const int kKeyingMaterialSize = 32;
1688   const char* kKeyingLabel1 = "client-socket-test-1";
1689   const char* kKeyingContext = "";
1690   unsigned char client_out1[kKeyingMaterialSize];
1691   memset(client_out1, 0, sizeof(client_out1));
1692   rv = sock->ExportKeyingMaterial(
1693       kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1));
1694   EXPECT_EQ(rv, OK);
1695
1696   const char* kKeyingLabel2 = "client-socket-test-2";
1697   unsigned char client_out2[kKeyingMaterialSize];
1698   memset(client_out2, 0, sizeof(client_out2));
1699   rv = sock->ExportKeyingMaterial(
1700       kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2));
1701   EXPECT_EQ(rv, OK);
1702   EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0);
1703 }
1704
1705 // Verifies that SSLClientSocket::ClearSessionCache can be called without
1706 // explicit NSS initialization.
1707 TEST(SSLClientSocket, ClearSessionCache) {
1708   SSLClientSocket::ClearSessionCache();
1709 }
1710
1711 // This tests that SSLInfo contains a properly re-constructed certificate
1712 // chain. That, in turn, verifies that GetSSLInfo is giving us the chain as
1713 // verified, not the chain as served by the server. (They may be different.)
1714 //
1715 // CERT_CHAIN_WRONG_ROOT is redundant-server-chain.pem. It contains A
1716 // (end-entity) -> B -> C, and C is signed by D. redundant-validated-chain.pem
1717 // contains a chain of A -> B -> C2, where C2 is the same public key as C, but
1718 // a self-signed root. Such a situation can occur when a new root (C2) is
1719 // cross-certified by an old root (D) and has two different versions of its
1720 // floating around. Servers may supply C2 as an intermediate, but the
1721 // SSLClientSocket should return the chain that was verified, from
1722 // verify_result, instead.
1723 TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) {
1724   // By default, cause the CertVerifier to treat all certificates as
1725   // expired.
1726   cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID);
1727
1728   // We will expect SSLInfo to ultimately contain this chain.
1729   CertificateList certs =
1730       CreateCertificateListFromFile(GetTestCertsDirectory(),
1731                                     "redundant-validated-chain.pem",
1732                                     X509Certificate::FORMAT_AUTO);
1733   ASSERT_EQ(3U, certs.size());
1734
1735   X509Certificate::OSCertHandles temp_intermediates;
1736   temp_intermediates.push_back(certs[1]->os_cert_handle());
1737   temp_intermediates.push_back(certs[2]->os_cert_handle());
1738
1739   CertVerifyResult verify_result;
1740   verify_result.verified_cert = X509Certificate::CreateFromHandle(
1741       certs[0]->os_cert_handle(), temp_intermediates);
1742
1743   // Add a rule that maps the server cert (A) to the chain of A->B->C2
1744   // rather than A->B->C.
1745   cert_verifier_->AddResultForCert(certs[0].get(), verify_result, OK);
1746
1747   // Load and install the root for the validated chain.
1748   scoped_refptr<X509Certificate> root_cert = ImportCertFromFile(
1749       GetTestCertsDirectory(), "redundant-validated-chain-root.pem");
1750   ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert);
1751   ScopedTestRoot scoped_root(root_cert.get());
1752
1753   // Set up a test server with CERT_CHAIN_WRONG_ROOT.
1754   SpawnedTestServer::SSLOptions ssl_options(
1755       SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT);
1756   SpawnedTestServer test_server(
1757       SpawnedTestServer::TYPE_HTTPS,
1758       ssl_options,
1759       base::FilePath(FILE_PATH_LITERAL("net/data/ssl")));
1760   ASSERT_TRUE(test_server.Start());
1761
1762   AddressList addr;
1763   ASSERT_TRUE(test_server.GetAddressList(&addr));
1764
1765   TestCompletionCallback callback;
1766   CapturingNetLog log;
1767   scoped_ptr<StreamSocket> transport(
1768       new TCPClientSocket(addr, &log, NetLog::Source()));
1769   int rv = transport->Connect(callback.callback());
1770   if (rv == ERR_IO_PENDING)
1771     rv = callback.WaitForResult();
1772   EXPECT_EQ(OK, rv);
1773
1774   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1775       transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1776   EXPECT_FALSE(sock->IsConnected());
1777   rv = sock->Connect(callback.callback());
1778
1779   CapturingNetLog::CapturedEntryList entries;
1780   log.GetEntries(&entries);
1781   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
1782   if (rv == ERR_IO_PENDING)
1783     rv = callback.WaitForResult();
1784
1785   EXPECT_EQ(OK, rv);
1786   EXPECT_TRUE(sock->IsConnected());
1787   log.GetEntries(&entries);
1788   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
1789
1790   SSLInfo ssl_info;
1791   sock->GetSSLInfo(&ssl_info);
1792
1793   // Verify that SSLInfo contains the corrected re-constructed chain A -> B
1794   // -> C2.
1795   const X509Certificate::OSCertHandles& intermediates =
1796       ssl_info.cert->GetIntermediateCertificates();
1797   ASSERT_EQ(2U, intermediates.size());
1798   EXPECT_TRUE(X509Certificate::IsSameOSCert(ssl_info.cert->os_cert_handle(),
1799                                             certs[0]->os_cert_handle()));
1800   EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[0],
1801                                             certs[1]->os_cert_handle()));
1802   EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[1],
1803                                             certs[2]->os_cert_handle()));
1804
1805   sock->Disconnect();
1806   EXPECT_FALSE(sock->IsConnected());
1807 }
1808
1809 // Verifies the correctness of GetSSLCertRequestInfo.
1810 class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest {
1811  protected:
1812   // Creates a test server with the given SSLOptions, connects to it and returns
1813   // the SSLCertRequestInfo reported by the socket.
1814   scoped_refptr<SSLCertRequestInfo> GetCertRequest(
1815       SpawnedTestServer::SSLOptions ssl_options) {
1816     SpawnedTestServer test_server(
1817         SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath());
1818     if (!test_server.Start())
1819       return NULL;
1820
1821     AddressList addr;
1822     if (!test_server.GetAddressList(&addr))
1823       return NULL;
1824
1825     TestCompletionCallback callback;
1826     CapturingNetLog log;
1827     scoped_ptr<StreamSocket> transport(
1828         new TCPClientSocket(addr, &log, NetLog::Source()));
1829     int rv = transport->Connect(callback.callback());
1830     if (rv == ERR_IO_PENDING)
1831       rv = callback.WaitForResult();
1832     EXPECT_EQ(OK, rv);
1833
1834     scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1835         transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig));
1836     EXPECT_FALSE(sock->IsConnected());
1837
1838     rv = sock->Connect(callback.callback());
1839     if (rv == ERR_IO_PENDING)
1840       rv = callback.WaitForResult();
1841     scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo();
1842     sock->GetSSLCertRequestInfo(request_info.get());
1843     sock->Disconnect();
1844     EXPECT_FALSE(sock->IsConnected());
1845
1846     return request_info;
1847   }
1848 };
1849
1850 TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) {
1851   SpawnedTestServer::SSLOptions ssl_options;
1852   ssl_options.request_client_certificate = true;
1853   scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
1854   ASSERT_TRUE(request_info.get());
1855   EXPECT_EQ(0u, request_info->cert_authorities.size());
1856 }
1857
1858 TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) {
1859   const base::FilePath::CharType kThawteFile[] =
1860       FILE_PATH_LITERAL("thawte.single.pem");
1861   const unsigned char kThawteDN[] = {
1862       0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
1863       0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a,
1864       0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e,
1865       0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79,
1866       0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03,
1867       0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20,
1868       0x53, 0x47, 0x43, 0x20, 0x43, 0x41};
1869   const size_t kThawteLen = sizeof(kThawteDN);
1870
1871   const base::FilePath::CharType kDiginotarFile[] =
1872       FILE_PATH_LITERAL("diginotar_root_ca.pem");
1873   const unsigned char kDiginotarDN[] = {
1874       0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
1875       0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a,
1876       0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31,
1877       0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69,
1878       0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74,
1879       0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48,
1880       0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f,
1881       0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e,
1882       0x6c};
1883   const size_t kDiginotarLen = sizeof(kDiginotarDN);
1884
1885   SpawnedTestServer::SSLOptions ssl_options;
1886   ssl_options.request_client_certificate = true;
1887   ssl_options.client_authorities.push_back(
1888       GetTestClientCertsDirectory().Append(kThawteFile));
1889   ssl_options.client_authorities.push_back(
1890       GetTestClientCertsDirectory().Append(kDiginotarFile));
1891   scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options);
1892   ASSERT_TRUE(request_info.get());
1893   ASSERT_EQ(2u, request_info->cert_authorities.size());
1894   EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen),
1895             request_info->cert_authorities[0]);
1896   EXPECT_EQ(
1897       std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen),
1898       request_info->cert_authorities[1]);
1899 }
1900
1901 }  // namespace
1902
1903 TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledTLSExtension) {
1904   SpawnedTestServer::SSLOptions ssl_options;
1905   ssl_options.signed_cert_timestamps_tls_ext = "test";
1906
1907   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1908                                 ssl_options,
1909                                 base::FilePath());
1910   ASSERT_TRUE(test_server.Start());
1911
1912   AddressList addr;
1913   ASSERT_TRUE(test_server.GetAddressList(&addr));
1914
1915   TestCompletionCallback callback;
1916   CapturingNetLog log;
1917   scoped_ptr<StreamSocket> transport(
1918       new TCPClientSocket(addr, &log, NetLog::Source()));
1919   int rv = transport->Connect(callback.callback());
1920   if (rv == ERR_IO_PENDING)
1921     rv = callback.WaitForResult();
1922   EXPECT_EQ(OK, rv);
1923
1924   SSLConfig ssl_config;
1925   ssl_config.signed_cert_timestamps_enabled = true;
1926
1927   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1928       transport.Pass(), test_server.host_port_pair(), ssl_config));
1929
1930   EXPECT_FALSE(sock->IsConnected());
1931
1932   rv = sock->Connect(callback.callback());
1933
1934   CapturingNetLog::CapturedEntryList entries;
1935   log.GetEntries(&entries);
1936   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
1937   if (rv == ERR_IO_PENDING)
1938     rv = callback.WaitForResult();
1939   EXPECT_EQ(OK, rv);
1940   EXPECT_TRUE(sock->IsConnected());
1941   log.GetEntries(&entries);
1942   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
1943
1944 #if !defined(USE_OPENSSL)
1945   EXPECT_TRUE(sock->signed_cert_timestamps_received_);
1946 #else
1947   // Enabling CT for OpenSSL is currently a noop.
1948   EXPECT_FALSE(sock->signed_cert_timestamps_received_);
1949 #endif
1950
1951   sock->Disconnect();
1952   EXPECT_FALSE(sock->IsConnected());
1953 }
1954
1955 // Test that enabling Signed Certificate Timestamps enables OCSP stapling.
1956 TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsEnabledOCSP) {
1957   SpawnedTestServer::SSLOptions ssl_options;
1958   ssl_options.staple_ocsp_response = true;
1959   // The test server currently only knows how to generate OCSP responses
1960   // for a freshly minted certificate.
1961   ssl_options.server_certificate = SpawnedTestServer::SSLOptions::CERT_AUTO;
1962
1963   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
1964                                 ssl_options,
1965                                 base::FilePath());
1966   ASSERT_TRUE(test_server.Start());
1967
1968   AddressList addr;
1969   ASSERT_TRUE(test_server.GetAddressList(&addr));
1970
1971   TestCompletionCallback callback;
1972   CapturingNetLog log;
1973   scoped_ptr<StreamSocket> transport(
1974       new TCPClientSocket(addr, &log, NetLog::Source()));
1975   int rv = transport->Connect(callback.callback());
1976   if (rv == ERR_IO_PENDING)
1977     rv = callback.WaitForResult();
1978   EXPECT_EQ(OK, rv);
1979
1980   SSLConfig ssl_config;
1981   // Enabling Signed Cert Timestamps ensures we request OCSP stapling for
1982   // Certificate Transparency verification regardless of whether the platform
1983   // is able to process the OCSP status itself.
1984   ssl_config.signed_cert_timestamps_enabled = true;
1985
1986   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
1987       transport.Pass(), test_server.host_port_pair(), ssl_config));
1988
1989   EXPECT_FALSE(sock->IsConnected());
1990
1991   rv = sock->Connect(callback.callback());
1992
1993   CapturingNetLog::CapturedEntryList entries;
1994   log.GetEntries(&entries);
1995   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
1996   if (rv == ERR_IO_PENDING)
1997     rv = callback.WaitForResult();
1998   EXPECT_EQ(OK, rv);
1999   EXPECT_TRUE(sock->IsConnected());
2000   log.GetEntries(&entries);
2001   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
2002
2003 #if !defined(USE_OPENSSL)
2004   EXPECT_TRUE(sock->stapled_ocsp_response_received_);
2005 #else
2006   // OCSP stapling isn't currently supported in the OpenSSL socket.
2007   EXPECT_FALSE(sock->stapled_ocsp_response_received_);
2008 #endif
2009
2010   sock->Disconnect();
2011   EXPECT_FALSE(sock->IsConnected());
2012 }
2013
2014 TEST_F(SSLClientSocketTest, ConnectSignedCertTimestampsDisabled) {
2015   SpawnedTestServer::SSLOptions ssl_options;
2016   ssl_options.signed_cert_timestamps_tls_ext = "test";
2017
2018   SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS,
2019                                 ssl_options,
2020                                 base::FilePath());
2021   ASSERT_TRUE(test_server.Start());
2022
2023   AddressList addr;
2024   ASSERT_TRUE(test_server.GetAddressList(&addr));
2025
2026   TestCompletionCallback callback;
2027   CapturingNetLog log;
2028   scoped_ptr<StreamSocket> transport(
2029       new TCPClientSocket(addr, &log, NetLog::Source()));
2030   int rv = transport->Connect(callback.callback());
2031   if (rv == ERR_IO_PENDING)
2032     rv = callback.WaitForResult();
2033   EXPECT_EQ(OK, rv);
2034
2035   SSLConfig ssl_config;
2036   ssl_config.signed_cert_timestamps_enabled = false;
2037
2038   scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket(
2039       transport.Pass(), test_server.host_port_pair(), ssl_config));
2040
2041   EXPECT_FALSE(sock->IsConnected());
2042
2043   rv = sock->Connect(callback.callback());
2044
2045   CapturingNetLog::CapturedEntryList entries;
2046   log.GetEntries(&entries);
2047   EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT));
2048   if (rv == ERR_IO_PENDING)
2049     rv = callback.WaitForResult();
2050   EXPECT_EQ(OK, rv);
2051   EXPECT_TRUE(sock->IsConnected());
2052   log.GetEntries(&entries);
2053   EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1));
2054
2055   EXPECT_FALSE(sock->signed_cert_timestamps_received_);
2056
2057   sock->Disconnect();
2058   EXPECT_FALSE(sock->IsConnected());
2059 }
2060
2061 }  // namespace net