- add sources.
[platform/framework/web/crosswalk.git] / src / net / socket / ssl_server_socket_unittest.cc
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 //    Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15
16 #include "net/socket/ssl_server_socket.h"
17
18 #include <stdlib.h>
19
20 #include <queue>
21
22 #include "base/compiler_specific.h"
23 #include "base/file_util.h"
24 #include "base/files/file_path.h"
25 #include "base/message_loop/message_loop.h"
26 #include "base/path_service.h"
27 #include "crypto/nss_util.h"
28 #include "crypto/rsa_private_key.h"
29 #include "net/base/address_list.h"
30 #include "net/base/completion_callback.h"
31 #include "net/base/host_port_pair.h"
32 #include "net/base/io_buffer.h"
33 #include "net/base/ip_endpoint.h"
34 #include "net/base/net_errors.h"
35 #include "net/base/net_log.h"
36 #include "net/base/test_data_directory.h"
37 #include "net/cert/cert_status_flags.h"
38 #include "net/cert/mock_cert_verifier.h"
39 #include "net/cert/x509_certificate.h"
40 #include "net/http/transport_security_state.h"
41 #include "net/socket/client_socket_factory.h"
42 #include "net/socket/socket_test_util.h"
43 #include "net/socket/ssl_client_socket.h"
44 #include "net/socket/stream_socket.h"
45 #include "net/ssl/ssl_config_service.h"
46 #include "net/ssl/ssl_info.h"
47 #include "net/test/cert_test_util.h"
48 #include "testing/gtest/include/gtest/gtest.h"
49 #include "testing/platform_test.h"
50
51 namespace net {
52
53 namespace {
54
55 class FakeDataChannel {
56  public:
57   FakeDataChannel()
58       : read_buf_len_(0),
59         weak_factory_(this),
60         closed_(false),
61         write_called_after_close_(false) {
62   }
63
64   int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
65     if (closed_)
66       return 0;
67     if (data_.empty()) {
68       read_callback_ = callback;
69       read_buf_ = buf;
70       read_buf_len_ = buf_len;
71       return net::ERR_IO_PENDING;
72     }
73     return PropogateData(buf, buf_len);
74   }
75
76   int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
77     if (closed_) {
78       if (write_called_after_close_)
79         return net::ERR_CONNECTION_RESET;
80       write_called_after_close_ = true;
81       write_callback_ = callback;
82       base::MessageLoop::current()->PostTask(
83           FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback,
84                                 weak_factory_.GetWeakPtr()));
85       return net::ERR_IO_PENDING;
86     }
87     data_.push(new net::DrainableIOBuffer(buf, buf_len));
88     base::MessageLoop::current()->PostTask(
89         FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
90                               weak_factory_.GetWeakPtr()));
91     return buf_len;
92   }
93
94   // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
95   // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
96   // after the FakeDataChannel is closed, the first Write() call completes
97   // asynchronously, which is necessary to reproduce bug 127822.
98   void Close() {
99     closed_ = true;
100   }
101
102  private:
103   void DoReadCallback() {
104     if (read_callback_.is_null() || data_.empty())
105       return;
106
107     int copied = PropogateData(read_buf_, read_buf_len_);
108     CompletionCallback callback = read_callback_;
109     read_callback_.Reset();
110     read_buf_ = NULL;
111     read_buf_len_ = 0;
112     callback.Run(copied);
113   }
114
115   void DoWriteCallback() {
116     if (write_callback_.is_null())
117       return;
118
119     CompletionCallback callback = write_callback_;
120     write_callback_.Reset();
121     callback.Run(net::ERR_CONNECTION_RESET);
122   }
123
124   int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) {
125     scoped_refptr<net::DrainableIOBuffer> buf = data_.front();
126     int copied = std::min(buf->BytesRemaining(), read_buf_len);
127     memcpy(read_buf->data(), buf->data(), copied);
128     buf->DidConsume(copied);
129
130     if (!buf->BytesRemaining())
131       data_.pop();
132     return copied;
133   }
134
135   CompletionCallback read_callback_;
136   scoped_refptr<net::IOBuffer> read_buf_;
137   int read_buf_len_;
138
139   CompletionCallback write_callback_;
140
141   std::queue<scoped_refptr<net::DrainableIOBuffer> > data_;
142
143   base::WeakPtrFactory<FakeDataChannel> weak_factory_;
144
145   // True if Close() has been called.
146   bool closed_;
147
148   // Controls the completion of Write() after the FakeDataChannel is closed.
149   // After the FakeDataChannel is closed, the first Write() call completes
150   // asynchronously.
151   bool write_called_after_close_;
152
153   DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
154 };
155
156 class FakeSocket : public StreamSocket {
157  public:
158   FakeSocket(FakeDataChannel* incoming_channel,
159              FakeDataChannel* outgoing_channel)
160       : incoming_(incoming_channel),
161         outgoing_(outgoing_channel) {
162   }
163
164   virtual ~FakeSocket() {
165   }
166
167   virtual int Read(IOBuffer* buf, int buf_len,
168                    const CompletionCallback& callback) OVERRIDE {
169     // Read random number of bytes.
170     buf_len = rand() % buf_len + 1;
171     return incoming_->Read(buf, buf_len, callback);
172   }
173
174   virtual int Write(IOBuffer* buf, int buf_len,
175                     const CompletionCallback& callback) OVERRIDE {
176     // Write random number of bytes.
177     buf_len = rand() % buf_len + 1;
178     return outgoing_->Write(buf, buf_len, callback);
179   }
180
181   virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
182     return true;
183   }
184
185   virtual bool SetSendBufferSize(int32 size) OVERRIDE {
186     return true;
187   }
188
189   virtual int Connect(const CompletionCallback& callback) OVERRIDE {
190     return net::OK;
191   }
192
193   virtual void Disconnect() OVERRIDE {
194     incoming_->Close();
195     outgoing_->Close();
196   }
197
198   virtual bool IsConnected() const OVERRIDE {
199     return true;
200   }
201
202   virtual bool IsConnectedAndIdle() const OVERRIDE {
203     return true;
204   }
205
206   virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
207       net::IPAddressNumber ip_address(net::kIPv4AddressSize);
208     *address = net::IPEndPoint(ip_address, 0 /*port*/);
209     return net::OK;
210   }
211
212   virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
213     net::IPAddressNumber ip_address(4);
214     *address = net::IPEndPoint(ip_address, 0);
215     return net::OK;
216   }
217
218   virtual const BoundNetLog& NetLog() const OVERRIDE {
219     return net_log_;
220   }
221
222   virtual void SetSubresourceSpeculation() OVERRIDE {}
223   virtual void SetOmniboxSpeculation() OVERRIDE {}
224
225   virtual bool WasEverUsed() const OVERRIDE {
226     return true;
227   }
228
229   virtual bool UsingTCPFastOpen() const OVERRIDE {
230     return false;
231   }
232
233
234   virtual bool WasNpnNegotiated() const OVERRIDE {
235     return false;
236   }
237
238   virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
239     return kProtoUnknown;
240   }
241
242   virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {
243     return false;
244   }
245
246  private:
247   net::BoundNetLog net_log_;
248   FakeDataChannel* incoming_;
249   FakeDataChannel* outgoing_;
250
251   DISALLOW_COPY_AND_ASSIGN(FakeSocket);
252 };
253
254 }  // namespace
255
256 // Verify the correctness of the test helper classes first.
257 TEST(FakeSocketTest, DataTransfer) {
258   // Establish channels between two sockets.
259   FakeDataChannel channel_1;
260   FakeDataChannel channel_2;
261   FakeSocket client(&channel_1, &channel_2);
262   FakeSocket server(&channel_2, &channel_1);
263
264   const char kTestData[] = "testing123";
265   const int kTestDataSize = strlen(kTestData);
266   const int kReadBufSize = 1024;
267   scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData);
268   scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
269
270   // Write then read.
271   int written =
272       server.Write(write_buf.get(), kTestDataSize, CompletionCallback());
273   EXPECT_GT(written, 0);
274   EXPECT_LE(written, kTestDataSize);
275
276   int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback());
277   EXPECT_GT(read, 0);
278   EXPECT_LE(read, written);
279   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
280
281   // Read then write.
282   TestCompletionCallback callback;
283   EXPECT_EQ(net::ERR_IO_PENDING,
284             server.Read(read_buf.get(), kReadBufSize, callback.callback()));
285
286   written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback());
287   EXPECT_GT(written, 0);
288   EXPECT_LE(written, kTestDataSize);
289
290   read = callback.WaitForResult();
291   EXPECT_GT(read, 0);
292   EXPECT_LE(read, written);
293   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
294 }
295
296 class SSLServerSocketTest : public PlatformTest {
297  public:
298   SSLServerSocketTest()
299       : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()),
300         cert_verifier_(new MockCertVerifier()),
301         transport_security_state_(new TransportSecurityState) {
302     cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID);
303   }
304
305  protected:
306   void Initialize() {
307     scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle);
308     client_connection->SetSocket(
309         scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_)));
310     scoped_ptr<StreamSocket> server_socket(
311         new FakeSocket(&channel_2_, &channel_1_));
312
313     base::FilePath certs_dir(GetTestCertsDirectory());
314
315     base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
316     std::string cert_der;
317     ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
318
319     scoped_refptr<net::X509Certificate> cert =
320         X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
321
322     base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
323     std::string key_string;
324     ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
325     std::vector<uint8> key_vector(
326         reinterpret_cast<const uint8*>(key_string.data()),
327         reinterpret_cast<const uint8*>(key_string.data() +
328                                        key_string.length()));
329
330     scoped_ptr<crypto::RSAPrivateKey> private_key(
331         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
332
333     net::SSLConfig ssl_config;
334     ssl_config.cached_info_enabled = false;
335     ssl_config.false_start_enabled = false;
336     ssl_config.channel_id_enabled = false;
337
338     // Certificate provided by the host doesn't need authority.
339     net::SSLConfig::CertAndStatus cert_and_status;
340     cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
341     cert_and_status.der_cert = cert_der;
342     ssl_config.allowed_bad_certs.push_back(cert_and_status);
343
344     net::HostPortPair host_and_pair("unittest", 0);
345     net::SSLClientSocketContext context;
346     context.cert_verifier = cert_verifier_.get();
347     context.transport_security_state = transport_security_state_.get();
348     client_socket_ =
349         socket_factory_->CreateSSLClientSocket(
350             client_connection.Pass(), host_and_pair, ssl_config, context);
351     server_socket_ = net::CreateSSLServerSocket(
352         server_socket.Pass(),
353         cert.get(), private_key.get(), net::SSLConfig());
354   }
355
356   FakeDataChannel channel_1_;
357   FakeDataChannel channel_2_;
358   scoped_ptr<net::SSLClientSocket> client_socket_;
359   scoped_ptr<net::SSLServerSocket> server_socket_;
360   net::ClientSocketFactory* socket_factory_;
361   scoped_ptr<net::MockCertVerifier> cert_verifier_;
362   scoped_ptr<net::TransportSecurityState> transport_security_state_;
363 };
364
365 // SSLServerSocket is only implemented using NSS.
366 #if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX)
367
368 // This test only executes creation of client and server sockets. This is to
369 // test that creation of sockets doesn't crash and have minimal code to run
370 // under valgrind in order to help debugging memory problems.
371 TEST_F(SSLServerSocketTest, Initialize) {
372   Initialize();
373 }
374
375 // This test executes Connect() on SSLClientSocket and Handshake() on
376 // SSLServerSocket to make sure handshaking between the two sockets is
377 // completed successfully.
378 TEST_F(SSLServerSocketTest, Handshake) {
379   Initialize();
380
381   TestCompletionCallback connect_callback;
382   TestCompletionCallback handshake_callback;
383
384   int server_ret = server_socket_->Handshake(handshake_callback.callback());
385   EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
386
387   int client_ret = client_socket_->Connect(connect_callback.callback());
388   EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
389
390   if (client_ret == net::ERR_IO_PENDING) {
391     EXPECT_EQ(net::OK, connect_callback.WaitForResult());
392   }
393   if (server_ret == net::ERR_IO_PENDING) {
394     EXPECT_EQ(net::OK, handshake_callback.WaitForResult());
395   }
396
397   // Make sure the cert status is expected.
398   SSLInfo ssl_info;
399   client_socket_->GetSSLInfo(&ssl_info);
400   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
401 }
402
403 TEST_F(SSLServerSocketTest, DataTransfer) {
404   Initialize();
405
406   TestCompletionCallback connect_callback;
407   TestCompletionCallback handshake_callback;
408
409   // Establish connection.
410   int client_ret = client_socket_->Connect(connect_callback.callback());
411   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
412
413   int server_ret = server_socket_->Handshake(handshake_callback.callback());
414   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
415
416   client_ret = connect_callback.GetResult(client_ret);
417   ASSERT_EQ(net::OK, client_ret);
418   server_ret = handshake_callback.GetResult(server_ret);
419   ASSERT_EQ(net::OK, server_ret);
420
421   const int kReadBufSize = 1024;
422   scoped_refptr<net::StringIOBuffer> write_buf =
423       new net::StringIOBuffer("testing123");
424   scoped_refptr<net::DrainableIOBuffer> read_buf =
425       new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize),
426                                  kReadBufSize);
427
428   // Write then read.
429   TestCompletionCallback write_callback;
430   TestCompletionCallback read_callback;
431   server_ret = server_socket_->Write(
432       write_buf.get(), write_buf->size(), write_callback.callback());
433   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
434   client_ret = client_socket_->Read(
435       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
436   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
437
438   server_ret = write_callback.GetResult(server_ret);
439   EXPECT_GT(server_ret, 0);
440   client_ret = read_callback.GetResult(client_ret);
441   ASSERT_GT(client_ret, 0);
442
443   read_buf->DidConsume(client_ret);
444   while (read_buf->BytesConsumed() < write_buf->size()) {
445     client_ret = client_socket_->Read(
446         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
447     EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
448     client_ret = read_callback.GetResult(client_ret);
449     ASSERT_GT(client_ret, 0);
450     read_buf->DidConsume(client_ret);
451   }
452   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
453   read_buf->SetOffset(0);
454   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
455
456   // Read then write.
457   write_buf = new net::StringIOBuffer("hello123");
458   server_ret = server_socket_->Read(
459       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
460   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
461   client_ret = client_socket_->Write(
462       write_buf.get(), write_buf->size(), write_callback.callback());
463   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
464
465   server_ret = read_callback.GetResult(server_ret);
466   ASSERT_GT(server_ret, 0);
467   client_ret = write_callback.GetResult(client_ret);
468   EXPECT_GT(client_ret, 0);
469
470   read_buf->DidConsume(server_ret);
471   while (read_buf->BytesConsumed() < write_buf->size()) {
472     server_ret = server_socket_->Read(
473         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
474     EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
475     server_ret = read_callback.GetResult(server_ret);
476     ASSERT_GT(server_ret, 0);
477     read_buf->DidConsume(server_ret);
478   }
479   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
480   read_buf->SetOffset(0);
481   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
482 }
483
484 // A regression test for bug 127822 (http://crbug.com/127822).
485 // If the server closes the connection after the handshake is finished,
486 // the client's Write() call should not cause an infinite loop.
487 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
488 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
489   Initialize();
490
491   TestCompletionCallback connect_callback;
492   TestCompletionCallback handshake_callback;
493
494   // Establish connection.
495   int client_ret = client_socket_->Connect(connect_callback.callback());
496   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
497
498   int server_ret = server_socket_->Handshake(handshake_callback.callback());
499   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
500
501   client_ret = connect_callback.GetResult(client_ret);
502   ASSERT_EQ(net::OK, client_ret);
503   server_ret = handshake_callback.GetResult(server_ret);
504   ASSERT_EQ(net::OK, server_ret);
505
506   scoped_refptr<net::StringIOBuffer> write_buf =
507       new net::StringIOBuffer("testing123");
508
509   // The server closes the connection. The server needs to write some
510   // data first so that the client's Read() calls from the transport
511   // socket won't return ERR_IO_PENDING.  This ensures that the client
512   // will call Read() on the transport socket again.
513   TestCompletionCallback write_callback;
514
515   server_ret = server_socket_->Write(
516       write_buf.get(), write_buf->size(), write_callback.callback());
517   EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
518
519   server_ret = write_callback.GetResult(server_ret);
520   EXPECT_GT(server_ret, 0);
521
522   server_socket_->Disconnect();
523
524   // The client writes some data. This should not cause an infinite loop.
525   client_ret = client_socket_->Write(
526       write_buf.get(), write_buf->size(), write_callback.callback());
527   EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
528
529   client_ret = write_callback.GetResult(client_ret);
530   EXPECT_GT(client_ret, 0);
531
532   base::MessageLoop::current()->PostDelayedTask(
533       FROM_HERE, base::MessageLoop::QuitClosure(),
534       base::TimeDelta::FromMilliseconds(10));
535   base::MessageLoop::current()->Run();
536 }
537
538 // This test executes ExportKeyingMaterial() on the client and server sockets,
539 // after connecting them, and verifies that the results match.
540 // This test will fail if False Start is enabled (see crbug.com/90208).
541 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
542   Initialize();
543
544   TestCompletionCallback connect_callback;
545   TestCompletionCallback handshake_callback;
546
547   int client_ret = client_socket_->Connect(connect_callback.callback());
548   ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
549
550   int server_ret = server_socket_->Handshake(handshake_callback.callback());
551   ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
552
553   if (client_ret == net::ERR_IO_PENDING) {
554     ASSERT_EQ(net::OK, connect_callback.WaitForResult());
555   }
556   if (server_ret == net::ERR_IO_PENDING) {
557     ASSERT_EQ(net::OK, handshake_callback.WaitForResult());
558   }
559
560   const int kKeyingMaterialSize = 32;
561   const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test";
562   const char* kKeyingContext = "";
563   unsigned char server_out[kKeyingMaterialSize];
564   int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel,
565                                                 false, kKeyingContext,
566                                                 server_out, sizeof(server_out));
567   ASSERT_EQ(net::OK, rv);
568
569   unsigned char client_out[kKeyingMaterialSize];
570   rv = client_socket_->ExportKeyingMaterial(kKeyingLabel,
571                                             false, kKeyingContext,
572                                             client_out, sizeof(client_out));
573   ASSERT_EQ(net::OK, rv);
574   EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
575
576   const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad";
577   unsigned char client_bad[kKeyingMaterialSize];
578   rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad,
579                                             false, kKeyingContext,
580                                             client_bad, sizeof(client_bad));
581   ASSERT_EQ(rv, net::OK);
582   EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
583 }
584 #endif
585
586 }  // namespace net