Upstream version 9.37.197.0
[platform/framework/web/crosswalk.git] / src / third_party / libjingle / source / talk / base / nssstreamadapter.cc
1 /*
2  * libjingle
3  * Copyright 2004--2008, Google Inc.
4  * Copyright 2004--2011, RTFM, Inc.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  *  1. Redistributions of source code must retain the above copyright notice,
10  *     this list of conditions and the following disclaimer.
11  *  2. Redistributions in binary form must reproduce the above copyright notice,
12  *     this list of conditions and the following disclaimer in the documentation
13  *     and/or other materials provided with the distribution.
14  *  3. The name of the author may not be used to endorse or promote products
15  *     derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
18  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
19  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
20  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
25  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
26  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  */
28
29 #include <vector>
30
31 #if HAVE_CONFIG_H
32 #include "config.h"
33 #endif  // HAVE_CONFIG_H
34
35 #if HAVE_NSS_SSL_H
36
37 #include "talk/base/nssstreamadapter.h"
38
39 #include "keyhi.h"
40 #include "nspr.h"
41 #include "nss.h"
42 #include "pk11pub.h"
43 #include "secerr.h"
44
45 #ifdef NSS_SSL_RELATIVE_PATH
46 #include "ssl.h"
47 #include "sslerr.h"
48 #include "sslproto.h"
49 #else
50 #include "net/third_party/nss/ssl/ssl.h"
51 #include "net/third_party/nss/ssl/sslerr.h"
52 #include "net/third_party/nss/ssl/sslproto.h"
53 #endif
54
55 #include "talk/base/nssidentity.h"
56 #include "talk/base/safe_conversions.h"
57 #include "talk/base/thread.h"
58
59 namespace talk_base {
60
61 PRDescIdentity NSSStreamAdapter::nspr_layer_identity = PR_INVALID_IO_LAYER;
62
63 #define UNIMPLEMENTED \
64   PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); \
65   LOG(LS_ERROR) \
66   << "Call to unimplemented function "<< __FUNCTION__; ASSERT(false)
67
68 #ifdef SRTP_AES128_CM_HMAC_SHA1_80
69 #define HAVE_DTLS_SRTP
70 #endif
71
72 #ifdef HAVE_DTLS_SRTP
73 // SRTP cipher suite table
74 struct SrtpCipherMapEntry {
75   const char* external_name;
76   PRUint16 cipher_id;
77 };
78
79 // This isn't elegant, but it's better than an external reference
80 static const SrtpCipherMapEntry kSrtpCipherMap[] = {
81   {"AES_CM_128_HMAC_SHA1_80", SRTP_AES128_CM_HMAC_SHA1_80 },
82   {"AES_CM_128_HMAC_SHA1_32", SRTP_AES128_CM_HMAC_SHA1_32 },
83   {NULL, 0}
84 };
85 #endif
86
87
88 // Implementation of NSPR methods
89 static PRStatus StreamClose(PRFileDesc *socket) {
90   ASSERT(!socket->lower);
91   socket->dtor(socket);
92   return PR_SUCCESS;
93 }
94
95 static PRInt32 StreamRead(PRFileDesc *socket, void *buf, PRInt32 length) {
96   StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
97   size_t read;
98   int error;
99   StreamResult result = stream->Read(buf, length, &read, &error);
100   if (result == SR_SUCCESS) {
101     return checked_cast<PRInt32>(read);
102   }
103
104   if (result == SR_EOS) {
105     return 0;
106   }
107
108   if (result == SR_BLOCK) {
109     PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
110     return -1;
111   }
112
113   PR_SetError(PR_UNKNOWN_ERROR, error);
114   return -1;
115 }
116
117 static PRInt32 StreamWrite(PRFileDesc *socket, const void *buf,
118                            PRInt32 length) {
119   StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
120   size_t written;
121   int error;
122   StreamResult result = stream->Write(buf, length, &written, &error);
123   if (result == SR_SUCCESS) {
124     return checked_cast<PRInt32>(written);
125   }
126
127   if (result == SR_BLOCK) {
128     LOG(LS_INFO) <<
129         "NSSStreamAdapter: write to underlying transport would block";
130     PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
131     return -1;
132   }
133
134   LOG(LS_ERROR) << "Write error";
135   PR_SetError(PR_UNKNOWN_ERROR, error);
136   return -1;
137 }
138
139 static PRInt32 StreamAvailable(PRFileDesc *socket) {
140   UNIMPLEMENTED;
141   return -1;
142 }
143
144 PRInt64 StreamAvailable64(PRFileDesc *socket) {
145   UNIMPLEMENTED;
146   return -1;
147 }
148
149 static PRStatus StreamSync(PRFileDesc *socket) {
150   UNIMPLEMENTED;
151   return PR_FAILURE;
152 }
153
154 static PROffset32 StreamSeek(PRFileDesc *socket, PROffset32 offset,
155                              PRSeekWhence how) {
156   UNIMPLEMENTED;
157   return -1;
158 }
159
160 static PROffset64 StreamSeek64(PRFileDesc *socket, PROffset64 offset,
161                                PRSeekWhence how) {
162   UNIMPLEMENTED;
163   return -1;
164 }
165
166 static PRStatus StreamFileInfo(PRFileDesc *socket, PRFileInfo *info) {
167   UNIMPLEMENTED;
168   return PR_FAILURE;
169 }
170
171 static PRStatus StreamFileInfo64(PRFileDesc *socket, PRFileInfo64 *info) {
172   UNIMPLEMENTED;
173   return PR_FAILURE;
174 }
175
176 static PRInt32 StreamWritev(PRFileDesc *socket, const PRIOVec *iov,
177                      PRInt32 iov_size, PRIntervalTime timeout) {
178   UNIMPLEMENTED;
179   return -1;
180 }
181
182 static PRStatus StreamConnect(PRFileDesc *socket, const PRNetAddr *addr,
183                               PRIntervalTime timeout) {
184   UNIMPLEMENTED;
185   return PR_FAILURE;
186 }
187
188 static PRFileDesc *StreamAccept(PRFileDesc *sd, PRNetAddr *addr,
189                                 PRIntervalTime timeout) {
190   UNIMPLEMENTED;
191   return NULL;
192 }
193
194 static PRStatus StreamBind(PRFileDesc *socket, const PRNetAddr *addr) {
195   UNIMPLEMENTED;
196   return PR_FAILURE;
197 }
198
199 static PRStatus StreamListen(PRFileDesc *socket, PRIntn depth) {
200   UNIMPLEMENTED;
201   return PR_FAILURE;
202 }
203
204 static PRStatus StreamShutdown(PRFileDesc *socket, PRIntn how) {
205   UNIMPLEMENTED;
206   return PR_FAILURE;
207 }
208
209 // Note: this is always nonblocking and ignores the timeout.
210 // TODO(ekr@rtfm.com): In future verify that the socket is
211 // actually in non-blocking mode.
212 // This function does not support peek.
213 static PRInt32 StreamRecv(PRFileDesc *socket, void *buf, PRInt32 amount,
214                    PRIntn flags, PRIntervalTime to) {
215   ASSERT(flags == 0);
216
217   if (flags != 0) {
218     PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
219     return -1;
220   }
221
222   return StreamRead(socket, buf, amount);
223 }
224
225 // Note: this is always nonblocking and assumes a zero timeout.
226 // This function does not support peek.
227 static PRInt32 StreamSend(PRFileDesc *socket, const void *buf,
228                           PRInt32 amount, PRIntn flags,
229                           PRIntervalTime to) {
230   ASSERT(flags == 0);
231
232   return StreamWrite(socket, buf, amount);
233 }
234
235 static PRInt32 StreamRecvfrom(PRFileDesc *socket, void *buf,
236                               PRInt32 amount, PRIntn flags,
237                               PRNetAddr *addr, PRIntervalTime to) {
238   UNIMPLEMENTED;
239   return -1;
240 }
241
242 static PRInt32 StreamSendto(PRFileDesc *socket, const void *buf,
243                             PRInt32 amount, PRIntn flags,
244                             const PRNetAddr *addr, PRIntervalTime to) {
245   UNIMPLEMENTED;
246   return -1;
247 }
248
249 static PRInt16 StreamPoll(PRFileDesc *socket, PRInt16 in_flags,
250                           PRInt16 *out_flags) {
251   UNIMPLEMENTED;
252   return -1;
253 }
254
255 static PRInt32 StreamAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
256                                 PRNetAddr **raddr,
257                                 void *buf, PRInt32 amount, PRIntervalTime t) {
258   UNIMPLEMENTED;
259   return -1;
260 }
261
262 static PRInt32 StreamTransmitFile(PRFileDesc *sd, PRFileDesc *socket,
263                                   const void *headers, PRInt32 hlen,
264                                   PRTransmitFileFlags flags, PRIntervalTime t) {
265   UNIMPLEMENTED;
266   return -1;
267 }
268
269 static PRStatus StreamGetPeerName(PRFileDesc *socket, PRNetAddr *addr) {
270   // TODO(ekr@rtfm.com): Modify to return unique names for each channel
271   // somehow, as opposed to always the same static address. The current
272   // implementation messes up the session cache, which is why it's off
273   // elsewhere
274   addr->inet.family = PR_AF_INET;
275   addr->inet.port = 0;
276   addr->inet.ip = 0;
277
278   return PR_SUCCESS;
279 }
280
281 static PRStatus StreamGetSockName(PRFileDesc *socket, PRNetAddr *addr) {
282   UNIMPLEMENTED;
283   return PR_FAILURE;
284 }
285
286 static PRStatus StreamGetSockOption(PRFileDesc *socket, PRSocketOptionData *opt) {
287   switch (opt->option) {
288     case PR_SockOpt_Nonblocking:
289       opt->value.non_blocking = PR_TRUE;
290       return PR_SUCCESS;
291     default:
292       UNIMPLEMENTED;
293       break;
294   }
295
296   return PR_FAILURE;
297 }
298
299 // Imitate setting socket options. These are mostly noops.
300 static PRStatus StreamSetSockOption(PRFileDesc *socket,
301                                     const PRSocketOptionData *opt) {
302   switch (opt->option) {
303     case PR_SockOpt_Nonblocking:
304       return PR_SUCCESS;
305     case PR_SockOpt_NoDelay:
306       return PR_SUCCESS;
307     default:
308       UNIMPLEMENTED;
309       break;
310   }
311
312   return PR_FAILURE;
313 }
314
315 static PRInt32 StreamSendfile(PRFileDesc *out, PRSendFileData *in,
316                               PRTransmitFileFlags flags, PRIntervalTime to) {
317   UNIMPLEMENTED;
318   return -1;
319 }
320
321 static PRStatus StreamConnectContinue(PRFileDesc *socket, PRInt16 flags) {
322   UNIMPLEMENTED;
323   return PR_FAILURE;
324 }
325
326 static PRIntn StreamReserved(PRFileDesc *socket) {
327   UNIMPLEMENTED;
328   return -1;
329 }
330
331 static const struct PRIOMethods nss_methods = {
332   PR_DESC_LAYERED,
333   StreamClose,
334   StreamRead,
335   StreamWrite,
336   StreamAvailable,
337   StreamAvailable64,
338   StreamSync,
339   StreamSeek,
340   StreamSeek64,
341   StreamFileInfo,
342   StreamFileInfo64,
343   StreamWritev,
344   StreamConnect,
345   StreamAccept,
346   StreamBind,
347   StreamListen,
348   StreamShutdown,
349   StreamRecv,
350   StreamSend,
351   StreamRecvfrom,
352   StreamSendto,
353   StreamPoll,
354   StreamAcceptRead,
355   StreamTransmitFile,
356   StreamGetSockName,
357   StreamGetPeerName,
358   StreamReserved,
359   StreamReserved,
360   StreamGetSockOption,
361   StreamSetSockOption,
362   StreamSendfile,
363   StreamConnectContinue,
364   StreamReserved,
365   StreamReserved,
366   StreamReserved,
367   StreamReserved
368 };
369
370 NSSStreamAdapter::NSSStreamAdapter(StreamInterface *stream)
371     : SSLStreamAdapterHelper(stream),
372       ssl_fd_(NULL),
373       cert_ok_(false) {
374 }
375
376 bool NSSStreamAdapter::Init() {
377   if (nspr_layer_identity == PR_INVALID_IO_LAYER) {
378     nspr_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
379   }
380   PRFileDesc *pr_fd = PR_CreateIOLayerStub(nspr_layer_identity, &nss_methods);
381   if (!pr_fd)
382     return false;
383   pr_fd->secret = reinterpret_cast<PRFilePrivate *>(stream());
384
385   PRFileDesc *ssl_fd;
386   if (ssl_mode_ == SSL_MODE_DTLS) {
387     ssl_fd = DTLS_ImportFD(NULL, pr_fd);
388   } else {
389     ssl_fd = SSL_ImportFD(NULL, pr_fd);
390   }
391   ASSERT(ssl_fd != NULL);  // This should never happen
392   if (!ssl_fd) {
393     PR_Close(pr_fd);
394     return false;
395   }
396
397   SECStatus rv;
398   // Turn on security.
399   rv = SSL_OptionSet(ssl_fd, SSL_SECURITY, PR_TRUE);
400   if (rv != SECSuccess) {
401     LOG(LS_ERROR) << "Error enabling security on SSL Socket";
402     return false;
403   }
404
405   // Disable SSLv2.
406   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SSL2, PR_FALSE);
407   if (rv != SECSuccess) {
408     LOG(LS_ERROR) << "Error disabling SSL2";
409     return false;
410   }
411
412   // Disable caching.
413   // TODO(ekr@rtfm.com): restore this when I have the caching
414   // identity set.
415   rv = SSL_OptionSet(ssl_fd, SSL_NO_CACHE, PR_TRUE);
416   if (rv != SECSuccess) {
417     LOG(LS_ERROR) << "Error disabling cache";
418     return false;
419   }
420
421   // Disable session tickets.
422   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
423   if (rv != SECSuccess) {
424     LOG(LS_ERROR) << "Error enabling tickets";
425     return false;
426   }
427
428   // Disable renegotiation.
429   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_RENEGOTIATION,
430                      SSL_RENEGOTIATE_NEVER);
431   if (rv != SECSuccess) {
432     LOG(LS_ERROR) << "Error disabling renegotiation";
433     return false;
434   }
435
436   // Disable false start.
437   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_FALSE_START, PR_FALSE);
438   if (rv != SECSuccess) {
439     LOG(LS_ERROR) << "Error disabling false start";
440     return false;
441   }
442
443   ssl_fd_ = ssl_fd;
444
445   return true;
446 }
447
448 NSSStreamAdapter::~NSSStreamAdapter() {
449   if (ssl_fd_)
450     PR_Close(ssl_fd_);
451 };
452
453
454 int NSSStreamAdapter::BeginSSL() {
455   SECStatus rv;
456
457   if (!Init()) {
458     Error("Init", -1, false);
459     return -1;
460   }
461
462   ASSERT(state_ == SSL_CONNECTING);
463   // The underlying stream has been opened. If we are in peer-to-peer mode
464   // then a peer certificate must have been specified by now.
465   ASSERT(!ssl_server_name_.empty() ||
466          peer_certificate_.get() != NULL ||
467          !peer_certificate_digest_algorithm_.empty());
468   LOG(LS_INFO) << "BeginSSL: "
469                << (!ssl_server_name_.empty() ? ssl_server_name_ :
470                                                "with peer");
471
472   if (role_ == SSL_CLIENT) {
473     LOG(LS_INFO) << "BeginSSL: as client";
474
475     rv = SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
476                                    this);
477     if (rv != SECSuccess) {
478       Error("BeginSSL", -1, false);
479       return -1;
480     }
481   } else {
482     LOG(LS_INFO) << "BeginSSL: as server";
483     NSSIdentity *identity;
484
485     if (identity_.get()) {
486       identity = static_cast<NSSIdentity *>(identity_.get());
487     } else {
488       LOG(LS_ERROR) << "Can't be an SSL server without an identity";
489       Error("BeginSSL", -1, false);
490       return -1;
491     }
492     rv = SSL_ConfigSecureServer(ssl_fd_, identity->certificate().certificate(),
493                                 identity->keypair()->privkey(),
494                                 kt_rsa);
495     if (rv != SECSuccess) {
496       Error("BeginSSL", -1, false);
497       return -1;
498     }
499
500     // Insist on a certificate from the client
501     rv = SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE);
502     if (rv != SECSuccess) {
503       Error("BeginSSL", -1, false);
504       return -1;
505     }
506
507     rv = SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
508     if (rv != SECSuccess) {
509       Error("BeginSSL", -1, false);
510       return -1;
511     }
512   }
513
514   // Set the version range.
515   SSLVersionRange vrange;
516   vrange.min =  (ssl_mode_ == SSL_MODE_DTLS) ?
517       SSL_LIBRARY_VERSION_TLS_1_1 :
518       SSL_LIBRARY_VERSION_TLS_1_0;
519   vrange.max = SSL_LIBRARY_VERSION_TLS_1_1;
520
521   rv = SSL_VersionRangeSet(ssl_fd_, &vrange);
522   if (rv != SECSuccess) {
523     Error("BeginSSL", -1, false);
524     return -1;
525   }
526
527   // SRTP
528 #ifdef HAVE_DTLS_SRTP
529   if (!srtp_ciphers_.empty()) {
530     rv = SSL_SetSRTPCiphers(
531         ssl_fd_, &srtp_ciphers_[0],
532         checked_cast<unsigned int>(srtp_ciphers_.size()));
533     if (rv != SECSuccess) {
534       Error("BeginSSL", -1, false);
535       return -1;
536     }
537   }
538 #endif
539
540   // Certificate validation
541   rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
542   if (rv != SECSuccess) {
543     Error("BeginSSL", -1, false);
544     return -1;
545   }
546
547   // Now start the handshake
548   rv = SSL_ResetHandshake(ssl_fd_, role_ == SSL_SERVER ? PR_TRUE : PR_FALSE);
549   if (rv != SECSuccess) {
550     Error("BeginSSL", -1, false);
551     return -1;
552   }
553
554   return ContinueSSL();
555 }
556
557 int NSSStreamAdapter::ContinueSSL() {
558   LOG(LS_INFO) << "ContinueSSL";
559   ASSERT(state_ == SSL_CONNECTING);
560
561   // Clear the DTLS timer
562   Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
563
564   SECStatus rv = SSL_ForceHandshake(ssl_fd_);
565
566   if (rv == SECSuccess) {
567     LOG(LS_INFO) << "Handshake complete";
568
569     ASSERT(cert_ok_);
570     if (!cert_ok_) {
571       Error("ContinueSSL", -1, true);
572       return -1;
573     }
574
575     state_ = SSL_CONNECTED;
576     StreamAdapterInterface::OnEvent(stream(), SE_OPEN|SE_READ|SE_WRITE, 0);
577     return 0;
578   }
579
580   PRInt32 err = PR_GetError();
581   switch (err) {
582     case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
583       if (ssl_mode_ != SSL_MODE_DTLS) {
584         Error("ContinueSSL", -1, true);
585         return -1;
586       } else {
587         LOG(LS_INFO) << "Malformed DTLS message. Ignoring.";
588         // Fall through
589       }
590     case PR_WOULD_BLOCK_ERROR:
591       LOG(LS_INFO) << "Would have blocked";
592       if (ssl_mode_ == SSL_MODE_DTLS) {
593         PRIntervalTime timeout;
594
595         SECStatus rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
596         if (rv == SECSuccess) {
597           LOG(LS_INFO) << "Timeout is " << timeout << " ms";
598           Thread::Current()->PostDelayed(PR_IntervalToMilliseconds(timeout),
599                                          this, MSG_DTLS_TIMEOUT, 0);
600         }
601       }
602
603       return 0;
604     default:
605       LOG(LS_INFO) << "Error " << err;
606       break;
607   }
608
609   Error("ContinueSSL", -1, true);
610   return -1;
611 }
612
613 void NSSStreamAdapter::Cleanup() {
614   if (state_ != SSL_ERROR) {
615     state_ = SSL_CLOSED;
616   }
617
618   if (ssl_fd_) {
619     PR_Close(ssl_fd_);
620     ssl_fd_ = NULL;
621   }
622
623   identity_.reset();
624   peer_certificate_.reset();
625
626   Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
627 }
628
629 StreamResult NSSStreamAdapter::Read(void* data, size_t data_len,
630                                     size_t* read, int* error) {
631   // SSL_CONNECTED sanity check.
632   switch (state_) {
633     case SSL_NONE:
634     case SSL_WAIT:
635     case SSL_CONNECTING:
636       return SR_BLOCK;
637
638     case SSL_CONNECTED:
639       break;
640
641     case SSL_CLOSED:
642       return SR_EOS;
643
644     case SSL_ERROR:
645     default:
646       if (error)
647         *error = ssl_error_code_;
648       return SR_ERROR;
649   }
650
651   PRInt32 rv = PR_Read(ssl_fd_, data, checked_cast<PRInt32>(data_len));
652
653   if (rv == 0) {
654     return SR_EOS;
655   }
656
657   // Error
658   if (rv < 0) {
659     PRInt32 err = PR_GetError();
660
661     switch (err) {
662       case PR_WOULD_BLOCK_ERROR:
663         return SR_BLOCK;
664       default:
665         Error("Read", -1, false);
666         *error = err;  // libjingle semantics are that this is impl-specific
667         return SR_ERROR;
668     }
669   }
670
671   // Success
672   *read = rv;
673
674   return SR_SUCCESS;
675 }
676
677 StreamResult NSSStreamAdapter::Write(const void* data, size_t data_len,
678                                      size_t* written, int* error) {
679   // SSL_CONNECTED sanity check.
680   switch (state_) {
681     case SSL_NONE:
682     case SSL_WAIT:
683     case SSL_CONNECTING:
684       return SR_BLOCK;
685
686     case SSL_CONNECTED:
687       break;
688
689     case SSL_ERROR:
690     case SSL_CLOSED:
691     default:
692       if (error)
693         *error = ssl_error_code_;
694       return SR_ERROR;
695   }
696
697   PRInt32 rv = PR_Write(ssl_fd_, data, checked_cast<PRInt32>(data_len));
698
699   // Error
700   if (rv < 0) {
701     PRInt32 err = PR_GetError();
702
703     switch (err) {
704       case PR_WOULD_BLOCK_ERROR:
705         return SR_BLOCK;
706       default:
707         Error("Write", -1, false);
708         *error = err;  // libjingle semantics are that this is impl-specific
709         return SR_ERROR;
710     }
711   }
712
713   // Success
714   *written = rv;
715
716   return SR_SUCCESS;
717 }
718
719 void NSSStreamAdapter::OnEvent(StreamInterface* stream, int events,
720                                int err) {
721   int events_to_signal = 0;
722   int signal_error = 0;
723   ASSERT(stream == this->stream());
724   if ((events & SE_OPEN)) {
725     LOG(LS_INFO) << "NSSStreamAdapter::OnEvent SE_OPEN";
726     if (state_ != SSL_WAIT) {
727       ASSERT(state_ == SSL_NONE);
728       events_to_signal |= SE_OPEN;
729     } else {
730       state_ = SSL_CONNECTING;
731       if (int err = BeginSSL()) {
732         Error("BeginSSL", err, true);
733         return;
734       }
735     }
736   }
737   if ((events & (SE_READ|SE_WRITE))) {
738     LOG(LS_INFO) << "NSSStreamAdapter::OnEvent"
739                  << ((events & SE_READ) ? " SE_READ" : "")
740                  << ((events & SE_WRITE) ? " SE_WRITE" : "");
741     if (state_ == SSL_NONE) {
742       events_to_signal |= events & (SE_READ|SE_WRITE);
743     } else if (state_ == SSL_CONNECTING) {
744       if (int err = ContinueSSL()) {
745         Error("ContinueSSL", err, true);
746         return;
747       }
748     } else if (state_ == SSL_CONNECTED) {
749       if (events & SE_WRITE) {
750         LOG(LS_INFO) << " -- onStreamWriteable";
751         events_to_signal |= SE_WRITE;
752       }
753       if (events & SE_READ) {
754         LOG(LS_INFO) << " -- onStreamReadable";
755         events_to_signal |= SE_READ;
756       }
757     }
758   }
759   if ((events & SE_CLOSE)) {
760     LOG(LS_INFO) << "NSSStreamAdapter::OnEvent(SE_CLOSE, " << err << ")";
761     Cleanup();
762     events_to_signal |= SE_CLOSE;
763     // SE_CLOSE is the only event that uses the final parameter to OnEvent().
764     ASSERT(signal_error == 0);
765     signal_error = err;
766   }
767   if (events_to_signal)
768     StreamAdapterInterface::OnEvent(stream, events_to_signal, signal_error);
769 }
770
771 void NSSStreamAdapter::OnMessage(Message* msg) {
772   // Process our own messages and then pass others to the superclass
773   if (MSG_DTLS_TIMEOUT == msg->message_id) {
774     LOG(LS_INFO) << "DTLS timeout expired";
775     ContinueSSL();
776   } else {
777     StreamInterface::OnMessage(msg);
778   }
779 }
780
781 // Certificate verification callback. Called to check any certificate
782 SECStatus NSSStreamAdapter::AuthCertificateHook(void *arg,
783                                                 PRFileDesc *fd,
784                                                 PRBool checksig,
785                                                 PRBool isServer) {
786   LOG(LS_INFO) << "NSSStreamAdapter::AuthCertificateHook";
787   // SSL_PeerCertificate returns a pointer that is owned by the caller, and
788   // the NSSCertificate constructor copies its argument, so |raw_peer_cert|
789   // must be destroyed in this function.
790   CERTCertificate* raw_peer_cert = SSL_PeerCertificate(fd);
791   NSSCertificate peer_cert(raw_peer_cert);
792   CERT_DestroyCertificate(raw_peer_cert);
793
794   NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
795   stream->cert_ok_ = false;
796
797   // Read the peer's certificate chain.
798   CERTCertList* cert_list = SSL_PeerCertificateChain(fd);
799   ASSERT(cert_list != NULL);
800
801   // If the peer provided multiple certificates, check that they form a valid
802   // chain as defined by RFC 5246 Section 7.4.2: "Each following certificate
803   // MUST directly certify the one preceding it.".  This check does NOT
804   // verify other requirements, such as whether the chain reaches a trusted
805   // root, self-signed certificates have valid signatures, certificates are not
806   // expired, etc.
807   // Even if the chain is valid, the leaf certificate must still match a
808   // provided certificate or digest.
809   if (!NSSCertificate::IsValidChain(cert_list)) {
810     CERT_DestroyCertList(cert_list);
811     PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
812     return SECFailure;
813   }
814
815   if (stream->peer_certificate_.get()) {
816     LOG(LS_INFO) << "Checking against specified certificate";
817
818     // The peer certificate was specified
819     if (reinterpret_cast<NSSCertificate *>(stream->peer_certificate_.get())->
820         Equals(&peer_cert)) {
821       LOG(LS_INFO) << "Accepted peer certificate";
822       stream->cert_ok_ = true;
823     }
824   } else if (!stream->peer_certificate_digest_algorithm_.empty()) {
825     LOG(LS_INFO) << "Checking against specified digest";
826     // The peer certificate digest was specified
827     unsigned char digest[64];  // Maximum size
828     size_t digest_length;
829
830     if (!peer_cert.ComputeDigest(
831             stream->peer_certificate_digest_algorithm_,
832             digest, sizeof(digest), &digest_length)) {
833       LOG(LS_ERROR) << "Digest computation failed";
834     } else {
835       Buffer computed_digest(digest, digest_length);
836       if (computed_digest == stream->peer_certificate_digest_value_) {
837         LOG(LS_INFO) << "Accepted peer certificate";
838         stream->cert_ok_ = true;
839       }
840     }
841   } else {
842     // Other modes, but we haven't implemented yet
843     // TODO(ekr@rtfm.com): Implement real certificate validation
844     UNIMPLEMENTED;
845   }
846
847   if (!stream->cert_ok_ && stream->ignore_bad_cert()) {
848     LOG(LS_WARNING) << "Ignoring cert error while verifying cert chain";
849     stream->cert_ok_ = true;
850   }
851
852   if (stream->cert_ok_)
853     stream->peer_certificate_.reset(new NSSCertificate(cert_list));
854
855   CERT_DestroyCertList(cert_list);
856
857   if (stream->cert_ok_)
858     return SECSuccess;
859
860   PORT_SetError(SEC_ERROR_UNTRUSTED_CERT);
861   return SECFailure;
862 }
863
864
865 SECStatus NSSStreamAdapter::GetClientAuthDataHook(void *arg, PRFileDesc *fd,
866                                                   CERTDistNames *caNames,
867                                                   CERTCertificate **pRetCert,
868                                                   SECKEYPrivateKey **pRetKey) {
869   LOG(LS_INFO) << "Client cert requested";
870   NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
871
872   if (!stream->identity_.get()) {
873     LOG(LS_ERROR) << "No identity available";
874     return SECFailure;
875   }
876
877   NSSIdentity *identity = static_cast<NSSIdentity *>(stream->identity_.get());
878   // Destroyed internally by NSS
879   *pRetCert = CERT_DupCertificate(identity->certificate().certificate());
880   *pRetKey = SECKEY_CopyPrivateKey(identity->keypair()->privkey());
881
882   return SECSuccess;
883 }
884
885 // RFC 5705 Key Exporter
886 bool NSSStreamAdapter::ExportKeyingMaterial(const std::string& label,
887                                             const uint8* context,
888                                             size_t context_len,
889                                             bool use_context,
890                                             uint8* result,
891                                             size_t result_len) {
892   SECStatus rv = SSL_ExportKeyingMaterial(
893       ssl_fd_,
894       label.c_str(),
895       checked_cast<unsigned int>(label.size()),
896       use_context,
897       context,
898       checked_cast<unsigned int>(context_len),
899       result,
900       checked_cast<unsigned int>(result_len));
901
902   return rv == SECSuccess;
903 }
904
905 bool NSSStreamAdapter::SetDtlsSrtpCiphers(
906     const std::vector<std::string>& ciphers) {
907 #ifdef HAVE_DTLS_SRTP
908   std::vector<PRUint16> internal_ciphers;
909   if (state_ != SSL_NONE)
910     return false;
911
912   for (std::vector<std::string>::const_iterator cipher = ciphers.begin();
913        cipher != ciphers.end(); ++cipher) {
914     bool found = false;
915     for (const SrtpCipherMapEntry *entry = kSrtpCipherMap; entry->cipher_id;
916          ++entry) {
917       if (*cipher == entry->external_name) {
918         found = true;
919         internal_ciphers.push_back(entry->cipher_id);
920         break;
921       }
922     }
923
924     if (!found) {
925       LOG(LS_ERROR) << "Could not find cipher: " << *cipher;
926       return false;
927     }
928   }
929
930   if (internal_ciphers.empty())
931     return false;
932
933   srtp_ciphers_ = internal_ciphers;
934
935   return true;
936 #else
937   return false;
938 #endif
939 }
940
941 bool NSSStreamAdapter::GetDtlsSrtpCipher(std::string* cipher) {
942 #ifdef HAVE_DTLS_SRTP
943   ASSERT(state_ == SSL_CONNECTED);
944   if (state_ != SSL_CONNECTED)
945     return false;
946
947   PRUint16 selected_cipher;
948
949   SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, &selected_cipher);
950   if (rv == SECFailure)
951     return false;
952
953   for (const SrtpCipherMapEntry *entry = kSrtpCipherMap;
954        entry->cipher_id; ++entry) {
955     if (selected_cipher == entry->cipher_id) {
956       *cipher = entry->external_name;
957       return true;
958     }
959   }
960
961   ASSERT(false);  // This should never happen
962 #endif
963   return false;
964 }
965
966
967 bool NSSContext::initialized;
968 NSSContext *NSSContext::global_nss_context;
969
970 // Static initialization and shutdown
971 NSSContext *NSSContext::Instance() {
972   if (!global_nss_context) {
973     NSSContext *new_ctx = new NSSContext();
974
975     if (!(new_ctx->slot_ = PK11_GetInternalSlot())) {
976       delete new_ctx;
977       goto fail;
978     }
979
980     global_nss_context = new_ctx;
981   }
982
983  fail:
984   return global_nss_context;
985 }
986
987
988
989 bool NSSContext::InitializeSSL(VerificationCallback callback) {
990   ASSERT(!callback);
991
992   if (!initialized) {
993     SECStatus rv;
994
995     rv = NSS_NoDB_Init(NULL);
996     if (rv != SECSuccess) {
997       LOG(LS_ERROR) << "Couldn't initialize NSS error=" <<
998           PORT_GetError();
999       return false;
1000     }
1001
1002     NSS_SetDomesticPolicy();
1003
1004     initialized = true;
1005   }
1006
1007   return true;
1008 }
1009
1010 bool NSSContext::InitializeSSLThread() {
1011   // Not needed
1012   return true;
1013 }
1014
1015 bool NSSContext::CleanupSSL() {
1016   // Not needed
1017   return true;
1018 }
1019
1020 bool NSSStreamAdapter::HaveDtls() {
1021   return true;
1022 }
1023
1024 bool NSSStreamAdapter::HaveDtlsSrtp() {
1025 #ifdef HAVE_DTLS_SRTP
1026   return true;
1027 #else
1028   return false;
1029 #endif
1030 }
1031
1032 bool NSSStreamAdapter::HaveExporter() {
1033   return true;
1034 }
1035
1036 }  // namespace talk_base
1037
1038 #endif  // HAVE_NSS_SSL_H