Upstream version 7.36.149.0
[platform/framework/web/crosswalk.git] / src / net / socket / ssl_server_socket_unittest.cc
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 //    Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15
16 #include "net/socket/ssl_server_socket.h"
17
18 #include <stdlib.h>
19
20 #include <queue>
21
22 #include "base/compiler_specific.h"
23 #include "base/file_util.h"
24 #include "base/files/file_path.h"
25 #include "base/message_loop/message_loop.h"
26 #include "base/path_service.h"
27 #include "crypto/nss_util.h"
28 #include "crypto/rsa_private_key.h"
29 #include "net/base/address_list.h"
30 #include "net/base/completion_callback.h"
31 #include "net/base/host_port_pair.h"
32 #include "net/base/io_buffer.h"
33 #include "net/base/ip_endpoint.h"
34 #include "net/base/net_errors.h"
35 #include "net/base/net_log.h"
36 #include "net/base/test_data_directory.h"
37 #include "net/cert/cert_status_flags.h"
38 #include "net/cert/mock_cert_verifier.h"
39 #include "net/cert/x509_certificate.h"
40 #include "net/http/transport_security_state.h"
41 #include "net/socket/client_socket_factory.h"
42 #include "net/socket/socket_test_util.h"
43 #include "net/socket/ssl_client_socket.h"
44 #include "net/socket/stream_socket.h"
45 #include "net/ssl/ssl_config_service.h"
46 #include "net/ssl/ssl_info.h"
47 #include "net/test/cert_test_util.h"
48 #include "testing/gtest/include/gtest/gtest.h"
49 #include "testing/platform_test.h"
50
51 namespace net {
52
53 namespace {
54
55 class FakeDataChannel {
56  public:
57   FakeDataChannel()
58       : read_buf_len_(0),
59         closed_(false),
60         write_called_after_close_(false),
61         weak_factory_(this) {
62   }
63
64   int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
65     if (closed_)
66       return 0;
67     if (data_.empty()) {
68       read_callback_ = callback;
69       read_buf_ = buf;
70       read_buf_len_ = buf_len;
71       return net::ERR_IO_PENDING;
72     }
73     return PropogateData(buf, buf_len);
74   }
75
76   int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
77     if (closed_) {
78       if (write_called_after_close_)
79         return net::ERR_CONNECTION_RESET;
80       write_called_after_close_ = true;
81       write_callback_ = callback;
82       base::MessageLoop::current()->PostTask(
83           FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback,
84                                 weak_factory_.GetWeakPtr()));
85       return net::ERR_IO_PENDING;
86     }
87     data_.push(new net::DrainableIOBuffer(buf, buf_len));
88     base::MessageLoop::current()->PostTask(
89         FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
90                               weak_factory_.GetWeakPtr()));
91     return buf_len;
92   }
93
94   // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
95   // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
96   // after the FakeDataChannel is closed, the first Write() call completes
97   // asynchronously, which is necessary to reproduce bug 127822.
98   void Close() {
99     closed_ = true;
100   }
101
102  private:
103   void DoReadCallback() {
104     if (read_callback_.is_null() || data_.empty())
105       return;
106
107     int copied = PropogateData(read_buf_, read_buf_len_);
108     CompletionCallback callback = read_callback_;
109     read_callback_.Reset();
110     read_buf_ = NULL;
111     read_buf_len_ = 0;
112     callback.Run(copied);
113   }
114
115   void DoWriteCallback() {
116     if (write_callback_.is_null())
117       return;
118
119     CompletionCallback callback = write_callback_;
120     write_callback_.Reset();
121     callback.Run(net::ERR_CONNECTION_RESET);
122   }
123
124   int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) {
125     scoped_refptr<net::DrainableIOBuffer> buf = data_.front();
126     int copied = std::min(buf->BytesRemaining(), read_buf_len);
127     memcpy(read_buf->data(), buf->data(), copied);
128     buf->DidConsume(copied);
129
130     if (!buf->BytesRemaining())
131       data_.pop();
132     return copied;
133   }
134
135   CompletionCallback read_callback_;
136   scoped_refptr<net::IOBuffer> read_buf_;
137   int read_buf_len_;
138
139   CompletionCallback write_callback_;
140
141   std::queue<scoped_refptr<net::DrainableIOBuffer> > data_;
142
143   // True if Close() has been called.
144   bool closed_;
145
146   // Controls the completion of Write() after the FakeDataChannel is closed.
147   // After the FakeDataChannel is closed, the first Write() call completes
148   // asynchronously.
149   bool write_called_after_close_;
150
151   base::WeakPtrFactory<FakeDataChannel> weak_factory_;
152
153   DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
154 };
155
156 class FakeSocket : public StreamSocket {
157  public:
158   FakeSocket(FakeDataChannel* incoming_channel,
159              FakeDataChannel* outgoing_channel)
160       : incoming_(incoming_channel),
161         outgoing_(outgoing_channel) {
162   }
163
164   virtual ~FakeSocket() {
165   }
166
167   virtual int Read(IOBuffer* buf, int buf_len,
168                    const CompletionCallback& callback) OVERRIDE {
169     // Read random number of bytes.
170     buf_len = rand() % buf_len + 1;
171     return incoming_->Read(buf, buf_len, callback);
172   }
173
174   virtual int Write(IOBuffer* buf, int buf_len,
175                     const CompletionCallback& callback) OVERRIDE {
176     // Write random number of bytes.
177     buf_len = rand() % buf_len + 1;
178     return outgoing_->Write(buf, buf_len, callback);
179   }
180
181   virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
182     return net::OK;
183   }
184
185   virtual int SetSendBufferSize(int32 size) OVERRIDE {
186     return net::OK;
187   }
188
189   virtual int Connect(const CompletionCallback& callback) OVERRIDE {
190     return net::OK;
191   }
192
193   virtual void Disconnect() OVERRIDE {
194     incoming_->Close();
195     outgoing_->Close();
196   }
197
198   virtual bool IsConnected() const OVERRIDE {
199     return true;
200   }
201
202   virtual bool IsConnectedAndIdle() const OVERRIDE {
203     return true;
204   }
205
206   virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
207       net::IPAddressNumber ip_address(net::kIPv4AddressSize);
208     *address = net::IPEndPoint(ip_address, 0 /*port*/);
209     return net::OK;
210   }
211
212   virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
213     net::IPAddressNumber ip_address(4);
214     *address = net::IPEndPoint(ip_address, 0);
215     return net::OK;
216   }
217
218   virtual const BoundNetLog& NetLog() const OVERRIDE {
219     return net_log_;
220   }
221
222   virtual void SetSubresourceSpeculation() OVERRIDE {}
223   virtual void SetOmniboxSpeculation() OVERRIDE {}
224
225   virtual bool WasEverUsed() const OVERRIDE {
226     return true;
227   }
228
229   virtual bool UsingTCPFastOpen() const OVERRIDE {
230     return false;
231   }
232
233
234   virtual bool WasNpnNegotiated() const OVERRIDE {
235     return false;
236   }
237
238   virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
239     return kProtoUnknown;
240   }
241
242   virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
243     return false;
244   }
245
246  private:
247   net::BoundNetLog net_log_;
248   FakeDataChannel* incoming_;
249   FakeDataChannel* outgoing_;
250
251   DISALLOW_COPY_AND_ASSIGN(FakeSocket);
252 };
253
254 }  // namespace
255
256 // Verify the correctness of the test helper classes first.
257 TEST(FakeSocketTest, DataTransfer) {
258   // Establish channels between two sockets.
259   FakeDataChannel channel_1;
260   FakeDataChannel channel_2;
261   FakeSocket client(&channel_1, &channel_2);
262   FakeSocket server(&channel_2, &channel_1);
263
264   const char kTestData[] = "testing123";
265   const int kTestDataSize = strlen(kTestData);
266   const int kReadBufSize = 1024;
267   scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData);
268   scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
269
270   // Write then read.
271   int written =
272       server.Write(write_buf.get(), kTestDataSize, CompletionCallback());
273   EXPECT_GT(written, 0);
274   EXPECT_LE(written, kTestDataSize);
275
276   int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback());
277   EXPECT_GT(read, 0);
278   EXPECT_LE(read, written);
279   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
280
281   // Read then write.
282   TestCompletionCallback callback;
283   EXPECT_EQ(net::ERR_IO_PENDING,
284             server.Read(read_buf.get(), kReadBufSize, callback.callback()));
285
286   written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback());
287   EXPECT_GT(written, 0);
288   EXPECT_LE(written, kTestDataSize);
289
290   read = callback.WaitForResult();
291   EXPECT_GT(read, 0);
292   EXPECT_LE(read, written);
293   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
294 }
295
296 class SSLServerSocketTest : public PlatformTest {
297  public:
298   SSLServerSocketTest()
299       : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()),
300         cert_verifier_(new MockCertVerifier()),
301         transport_security_state_(new TransportSecurityState) {
302     cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID);
303   }
304
305  protected:
306   void Initialize() {
307     scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle);
308     client_connection->SetSocket(
309         scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_)));
310     scoped_ptr<StreamSocket> server_socket(
311         new FakeSocket(&channel_2_, &channel_1_));
312
313     base::FilePath certs_dir(GetTestCertsDirectory());
314
315     base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
316     std::string cert_der;
317     ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
318
319     scoped_refptr<net::X509Certificate> cert =
320         X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
321
322     base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
323     std::string key_string;
324     ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
325     std::vector<uint8> key_vector(
326         reinterpret_cast<const uint8*>(key_string.data()),
327         reinterpret_cast<const uint8*>(key_string.data() +
328                                        key_string.length()));
329
330     scoped_ptr<crypto::RSAPrivateKey> private_key(
331         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
332
333     net::SSLConfig ssl_config;
334     ssl_config.false_start_enabled = false;
335     ssl_config.channel_id_enabled = false;
336
337     // Certificate provided by the host doesn't need authority.
338     net::SSLConfig::CertAndStatus cert_and_status;
339     cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
340     cert_and_status.der_cert = cert_der;
341     ssl_config.allowed_bad_certs.push_back(cert_and_status);
342
343     net::HostPortPair host_and_pair("unittest", 0);
344     net::SSLClientSocketContext context;
345     context.cert_verifier = cert_verifier_.get();
346     context.transport_security_state = transport_security_state_.get();
347     client_socket_ =
348         socket_factory_->CreateSSLClientSocket(
349             client_connection.Pass(), host_and_pair, ssl_config, context);
350     server_socket_ = net::CreateSSLServerSocket(
351         server_socket.Pass(),
352         cert.get(), private_key.get(), net::SSLConfig());
353   }
354
355   FakeDataChannel channel_1_;
356   FakeDataChannel channel_2_;
357   scoped_ptr<net::SSLClientSocket> client_socket_;
358   scoped_ptr<net::SSLServerSocket> server_socket_;
359   net::ClientSocketFactory* socket_factory_;
360   scoped_ptr<net::MockCertVerifier> cert_verifier_;
361   scoped_ptr<net::TransportSecurityState> transport_security_state_;
362 };
363
364 // SSLServerSocket is only implemented using NSS.
365 #if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX)
366
367 // This test only executes creation of client and server sockets. This is to
368 // test that creation of sockets doesn't crash and have minimal code to run
369 // under valgrind in order to help debugging memory problems.
370 TEST_F(SSLServerSocketTest, Initialize) {
371   Initialize();
372 }
373
374 // This test executes Connect() on SSLClientSocket and Handshake() on
375 // SSLServerSocket to make sure handshaking between the two sockets is
376 // completed successfully.
377 TEST_F(SSLServerSocketTest, Handshake) {
378   Initialize();
379
380   TestCompletionCallback connect_callback;
381   TestCompletionCallback handshake_callback;
382
383   int server_ret = server_socket_->Handshake(handshake_callback.callback());
384   EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
385
386   int client_ret = client_socket_->Connect(connect_callback.callback());
387   EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
388
389   if (client_ret == net::ERR_IO_PENDING) {
390     EXPECT_EQ(net::OK, connect_callback.WaitForResult());
391   }
392   if (server_ret == net::ERR_IO_PENDING) {
393     EXPECT_EQ(net::OK, handshake_callback.WaitForResult());
394   }
395
396   // Make sure the cert status is expected.
397   SSLInfo ssl_info;
398   client_socket_->GetSSLInfo(&ssl_info);
399   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
400 }
401
402 TEST_F(SSLServerSocketTest, DataTransfer) {
403   Initialize();
404
405   TestCompletionCallback connect_callback;
406   TestCompletionCallback handshake_callback;
407
408   // Establish connection.
409   int client_ret = client_socket_->Connect(connect_callback.callback());
410   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
411
412   int server_ret = server_socket_->Handshake(handshake_callback.callback());
413   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
414
415   client_ret = connect_callback.GetResult(client_ret);
416   ASSERT_EQ(net::OK, client_ret);
417   server_ret = handshake_callback.GetResult(server_ret);
418   ASSERT_EQ(net::OK, server_ret);
419
420   const int kReadBufSize = 1024;
421   scoped_refptr<net::StringIOBuffer> write_buf =
422       new net::StringIOBuffer("testing123");
423   scoped_refptr<net::DrainableIOBuffer> read_buf =
424       new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize),
425                                  kReadBufSize);
426
427   // Write then read.
428   TestCompletionCallback write_callback;
429   TestCompletionCallback read_callback;
430   server_ret = server_socket_->Write(
431       write_buf.get(), write_buf->size(), write_callback.callback());
432   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
433   client_ret = client_socket_->Read(
434       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
435   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
436
437   server_ret = write_callback.GetResult(server_ret);
438   EXPECT_GT(server_ret, 0);
439   client_ret = read_callback.GetResult(client_ret);
440   ASSERT_GT(client_ret, 0);
441
442   read_buf->DidConsume(client_ret);
443   while (read_buf->BytesConsumed() < write_buf->size()) {
444     client_ret = client_socket_->Read(
445         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
446     EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
447     client_ret = read_callback.GetResult(client_ret);
448     ASSERT_GT(client_ret, 0);
449     read_buf->DidConsume(client_ret);
450   }
451   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
452   read_buf->SetOffset(0);
453   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
454
455   // Read then write.
456   write_buf = new net::StringIOBuffer("hello123");
457   server_ret = server_socket_->Read(
458       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
459   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
460   client_ret = client_socket_->Write(
461       write_buf.get(), write_buf->size(), write_callback.callback());
462   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
463
464   server_ret = read_callback.GetResult(server_ret);
465   ASSERT_GT(server_ret, 0);
466   client_ret = write_callback.GetResult(client_ret);
467   EXPECT_GT(client_ret, 0);
468
469   read_buf->DidConsume(server_ret);
470   while (read_buf->BytesConsumed() < write_buf->size()) {
471     server_ret = server_socket_->Read(
472         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
473     EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
474     server_ret = read_callback.GetResult(server_ret);
475     ASSERT_GT(server_ret, 0);
476     read_buf->DidConsume(server_ret);
477   }
478   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
479   read_buf->SetOffset(0);
480   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
481 }
482
483 // A regression test for bug 127822 (http://crbug.com/127822).
484 // If the server closes the connection after the handshake is finished,
485 // the client's Write() call should not cause an infinite loop.
486 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
487 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
488   Initialize();
489
490   TestCompletionCallback connect_callback;
491   TestCompletionCallback handshake_callback;
492
493   // Establish connection.
494   int client_ret = client_socket_->Connect(connect_callback.callback());
495   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
496
497   int server_ret = server_socket_->Handshake(handshake_callback.callback());
498   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
499
500   client_ret = connect_callback.GetResult(client_ret);
501   ASSERT_EQ(net::OK, client_ret);
502   server_ret = handshake_callback.GetResult(server_ret);
503   ASSERT_EQ(net::OK, server_ret);
504
505   scoped_refptr<net::StringIOBuffer> write_buf =
506       new net::StringIOBuffer("testing123");
507
508   // The server closes the connection. The server needs to write some
509   // data first so that the client's Read() calls from the transport
510   // socket won't return ERR_IO_PENDING.  This ensures that the client
511   // will call Read() on the transport socket again.
512   TestCompletionCallback write_callback;
513
514   server_ret = server_socket_->Write(
515       write_buf.get(), write_buf->size(), write_callback.callback());
516   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
517
518   server_ret = write_callback.GetResult(server_ret);
519   EXPECT_GT(server_ret, 0);
520
521   server_socket_->Disconnect();
522
523   // The client writes some data. This should not cause an infinite loop.
524   client_ret = client_socket_->Write(
525       write_buf.get(), write_buf->size(), write_callback.callback());
526   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
527
528   client_ret = write_callback.GetResult(client_ret);
529   EXPECT_GT(client_ret, 0);
530
531   base::MessageLoop::current()->PostDelayedTask(
532       FROM_HERE, base::MessageLoop::QuitClosure(),
533       base::TimeDelta::FromMilliseconds(10));
534   base::MessageLoop::current()->Run();
535 }
536
537 // This test executes ExportKeyingMaterial() on the client and server sockets,
538 // after connecting them, and verifies that the results match.
539 // This test will fail if False Start is enabled (see crbug.com/90208).
540 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
541   Initialize();
542
543   TestCompletionCallback connect_callback;
544   TestCompletionCallback handshake_callback;
545
546   int client_ret = client_socket_->Connect(connect_callback.callback());
547   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
548
549   int server_ret = server_socket_->Handshake(handshake_callback.callback());
550   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
551
552   if (client_ret == net::ERR_IO_PENDING) {
553     ASSERT_EQ(net::OK, connect_callback.WaitForResult());
554   }
555   if (server_ret == net::ERR_IO_PENDING) {
556     ASSERT_EQ(net::OK, handshake_callback.WaitForResult());
557   }
558
559   const int kKeyingMaterialSize = 32;
560   const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test";
561   const char* kKeyingContext = "";
562   unsigned char server_out[kKeyingMaterialSize];
563   int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel,
564                                                 false, kKeyingContext,
565                                                 server_out, sizeof(server_out));
566   ASSERT_EQ(net::OK, rv);
567
568   unsigned char client_out[kKeyingMaterialSize];
569   rv = client_socket_->ExportKeyingMaterial(kKeyingLabel,
570                                             false, kKeyingContext,
571                                             client_out, sizeof(client_out));
572   ASSERT_EQ(net::OK, rv);
573   EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
574
575   const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad";
576   unsigned char client_bad[kKeyingMaterialSize];
577   rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad,
578                                             false, kKeyingContext,
579                                             client_bad, sizeof(client_bad));
580   ASSERT_EQ(rv, net::OK);
581   EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
582 }
583 #endif
584
585 }  // namespace net