Upstream version 5.34.104.0
[platform/framework/web/crosswalk.git] / src / third_party / libjingle / source / talk / base / nssstreamadapter.cc
1 /*
2  * libjingle
3  * Copyright 2004--2008, Google Inc.
4  * Copyright 2004--2011, RTFM, Inc.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  *  1. Redistributions of source code must retain the above copyright notice,
10  *     this list of conditions and the following disclaimer.
11  *  2. Redistributions in binary form must reproduce the above copyright notice,
12  *     this list of conditions and the following disclaimer in the documentation
13  *     and/or other materials provided with the distribution.
14  *  3. The name of the author may not be used to endorse or promote products
15  *     derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
18  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
19  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
20  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
25  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
26  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  */
28
29 #include <vector>
30
31 #if HAVE_CONFIG_H
32 #include "config.h"
33 #endif  // HAVE_CONFIG_H
34
35 #if HAVE_NSS_SSL_H
36
37 #include "talk/base/nssstreamadapter.h"
38
39 #include "keyhi.h"
40 #include "nspr.h"
41 #include "nss.h"
42 #include "pk11pub.h"
43 #include "secerr.h"
44
45 #ifdef NSS_SSL_RELATIVE_PATH
46 #include "ssl.h"
47 #include "sslerr.h"
48 #include "sslproto.h"
49 #else
50 #include "net/third_party/nss/ssl/ssl.h"
51 #include "net/third_party/nss/ssl/sslerr.h"
52 #include "net/third_party/nss/ssl/sslproto.h"
53 #endif
54
55 #include "talk/base/nssidentity.h"
56 #include "talk/base/safe_conversions.h"
57 #include "talk/base/thread.h"
58
59 namespace talk_base {
60
61 PRDescIdentity NSSStreamAdapter::nspr_layer_identity = PR_INVALID_IO_LAYER;
62
63 #define UNIMPLEMENTED \
64   PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); \
65   LOG(LS_ERROR) \
66   << "Call to unimplemented function "<< __FUNCTION__; ASSERT(false)
67
68 #ifdef SRTP_AES128_CM_HMAC_SHA1_80
69 #define HAVE_DTLS_SRTP
70 #endif
71
72 #ifdef HAVE_DTLS_SRTP
73 // SRTP cipher suite table
74 struct SrtpCipherMapEntry {
75   const char* external_name;
76   PRUint16 cipher_id;
77 };
78
79 // This isn't elegant, but it's better than an external reference
80 static const SrtpCipherMapEntry kSrtpCipherMap[] = {
81   {"AES_CM_128_HMAC_SHA1_80", SRTP_AES128_CM_HMAC_SHA1_80 },
82   {"AES_CM_128_HMAC_SHA1_32", SRTP_AES128_CM_HMAC_SHA1_32 },
83   {NULL, 0}
84 };
85 #endif
86
87
88 // Implementation of NSPR methods
89 static PRStatus StreamClose(PRFileDesc *socket) {
90   ASSERT(!socket->lower);
91   socket->dtor(socket);
92   return PR_SUCCESS;
93 }
94
95 static PRInt32 StreamRead(PRFileDesc *socket, void *buf, PRInt32 length) {
96   StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
97   size_t read;
98   int error;
99   StreamResult result = stream->Read(buf, length, &read, &error);
100   if (result == SR_SUCCESS) {
101     return checked_cast<PRInt32>(read);
102   }
103
104   if (result == SR_EOS) {
105     return 0;
106   }
107
108   if (result == SR_BLOCK) {
109     PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
110     return -1;
111   }
112
113   PR_SetError(PR_UNKNOWN_ERROR, error);
114   return -1;
115 }
116
117 static PRInt32 StreamWrite(PRFileDesc *socket, const void *buf,
118                            PRInt32 length) {
119   StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
120   size_t written;
121   int error;
122   StreamResult result = stream->Write(buf, length, &written, &error);
123   if (result == SR_SUCCESS) {
124     return checked_cast<PRInt32>(written);
125   }
126
127   if (result == SR_BLOCK) {
128     LOG(LS_INFO) <<
129         "NSSStreamAdapter: write to underlying transport would block";
130     PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
131     return -1;
132   }
133
134   LOG(LS_ERROR) << "Write error";
135   PR_SetError(PR_UNKNOWN_ERROR, error);
136   return -1;
137 }
138
139 static PRInt32 StreamAvailable(PRFileDesc *socket) {
140   UNIMPLEMENTED;
141   return -1;
142 }
143
144 PRInt64 StreamAvailable64(PRFileDesc *socket) {
145   UNIMPLEMENTED;
146   return -1;
147 }
148
149 static PRStatus StreamSync(PRFileDesc *socket) {
150   UNIMPLEMENTED;
151   return PR_FAILURE;
152 }
153
154 static PROffset32 StreamSeek(PRFileDesc *socket, PROffset32 offset,
155                              PRSeekWhence how) {
156   UNIMPLEMENTED;
157   return -1;
158 }
159
160 static PROffset64 StreamSeek64(PRFileDesc *socket, PROffset64 offset,
161                                PRSeekWhence how) {
162   UNIMPLEMENTED;
163   return -1;
164 }
165
166 static PRStatus StreamFileInfo(PRFileDesc *socket, PRFileInfo *info) {
167   UNIMPLEMENTED;
168   return PR_FAILURE;
169 }
170
171 static PRStatus StreamFileInfo64(PRFileDesc *socket, PRFileInfo64 *info) {
172   UNIMPLEMENTED;
173   return PR_FAILURE;
174 }
175
176 static PRInt32 StreamWritev(PRFileDesc *socket, const PRIOVec *iov,
177                      PRInt32 iov_size, PRIntervalTime timeout) {
178   UNIMPLEMENTED;
179   return -1;
180 }
181
182 static PRStatus StreamConnect(PRFileDesc *socket, const PRNetAddr *addr,
183                               PRIntervalTime timeout) {
184   UNIMPLEMENTED;
185   return PR_FAILURE;
186 }
187
188 static PRFileDesc *StreamAccept(PRFileDesc *sd, PRNetAddr *addr,
189                                 PRIntervalTime timeout) {
190   UNIMPLEMENTED;
191   return NULL;
192 }
193
194 static PRStatus StreamBind(PRFileDesc *socket, const PRNetAddr *addr) {
195   UNIMPLEMENTED;
196   return PR_FAILURE;
197 }
198
199 static PRStatus StreamListen(PRFileDesc *socket, PRIntn depth) {
200   UNIMPLEMENTED;
201   return PR_FAILURE;
202 }
203
204 static PRStatus StreamShutdown(PRFileDesc *socket, PRIntn how) {
205   UNIMPLEMENTED;
206   return PR_FAILURE;
207 }
208
209 // Note: this is always nonblocking and ignores the timeout.
210 // TODO(ekr@rtfm.com): In future verify that the socket is
211 // actually in non-blocking mode.
212 // This function does not support peek.
213 static PRInt32 StreamRecv(PRFileDesc *socket, void *buf, PRInt32 amount,
214                    PRIntn flags, PRIntervalTime to) {
215   ASSERT(flags == 0);
216
217   if (flags != 0) {
218     PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
219     return -1;
220   }
221
222   return StreamRead(socket, buf, amount);
223 }
224
225 // Note: this is always nonblocking and assumes a zero timeout.
226 // This function does not support peek.
227 static PRInt32 StreamSend(PRFileDesc *socket, const void *buf,
228                           PRInt32 amount, PRIntn flags,
229                           PRIntervalTime to) {
230   ASSERT(flags == 0);
231
232   return StreamWrite(socket, buf, amount);
233 }
234
235 static PRInt32 StreamRecvfrom(PRFileDesc *socket, void *buf,
236                               PRInt32 amount, PRIntn flags,
237                               PRNetAddr *addr, PRIntervalTime to) {
238   UNIMPLEMENTED;
239   return -1;
240 }
241
242 static PRInt32 StreamSendto(PRFileDesc *socket, const void *buf,
243                             PRInt32 amount, PRIntn flags,
244                             const PRNetAddr *addr, PRIntervalTime to) {
245   UNIMPLEMENTED;
246   return -1;
247 }
248
249 static PRInt16 StreamPoll(PRFileDesc *socket, PRInt16 in_flags,
250                           PRInt16 *out_flags) {
251   UNIMPLEMENTED;
252   return -1;
253 }
254
255 static PRInt32 StreamAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
256                                 PRNetAddr **raddr,
257                                 void *buf, PRInt32 amount, PRIntervalTime t) {
258   UNIMPLEMENTED;
259   return -1;
260 }
261
262 static PRInt32 StreamTransmitFile(PRFileDesc *sd, PRFileDesc *socket,
263                                   const void *headers, PRInt32 hlen,
264                                   PRTransmitFileFlags flags, PRIntervalTime t) {
265   UNIMPLEMENTED;
266   return -1;
267 }
268
269 static PRStatus StreamGetPeerName(PRFileDesc *socket, PRNetAddr *addr) {
270   // TODO(ekr@rtfm.com): Modify to return unique names for each channel
271   // somehow, as opposed to always the same static address. The current
272   // implementation messes up the session cache, which is why it's off
273   // elsewhere
274   addr->inet.family = PR_AF_INET;
275   addr->inet.port = 0;
276   addr->inet.ip = 0;
277
278   return PR_SUCCESS;
279 }
280
281 static PRStatus StreamGetSockName(PRFileDesc *socket, PRNetAddr *addr) {
282   UNIMPLEMENTED;
283   return PR_FAILURE;
284 }
285
286 static PRStatus StreamGetSockOption(PRFileDesc *socket, PRSocketOptionData *opt) {
287   switch (opt->option) {
288     case PR_SockOpt_Nonblocking:
289       opt->value.non_blocking = PR_TRUE;
290       return PR_SUCCESS;
291     default:
292       UNIMPLEMENTED;
293       break;
294   }
295
296   return PR_FAILURE;
297 }
298
299 // Imitate setting socket options. These are mostly noops.
300 static PRStatus StreamSetSockOption(PRFileDesc *socket,
301                                     const PRSocketOptionData *opt) {
302   switch (opt->option) {
303     case PR_SockOpt_Nonblocking:
304       return PR_SUCCESS;
305     case PR_SockOpt_NoDelay:
306       return PR_SUCCESS;
307     default:
308       UNIMPLEMENTED;
309       break;
310   }
311
312   return PR_FAILURE;
313 }
314
315 static PRInt32 StreamSendfile(PRFileDesc *out, PRSendFileData *in,
316                               PRTransmitFileFlags flags, PRIntervalTime to) {
317   UNIMPLEMENTED;
318   return -1;
319 }
320
321 static PRStatus StreamConnectContinue(PRFileDesc *socket, PRInt16 flags) {
322   UNIMPLEMENTED;
323   return PR_FAILURE;
324 }
325
326 static PRIntn StreamReserved(PRFileDesc *socket) {
327   UNIMPLEMENTED;
328   return -1;
329 }
330
331 static const struct PRIOMethods nss_methods = {
332   PR_DESC_LAYERED,
333   StreamClose,
334   StreamRead,
335   StreamWrite,
336   StreamAvailable,
337   StreamAvailable64,
338   StreamSync,
339   StreamSeek,
340   StreamSeek64,
341   StreamFileInfo,
342   StreamFileInfo64,
343   StreamWritev,
344   StreamConnect,
345   StreamAccept,
346   StreamBind,
347   StreamListen,
348   StreamShutdown,
349   StreamRecv,
350   StreamSend,
351   StreamRecvfrom,
352   StreamSendto,
353   StreamPoll,
354   StreamAcceptRead,
355   StreamTransmitFile,
356   StreamGetSockName,
357   StreamGetPeerName,
358   StreamReserved,
359   StreamReserved,
360   StreamGetSockOption,
361   StreamSetSockOption,
362   StreamSendfile,
363   StreamConnectContinue,
364   StreamReserved,
365   StreamReserved,
366   StreamReserved,
367   StreamReserved
368 };
369
370 NSSStreamAdapter::NSSStreamAdapter(StreamInterface *stream)
371     : SSLStreamAdapterHelper(stream),
372       ssl_fd_(NULL),
373       cert_ok_(false) {
374 }
375
376 bool NSSStreamAdapter::Init() {
377   if (nspr_layer_identity == PR_INVALID_IO_LAYER) {
378     nspr_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
379   }
380   PRFileDesc *pr_fd = PR_CreateIOLayerStub(nspr_layer_identity, &nss_methods);
381   if (!pr_fd)
382     return false;
383   pr_fd->secret = reinterpret_cast<PRFilePrivate *>(stream());
384
385   PRFileDesc *ssl_fd;
386   if (ssl_mode_ == SSL_MODE_DTLS) {
387     ssl_fd = DTLS_ImportFD(NULL, pr_fd);
388   } else {
389     ssl_fd = SSL_ImportFD(NULL, pr_fd);
390   }
391   ASSERT(ssl_fd != NULL);  // This should never happen
392   if (!ssl_fd) {
393     PR_Close(pr_fd);
394     return false;
395   }
396
397   SECStatus rv;
398   // Turn on security.
399   rv = SSL_OptionSet(ssl_fd, SSL_SECURITY, PR_TRUE);
400   if (rv != SECSuccess) {
401     LOG(LS_ERROR) << "Error enabling security on SSL Socket";
402     return false;
403   }
404
405   // Disable SSLv2.
406   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SSL2, PR_FALSE);
407   if (rv != SECSuccess) {
408     LOG(LS_ERROR) << "Error disabling SSL2";
409     return false;
410   }
411
412   // Disable caching.
413   // TODO(ekr@rtfm.com): restore this when I have the caching
414   // identity set.
415   rv = SSL_OptionSet(ssl_fd, SSL_NO_CACHE, PR_TRUE);
416   if (rv != SECSuccess) {
417     LOG(LS_ERROR) << "Error disabling cache";
418     return false;
419   }
420
421   // Disable session tickets.
422   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
423   if (rv != SECSuccess) {
424     LOG(LS_ERROR) << "Error enabling tickets";
425     return false;
426   }
427
428   // Disable renegotiation.
429   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_RENEGOTIATION,
430                      SSL_RENEGOTIATE_NEVER);
431   if (rv != SECSuccess) {
432     LOG(LS_ERROR) << "Error disabling renegotiation";
433     return false;
434   }
435
436   // Disable false start.
437   rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_FALSE_START, PR_FALSE);
438   if (rv != SECSuccess) {
439     LOG(LS_ERROR) << "Error disabling false start";
440     return false;
441   }
442
443   ssl_fd_ = ssl_fd;
444
445   return true;
446 }
447
448 NSSStreamAdapter::~NSSStreamAdapter() {
449   if (ssl_fd_)
450     PR_Close(ssl_fd_);
451 };
452
453
454 int NSSStreamAdapter::BeginSSL() {
455   SECStatus rv;
456
457   if (!Init()) {
458     Error("Init", -1, false);
459     return -1;
460   }
461
462   ASSERT(state_ == SSL_CONNECTING);
463   // The underlying stream has been opened. If we are in peer-to-peer mode
464   // then a peer certificate must have been specified by now.
465   ASSERT(!ssl_server_name_.empty() ||
466          peer_certificate_.get() != NULL ||
467          !peer_certificate_digest_algorithm_.empty());
468   LOG(LS_INFO) << "BeginSSL: "
469                << (!ssl_server_name_.empty() ? ssl_server_name_ :
470                                                "with peer");
471
472   if (role_ == SSL_CLIENT) {
473     LOG(LS_INFO) << "BeginSSL: as client";
474
475     rv = SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
476                                    this);
477     if (rv != SECSuccess) {
478       Error("BeginSSL", -1, false);
479       return -1;
480     }
481   } else {
482     LOG(LS_INFO) << "BeginSSL: as server";
483     NSSIdentity *identity;
484
485     if (identity_.get()) {
486       identity = static_cast<NSSIdentity *>(identity_.get());
487     } else {
488       LOG(LS_ERROR) << "Can't be an SSL server without an identity";
489       Error("BeginSSL", -1, false);
490       return -1;
491     }
492     rv = SSL_ConfigSecureServer(ssl_fd_, identity->certificate().certificate(),
493                                 identity->keypair()->privkey(),
494                                 kt_rsa);
495     if (rv != SECSuccess) {
496       Error("BeginSSL", -1, false);
497       return -1;
498     }
499
500     // Insist on a certificate from the client
501     rv = SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE);
502     if (rv != SECSuccess) {
503       Error("BeginSSL", -1, false);
504       return -1;
505     }
506
507     rv = SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
508     if (rv != SECSuccess) {
509       Error("BeginSSL", -1, false);
510       return -1;
511     }
512   }
513
514   // Set the version range.
515   SSLVersionRange vrange;
516   vrange.min =  (ssl_mode_ == SSL_MODE_DTLS) ?
517       SSL_LIBRARY_VERSION_TLS_1_1 :
518       SSL_LIBRARY_VERSION_TLS_1_0;
519   vrange.max = SSL_LIBRARY_VERSION_TLS_1_1;
520
521   rv = SSL_VersionRangeSet(ssl_fd_, &vrange);
522   if (rv != SECSuccess) {
523     Error("BeginSSL", -1, false);
524     return -1;
525   }
526
527   // SRTP
528 #ifdef HAVE_DTLS_SRTP
529   if (!srtp_ciphers_.empty()) {
530     rv = SSL_SetSRTPCiphers(
531         ssl_fd_, &srtp_ciphers_[0],
532         checked_cast<unsigned int>(srtp_ciphers_.size()));
533     if (rv != SECSuccess) {
534       Error("BeginSSL", -1, false);
535       return -1;
536     }
537   }
538 #endif
539
540   // Certificate validation
541   rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
542   if (rv != SECSuccess) {
543     Error("BeginSSL", -1, false);
544     return -1;
545   }
546
547   // Now start the handshake
548   rv = SSL_ResetHandshake(ssl_fd_, role_ == SSL_SERVER ? PR_TRUE : PR_FALSE);
549   if (rv != SECSuccess) {
550     Error("BeginSSL", -1, false);
551     return -1;
552   }
553
554   return ContinueSSL();
555 }
556
557 int NSSStreamAdapter::ContinueSSL() {
558   LOG(LS_INFO) << "ContinueSSL";
559   ASSERT(state_ == SSL_CONNECTING);
560
561   // Clear the DTLS timer
562   Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
563
564   SECStatus rv = SSL_ForceHandshake(ssl_fd_);
565
566   if (rv == SECSuccess) {
567     LOG(LS_INFO) << "Handshake complete";
568
569     ASSERT(cert_ok_);
570     if (!cert_ok_) {
571       Error("ContinueSSL", -1, true);
572       return -1;
573     }
574
575     state_ = SSL_CONNECTED;
576     StreamAdapterInterface::OnEvent(stream(), SE_OPEN|SE_READ|SE_WRITE, 0);
577     return 0;
578   }
579
580   PRInt32 err = PR_GetError();
581   switch (err) {
582     case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
583       if (ssl_mode_ != SSL_MODE_DTLS) {
584         Error("ContinueSSL", -1, true);
585         return -1;
586       } else {
587         LOG(LS_INFO) << "Malformed DTLS message. Ignoring.";
588         // Fall through
589       }
590     case PR_WOULD_BLOCK_ERROR:
591       LOG(LS_INFO) << "Would have blocked";
592       if (ssl_mode_ == SSL_MODE_DTLS) {
593         PRIntervalTime timeout;
594
595         SECStatus rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
596         if (rv == SECSuccess) {
597           LOG(LS_INFO) << "Timeout is " << timeout << " ms";
598           Thread::Current()->PostDelayed(PR_IntervalToMilliseconds(timeout),
599                                          this, MSG_DTLS_TIMEOUT, 0);
600         }
601       }
602
603       return 0;
604     default:
605       LOG(LS_INFO) << "Error " << err;
606       break;
607   }
608
609   Error("ContinueSSL", -1, true);
610   return -1;
611 }
612
613 void NSSStreamAdapter::Cleanup() {
614   if (state_ != SSL_ERROR) {
615     state_ = SSL_CLOSED;
616   }
617
618   if (ssl_fd_) {
619     PR_Close(ssl_fd_);
620     ssl_fd_ = NULL;
621   }
622
623   identity_.reset();
624   peer_certificate_.reset();
625
626   Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
627 }
628
629 StreamResult NSSStreamAdapter::Read(void* data, size_t data_len,
630                                     size_t* read, int* error) {
631   // SSL_CONNECTED sanity check.
632   switch (state_) {
633     case SSL_NONE:
634     case SSL_WAIT:
635     case SSL_CONNECTING:
636       return SR_BLOCK;
637
638     case SSL_CONNECTED:
639       break;
640
641     case SSL_CLOSED:
642       return SR_EOS;
643
644     case SSL_ERROR:
645     default:
646       if (error)
647         *error = ssl_error_code_;
648       return SR_ERROR;
649   }
650
651   PRInt32 rv = PR_Read(ssl_fd_, data, checked_cast<PRInt32>(data_len));
652
653   if (rv == 0) {
654     return SR_EOS;
655   }
656
657   // Error
658   if (rv < 0) {
659     PRInt32 err = PR_GetError();
660
661     switch (err) {
662       case PR_WOULD_BLOCK_ERROR:
663         return SR_BLOCK;
664       default:
665         Error("Read", -1, false);
666         *error = err;  // libjingle semantics are that this is impl-specific
667         return SR_ERROR;
668     }
669   }
670
671   // Success
672   *read = rv;
673
674   return SR_SUCCESS;
675 }
676
677 StreamResult NSSStreamAdapter::Write(const void* data, size_t data_len,
678                                      size_t* written, int* error) {
679   // SSL_CONNECTED sanity check.
680   switch (state_) {
681     case SSL_NONE:
682     case SSL_WAIT:
683     case SSL_CONNECTING:
684       return SR_BLOCK;
685
686     case SSL_CONNECTED:
687       break;
688
689     case SSL_ERROR:
690     case SSL_CLOSED:
691     default:
692       if (error)
693         *error = ssl_error_code_;
694       return SR_ERROR;
695   }
696
697   PRInt32 rv = PR_Write(ssl_fd_, data, checked_cast<PRInt32>(data_len));
698
699   // Error
700   if (rv < 0) {
701     PRInt32 err = PR_GetError();
702
703     switch (err) {
704       case PR_WOULD_BLOCK_ERROR:
705         return SR_BLOCK;
706       default:
707         Error("Write", -1, false);
708         *error = err;  // libjingle semantics are that this is impl-specific
709         return SR_ERROR;
710     }
711   }
712
713   // Success
714   *written = rv;
715
716   return SR_SUCCESS;
717 }
718
719 void NSSStreamAdapter::OnEvent(StreamInterface* stream, int events,
720                                int err) {
721   int events_to_signal = 0;
722   int signal_error = 0;
723   ASSERT(stream == this->stream());
724   if ((events & SE_OPEN)) {
725     LOG(LS_INFO) << "NSSStreamAdapter::OnEvent SE_OPEN";
726     if (state_ != SSL_WAIT) {
727       ASSERT(state_ == SSL_NONE);
728       events_to_signal |= SE_OPEN;
729     } else {
730       state_ = SSL_CONNECTING;
731       if (int err = BeginSSL()) {
732         Error("BeginSSL", err, true);
733         return;
734       }
735     }
736   }
737   if ((events & (SE_READ|SE_WRITE))) {
738     LOG(LS_INFO) << "NSSStreamAdapter::OnEvent"
739                  << ((events & SE_READ) ? " SE_READ" : "")
740                  << ((events & SE_WRITE) ? " SE_WRITE" : "");
741     if (state_ == SSL_NONE) {
742       events_to_signal |= events & (SE_READ|SE_WRITE);
743     } else if (state_ == SSL_CONNECTING) {
744       if (int err = ContinueSSL()) {
745         Error("ContinueSSL", err, true);
746         return;
747       }
748     } else if (state_ == SSL_CONNECTED) {
749       if (events & SE_WRITE) {
750         LOG(LS_INFO) << " -- onStreamWriteable";
751         events_to_signal |= SE_WRITE;
752       }
753       if (events & SE_READ) {
754         LOG(LS_INFO) << " -- onStreamReadable";
755         events_to_signal |= SE_READ;
756       }
757     }
758   }
759   if ((events & SE_CLOSE)) {
760     LOG(LS_INFO) << "NSSStreamAdapter::OnEvent(SE_CLOSE, " << err << ")";
761     Cleanup();
762     events_to_signal |= SE_CLOSE;
763     // SE_CLOSE is the only event that uses the final parameter to OnEvent().
764     ASSERT(signal_error == 0);
765     signal_error = err;
766   }
767   if (events_to_signal)
768     StreamAdapterInterface::OnEvent(stream, events_to_signal, signal_error);
769 }
770
771 void NSSStreamAdapter::OnMessage(Message* msg) {
772   // Process our own messages and then pass others to the superclass
773   if (MSG_DTLS_TIMEOUT == msg->message_id) {
774     LOG(LS_INFO) << "DTLS timeout expired";
775     ContinueSSL();
776   } else {
777     StreamInterface::OnMessage(msg);
778   }
779 }
780
781 // Certificate verification callback. Called to check any certificate
782 SECStatus NSSStreamAdapter::AuthCertificateHook(void *arg,
783                                                 PRFileDesc *fd,
784                                                 PRBool checksig,
785                                                 PRBool isServer) {
786   LOG(LS_INFO) << "NSSStreamAdapter::AuthCertificateHook";
787   NSSCertificate peer_cert(SSL_PeerCertificate(fd));
788   NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
789   stream->cert_ok_ = false;
790
791   // Read the peer's certificate chain.
792   CERTCertList* cert_list = SSL_PeerCertificateChain(fd);
793   ASSERT(cert_list != NULL);
794
795   // If the peer provided multiple certificates, check that they form a valid
796   // chain as defined by RFC 5246 Section 7.4.2: "Each following certificate
797   // MUST directly certify the one preceding it.".  This check does NOT
798   // verify other requirements, such as whether the chain reaches a trusted
799   // root, self-signed certificates have valid signatures, certificates are not
800   // expired, etc.
801   // Even if the chain is valid, the leaf certificate must still match a
802   // provided certificate or digest.
803   if (!NSSCertificate::IsValidChain(cert_list)) {
804     CERT_DestroyCertList(cert_list);
805     PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
806     return SECFailure;
807   }
808
809   if (stream->peer_certificate_.get()) {
810     LOG(LS_INFO) << "Checking against specified certificate";
811
812     // The peer certificate was specified
813     if (reinterpret_cast<NSSCertificate *>(stream->peer_certificate_.get())->
814         Equals(&peer_cert)) {
815       LOG(LS_INFO) << "Accepted peer certificate";
816       stream->cert_ok_ = true;
817     }
818   } else if (!stream->peer_certificate_digest_algorithm_.empty()) {
819     LOG(LS_INFO) << "Checking against specified digest";
820     // The peer certificate digest was specified
821     unsigned char digest[64];  // Maximum size
822     std::size_t digest_length;
823
824     if (!peer_cert.ComputeDigest(
825             stream->peer_certificate_digest_algorithm_,
826             digest, sizeof(digest), &digest_length)) {
827       LOG(LS_ERROR) << "Digest computation failed";
828     } else {
829       Buffer computed_digest(digest, digest_length);
830       if (computed_digest == stream->peer_certificate_digest_value_) {
831         LOG(LS_INFO) << "Accepted peer certificate";
832         stream->cert_ok_ = true;
833       }
834     }
835   } else {
836     // Other modes, but we haven't implemented yet
837     // TODO(ekr@rtfm.com): Implement real certificate validation
838     UNIMPLEMENTED;
839   }
840
841   if (!stream->cert_ok_ && stream->ignore_bad_cert()) {
842     LOG(LS_WARNING) << "Ignoring cert error while verifying cert chain";
843     stream->cert_ok_ = true;
844   }
845
846   if (stream->cert_ok_)
847     stream->peer_certificate_.reset(new NSSCertificate(cert_list));
848
849   CERT_DestroyCertList(cert_list);
850
851   if (stream->cert_ok_)
852     return SECSuccess;
853
854   PORT_SetError(SEC_ERROR_UNTRUSTED_CERT);
855   return SECFailure;
856 }
857
858
859 SECStatus NSSStreamAdapter::GetClientAuthDataHook(void *arg, PRFileDesc *fd,
860                                                   CERTDistNames *caNames,
861                                                   CERTCertificate **pRetCert,
862                                                   SECKEYPrivateKey **pRetKey) {
863   LOG(LS_INFO) << "Client cert requested";
864   NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
865
866   if (!stream->identity_.get()) {
867     LOG(LS_ERROR) << "No identity available";
868     return SECFailure;
869   }
870
871   NSSIdentity *identity = static_cast<NSSIdentity *>(stream->identity_.get());
872   // Destroyed internally by NSS
873   *pRetCert = CERT_DupCertificate(identity->certificate().certificate());
874   *pRetKey = SECKEY_CopyPrivateKey(identity->keypair()->privkey());
875
876   return SECSuccess;
877 }
878
879 // RFC 5705 Key Exporter
880 bool NSSStreamAdapter::ExportKeyingMaterial(const std::string& label,
881                                             const uint8* context,
882                                             size_t context_len,
883                                             bool use_context,
884                                             uint8* result,
885                                             size_t result_len) {
886   SECStatus rv = SSL_ExportKeyingMaterial(
887       ssl_fd_,
888       label.c_str(),
889       checked_cast<unsigned int>(label.size()),
890       use_context,
891       context,
892       checked_cast<unsigned int>(context_len),
893       result,
894       checked_cast<unsigned int>(result_len));
895
896   return rv == SECSuccess;
897 }
898
899 bool NSSStreamAdapter::SetDtlsSrtpCiphers(
900     const std::vector<std::string>& ciphers) {
901 #ifdef HAVE_DTLS_SRTP
902   std::vector<PRUint16> internal_ciphers;
903   if (state_ != SSL_NONE)
904     return false;
905
906   for (std::vector<std::string>::const_iterator cipher = ciphers.begin();
907        cipher != ciphers.end(); ++cipher) {
908     bool found = false;
909     for (const SrtpCipherMapEntry *entry = kSrtpCipherMap; entry->cipher_id;
910          ++entry) {
911       if (*cipher == entry->external_name) {
912         found = true;
913         internal_ciphers.push_back(entry->cipher_id);
914         break;
915       }
916     }
917
918     if (!found) {
919       LOG(LS_ERROR) << "Could not find cipher: " << *cipher;
920       return false;
921     }
922   }
923
924   if (internal_ciphers.empty())
925     return false;
926
927   srtp_ciphers_ = internal_ciphers;
928
929   return true;
930 #else
931   return false;
932 #endif
933 }
934
935 bool NSSStreamAdapter::GetDtlsSrtpCipher(std::string* cipher) {
936 #ifdef HAVE_DTLS_SRTP
937   ASSERT(state_ == SSL_CONNECTED);
938   if (state_ != SSL_CONNECTED)
939     return false;
940
941   PRUint16 selected_cipher;
942
943   SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, &selected_cipher);
944   if (rv == SECFailure)
945     return false;
946
947   for (const SrtpCipherMapEntry *entry = kSrtpCipherMap;
948        entry->cipher_id; ++entry) {
949     if (selected_cipher == entry->cipher_id) {
950       *cipher = entry->external_name;
951       return true;
952     }
953   }
954
955   ASSERT(false);  // This should never happen
956 #endif
957   return false;
958 }
959
960
961 bool NSSContext::initialized;
962 NSSContext *NSSContext::global_nss_context;
963
964 // Static initialization and shutdown
965 NSSContext *NSSContext::Instance() {
966   if (!global_nss_context) {
967     NSSContext *new_ctx = new NSSContext();
968
969     if (!(new_ctx->slot_ = PK11_GetInternalSlot())) {
970       delete new_ctx;
971       goto fail;
972     }
973
974     global_nss_context = new_ctx;
975   }
976
977  fail:
978   return global_nss_context;
979 }
980
981
982
983 bool NSSContext::InitializeSSL(VerificationCallback callback) {
984   ASSERT(!callback);
985
986   if (!initialized) {
987     SECStatus rv;
988
989     rv = NSS_NoDB_Init(NULL);
990     if (rv != SECSuccess) {
991       LOG(LS_ERROR) << "Couldn't initialize NSS error=" <<
992           PORT_GetError();
993       return false;
994     }
995
996     NSS_SetDomesticPolicy();
997
998     initialized = true;
999   }
1000
1001   return true;
1002 }
1003
1004 bool NSSContext::InitializeSSLThread() {
1005   // Not needed
1006   return true;
1007 }
1008
1009 bool NSSContext::CleanupSSL() {
1010   // Not needed
1011   return true;
1012 }
1013
1014 bool NSSStreamAdapter::HaveDtls() {
1015   return true;
1016 }
1017
1018 bool NSSStreamAdapter::HaveDtlsSrtp() {
1019 #ifdef HAVE_DTLS_SRTP
1020   return true;
1021 #else
1022   return false;
1023 #endif
1024 }
1025
1026 bool NSSStreamAdapter::HaveExporter() {
1027   return true;
1028 }
1029
1030 }  // namespace talk_base
1031
1032 #endif  // HAVE_NSS_SSL_H