- add third_party src.
[platform/framework/web/crosswalk.git] / src / third_party / libjingle / source / talk / base / schanneladapter.cc
1 /*
2  * libjingle
3  * Copyright 2004--2005, Google Inc.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  *  1. Redistributions of source code must retain the above copyright notice,
9  *     this list of conditions and the following disclaimer.
10  *  2. Redistributions in binary form must reproduce the above copyright notice,
11  *     this list of conditions and the following disclaimer in the documentation
12  *     and/or other materials provided with the distribution.
13  *  3. The name of the author may not be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27
28 #include "talk/base/win32.h"
29 #define SECURITY_WIN32
30 #include <security.h>
31 #include <schannel.h>
32
33 #include <iomanip>
34 #include <vector>
35
36 #include "talk/base/common.h"
37 #include "talk/base/logging.h"
38 #include "talk/base/schanneladapter.h"
39 #include "talk/base/sec_buffer.h"
40 #include "talk/base/thread.h"
41
42 namespace talk_base {
43
44 /////////////////////////////////////////////////////////////////////////////
45 // SChannelAdapter
46 /////////////////////////////////////////////////////////////////////////////
47
48 extern const ConstantLabel SECURITY_ERRORS[];
49
50 const ConstantLabel SCHANNEL_BUFFER_TYPES[] = {
51   KLABEL(SECBUFFER_EMPTY),              //  0
52   KLABEL(SECBUFFER_DATA),               //  1
53   KLABEL(SECBUFFER_TOKEN),              //  2
54   KLABEL(SECBUFFER_PKG_PARAMS),         //  3
55   KLABEL(SECBUFFER_MISSING),            //  4
56   KLABEL(SECBUFFER_EXTRA),              //  5
57   KLABEL(SECBUFFER_STREAM_TRAILER),     //  6
58   KLABEL(SECBUFFER_STREAM_HEADER),      //  7
59   KLABEL(SECBUFFER_MECHLIST),           // 11
60   KLABEL(SECBUFFER_MECHLIST_SIGNATURE), // 12
61   KLABEL(SECBUFFER_TARGET),             // 13
62   KLABEL(SECBUFFER_CHANNEL_BINDINGS),   // 14
63   LASTLABEL
64 };
65
66 void DescribeBuffer(LoggingSeverity severity, const char* prefix,
67                     const SecBuffer& sb) {
68   LOG_V(severity)
69     << prefix
70     << "(" << sb.cbBuffer
71     << ", " << FindLabel(sb.BufferType & ~SECBUFFER_ATTRMASK,
72                           SCHANNEL_BUFFER_TYPES)
73     << ", " << sb.pvBuffer << ")";
74 }
75
76 void DescribeBuffers(LoggingSeverity severity, const char* prefix,
77                      const SecBufferDesc* sbd) {
78   if (!LOG_CHECK_LEVEL_V(severity))
79     return;
80   LOG_V(severity) << prefix << "(";
81   for (size_t i=0; i<sbd->cBuffers; ++i) {
82     DescribeBuffer(severity, "  ", sbd->pBuffers[i]);
83   }
84   LOG_V(severity) << ")";
85 }
86
87 const ULONG SSL_FLAGS_DEFAULT = ISC_REQ_ALLOCATE_MEMORY
88                               | ISC_REQ_CONFIDENTIALITY
89                               | ISC_REQ_EXTENDED_ERROR
90                               | ISC_REQ_INTEGRITY
91                               | ISC_REQ_REPLAY_DETECT
92                               | ISC_REQ_SEQUENCE_DETECT
93                               | ISC_REQ_STREAM;
94                               //| ISC_REQ_USE_SUPPLIED_CREDS;
95
96 typedef std::vector<char> SChannelBuffer;
97
98 struct SChannelAdapter::SSLImpl {
99   CredHandle cred;
100   CtxtHandle ctx;
101   bool cred_init, ctx_init;
102   SChannelBuffer inbuf, outbuf, readable;
103   SecPkgContext_StreamSizes sizes;
104
105   SSLImpl() : cred_init(false), ctx_init(false) { }
106 };
107
108 SChannelAdapter::SChannelAdapter(AsyncSocket* socket)
109   : SSLAdapter(socket), state_(SSL_NONE),
110     restartable_(false), signal_close_(false), message_pending_(false),
111     impl_(new SSLImpl) {
112 }
113
114 SChannelAdapter::~SChannelAdapter() {
115   Cleanup();
116 }
117
118 int
119 SChannelAdapter::StartSSL(const char* hostname, bool restartable) {
120   if (state_ != SSL_NONE)
121     return ERROR_ALREADY_INITIALIZED;
122
123   ssl_host_name_ = hostname;
124   restartable_ = restartable;
125
126   if (socket_->GetState() != Socket::CS_CONNECTED) {
127     state_ = SSL_WAIT;
128     return 0;
129   }
130
131   state_ = SSL_CONNECTING;
132   if (int err = BeginSSL()) {
133     Error("BeginSSL", err, false);
134     return err;
135   }
136
137   return 0;
138 }
139
140 int
141 SChannelAdapter::BeginSSL() {
142   LOG(LS_VERBOSE) << "BeginSSL: " << ssl_host_name_;
143   ASSERT(state_ == SSL_CONNECTING);
144
145   SECURITY_STATUS ret;
146
147   SCHANNEL_CRED sc_cred = { 0 };
148   sc_cred.dwVersion = SCHANNEL_CRED_VERSION;
149   //sc_cred.dwMinimumCipherStrength = 128; // Note: use system default
150   sc_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_AUTO_CRED_VALIDATION;
151
152   ret = AcquireCredentialsHandle(NULL, UNISP_NAME, SECPKG_CRED_OUTBOUND, NULL,
153                                  &sc_cred, NULL, NULL, &impl_->cred, NULL);
154   if (ret != SEC_E_OK) {
155     LOG(LS_ERROR) << "AcquireCredentialsHandle error: "
156                   << ErrorName(ret, SECURITY_ERRORS);
157     return ret;
158   }
159   impl_->cred_init = true;
160
161   if (LOG_CHECK_LEVEL(LS_VERBOSE)) {
162     SecPkgCred_CipherStrengths cipher_strengths = { 0 };
163     ret = QueryCredentialsAttributes(&impl_->cred,
164                                      SECPKG_ATTR_CIPHER_STRENGTHS,
165                                      &cipher_strengths);
166     if (SUCCEEDED(ret)) {
167       LOG(LS_VERBOSE) << "SChannel cipher strength: "
168                   << cipher_strengths.dwMinimumCipherStrength << " - "
169                   << cipher_strengths.dwMaximumCipherStrength;
170     }
171
172     SecPkgCred_SupportedAlgs supported_algs = { 0 };
173     ret = QueryCredentialsAttributes(&impl_->cred,
174                                      SECPKG_ATTR_SUPPORTED_ALGS,
175                                      &supported_algs);
176     if (SUCCEEDED(ret)) {
177       LOG(LS_VERBOSE) << "SChannel supported algorithms:";
178       for (DWORD i=0; i<supported_algs.cSupportedAlgs; ++i) {
179         ALG_ID alg_id = supported_algs.palgSupportedAlgs[i];
180         PCCRYPT_OID_INFO oinfo = CryptFindOIDInfo(CRYPT_OID_INFO_ALGID_KEY,
181                                                   &alg_id, 0);
182         LPCWSTR alg_name = (NULL != oinfo) ? oinfo->pwszName : L"Unknown";
183         LOG(LS_VERBOSE) << "  " << ToUtf8(alg_name) << " (" << alg_id << ")";
184       }
185       CSecBufferBase::FreeSSPI(supported_algs.palgSupportedAlgs);
186     }
187   }
188
189   ULONG flags = SSL_FLAGS_DEFAULT, ret_flags = 0;
190   if (ignore_bad_cert())
191     flags |= ISC_REQ_MANUAL_CRED_VALIDATION;
192
193   CSecBufferBundle<2, CSecBufferBase::FreeSSPI> sb_out;
194   ret = InitializeSecurityContextA(&impl_->cred, NULL,
195                                    const_cast<char*>(ssl_host_name_.c_str()),
196                                    flags, 0, 0, NULL, 0,
197                                    &impl_->ctx, sb_out.desc(),
198                                    &ret_flags, NULL);
199   if (SUCCEEDED(ret))
200     impl_->ctx_init = true;
201   return ProcessContext(ret, NULL, sb_out.desc());
202 }
203
204 int
205 SChannelAdapter::ContinueSSL() {
206   LOG(LS_VERBOSE) << "ContinueSSL";
207   ASSERT(state_ == SSL_CONNECTING);
208
209   SECURITY_STATUS ret;
210
211   CSecBufferBundle<2> sb_in;
212   sb_in[0].BufferType = SECBUFFER_TOKEN;
213   sb_in[0].cbBuffer = static_cast<unsigned long>(impl_->inbuf.size());
214   sb_in[0].pvBuffer = &impl_->inbuf[0];
215   //DescribeBuffers(LS_VERBOSE, "Input Buffer ", sb_in.desc());
216
217   ULONG flags = SSL_FLAGS_DEFAULT, ret_flags = 0;
218   if (ignore_bad_cert())
219     flags |= ISC_REQ_MANUAL_CRED_VALIDATION;
220
221   CSecBufferBundle<2, CSecBufferBase::FreeSSPI> sb_out;
222   ret = InitializeSecurityContextA(&impl_->cred, &impl_->ctx,
223                                    const_cast<char*>(ssl_host_name_.c_str()),
224                                    flags, 0, 0, sb_in.desc(), 0,
225                                    NULL, sb_out.desc(),
226                                    &ret_flags, NULL);
227   return ProcessContext(ret, sb_in.desc(), sb_out.desc());
228 }
229
230 int
231 SChannelAdapter::ProcessContext(long int status, _SecBufferDesc* sbd_in,
232                                 _SecBufferDesc* sbd_out) {
233   if (status != SEC_E_OK && status != SEC_I_CONTINUE_NEEDED &&
234       status != SEC_E_INCOMPLETE_MESSAGE) {
235     LOG(LS_ERROR)
236       << "InitializeSecurityContext error: "
237       << ErrorName(status, SECURITY_ERRORS);
238   }
239   //if (sbd_in)
240   //  DescribeBuffers(LS_VERBOSE, "Input Buffer ", sbd_in);
241   //if (sbd_out)
242   //  DescribeBuffers(LS_VERBOSE, "Output Buffer ", sbd_out);
243
244   if (status == SEC_E_INCOMPLETE_MESSAGE) {
245     // Wait for more input from server.
246     return Flush();
247   }
248
249   if (FAILED(status)) {
250     // We can't continue.  Common errors:
251     // SEC_E_CERT_EXPIRED - Typically, this means the computer clock is wrong.
252     return status;
253   }
254
255   // Note: we check both input and output buffers for SECBUFFER_EXTRA.
256   // Experience shows it appearing in the input, but the documentation claims
257   // it should appear in the output.
258   size_t extra = 0;
259   if (sbd_in) {
260     for (size_t i=0; i<sbd_in->cBuffers; ++i) {
261       SecBuffer& buffer = sbd_in->pBuffers[i];
262       if (buffer.BufferType == SECBUFFER_EXTRA) {
263         extra += buffer.cbBuffer;
264       }
265     }
266   }
267   if (sbd_out) {
268     for (size_t i=0; i<sbd_out->cBuffers; ++i) {
269       SecBuffer& buffer = sbd_out->pBuffers[i];
270       if (buffer.BufferType == SECBUFFER_EXTRA) {
271         extra += buffer.cbBuffer;
272       } else if (buffer.BufferType == SECBUFFER_TOKEN) {
273         impl_->outbuf.insert(impl_->outbuf.end(),
274           reinterpret_cast<char*>(buffer.pvBuffer),
275           reinterpret_cast<char*>(buffer.pvBuffer) + buffer.cbBuffer);
276       }
277     }
278   }
279
280   if (extra) {
281     ASSERT(extra <= impl_->inbuf.size());
282     size_t consumed = impl_->inbuf.size() - extra;
283     memmove(&impl_->inbuf[0], &impl_->inbuf[consumed], extra);
284     impl_->inbuf.resize(extra);
285   } else {
286     impl_->inbuf.clear();
287   }
288
289   if (SEC_I_CONTINUE_NEEDED == status) {
290     // Send data to server and wait for response.
291     // Note: ContinueSSL will result in a Flush, anyway.
292     return impl_->inbuf.empty() ? Flush() : ContinueSSL();
293   }
294
295   if (SEC_E_OK == status) {
296     LOG(LS_VERBOSE) << "QueryContextAttributes";
297     status = QueryContextAttributes(&impl_->ctx, SECPKG_ATTR_STREAM_SIZES,
298                                     &impl_->sizes);
299     if (FAILED(status)) {
300       LOG(LS_ERROR) << "QueryContextAttributes error: "
301                     << ErrorName(status, SECURITY_ERRORS);
302       return status;
303     }
304
305     state_ = SSL_CONNECTED;
306
307     if (int err = DecryptData()) {
308       return err;
309     } else if (int err = Flush()) {
310       return err;
311     } else {
312       // If we decrypted any data, queue up a notification here
313       PostEvent();
314       // Signal our connectedness
315       AsyncSocketAdapter::OnConnectEvent(this);
316     }
317     return 0;
318   }
319
320   if (SEC_I_INCOMPLETE_CREDENTIALS == status) {
321     // We don't support client authentication in schannel.
322     return status;
323   }
324
325   // We don't expect any other codes
326   ASSERT(false);
327   return status;
328 }
329
330 int
331 SChannelAdapter::DecryptData() {
332   SChannelBuffer& inbuf = impl_->inbuf;
333   SChannelBuffer& readable = impl_->readable;
334
335   while (!inbuf.empty()) {
336     CSecBufferBundle<4> in_buf;
337     in_buf[0].BufferType = SECBUFFER_DATA;
338     in_buf[0].cbBuffer = static_cast<unsigned long>(inbuf.size());
339     in_buf[0].pvBuffer = &inbuf[0];
340
341     //DescribeBuffers(LS_VERBOSE, "Decrypt In ", in_buf.desc());
342     SECURITY_STATUS status = DecryptMessage(&impl_->ctx, in_buf.desc(), 0, 0);
343     //DescribeBuffers(LS_VERBOSE, "Decrypt Out ", in_buf.desc());
344
345     // Note: We are explicitly treating SEC_E_OK, SEC_I_CONTEXT_EXPIRED, and
346     // any other successful results as continue.
347     if (SUCCEEDED(status)) {
348       size_t data_len = 0, extra_len = 0;
349       for (size_t i=0; i<in_buf.desc()->cBuffers; ++i) {
350         if (in_buf[i].BufferType == SECBUFFER_DATA) {
351           data_len += in_buf[i].cbBuffer;
352           readable.insert(readable.end(),
353             reinterpret_cast<char*>(in_buf[i].pvBuffer),
354             reinterpret_cast<char*>(in_buf[i].pvBuffer) + in_buf[i].cbBuffer);
355         } else if (in_buf[i].BufferType == SECBUFFER_EXTRA) {
356           extra_len += in_buf[i].cbBuffer;
357         }
358       }
359       // There is a bug on Win2K where SEC_I_CONTEXT_EXPIRED is misclassified.
360       if ((data_len == 0) && (inbuf[0] == 0x15)) {
361         status = SEC_I_CONTEXT_EXPIRED;
362       }
363       if (extra_len) {
364         size_t consumed = inbuf.size() - extra_len;
365         memmove(&inbuf[0], &inbuf[consumed], extra_len);
366         inbuf.resize(extra_len);
367       } else {
368         inbuf.clear();
369       }
370       // TODO: Handle SEC_I_CONTEXT_EXPIRED to do clean shutdown
371       if (status != SEC_E_OK) {
372         LOG(LS_INFO) << "DecryptMessage returned continuation code: "
373                       << ErrorName(status, SECURITY_ERRORS);
374       }
375       continue;
376     }
377
378     if (status == SEC_E_INCOMPLETE_MESSAGE) {
379       break;
380     } else {
381       return status;
382     }
383   }
384
385   return 0;
386 }
387
388 void
389 SChannelAdapter::Cleanup() {
390   if (impl_->ctx_init)
391     DeleteSecurityContext(&impl_->ctx);
392   if (impl_->cred_init)
393     FreeCredentialsHandle(&impl_->cred);
394   delete impl_;
395 }
396
397 void
398 SChannelAdapter::PostEvent() {
399   // Check if there's anything notable to signal
400   if (impl_->readable.empty() && !signal_close_)
401     return;
402
403   // Only one post in the queue at a time
404   if (message_pending_)
405     return;
406
407   if (Thread* thread = Thread::Current()) {
408     message_pending_ = true;
409     thread->Post(this);
410   } else {
411     LOG(LS_ERROR) << "No thread context available for SChannelAdapter";
412     ASSERT(false);
413   }
414 }
415
416 void
417 SChannelAdapter::Error(const char* context, int err, bool signal) {
418   LOG(LS_WARNING) << "SChannelAdapter::Error("
419                   << context << ", "
420                   << ErrorName(err, SECURITY_ERRORS) << ")";
421   state_ = SSL_ERROR;
422   SetError(err);
423   if (signal)
424     AsyncSocketAdapter::OnCloseEvent(this, err);
425 }
426
427 int
428 SChannelAdapter::Read() {
429   char buffer[4096];
430   SChannelBuffer& inbuf = impl_->inbuf;
431   while (true) {
432     int ret = AsyncSocketAdapter::Recv(buffer, sizeof(buffer));
433     if (ret > 0) {
434       inbuf.insert(inbuf.end(), buffer, buffer + ret);
435     } else if (GetError() == EWOULDBLOCK) {
436       return 0;  // Blocking
437     } else {
438       return GetError();
439     }
440   }
441 }
442
443 int
444 SChannelAdapter::Flush() {
445   int result = 0;
446   size_t pos = 0;
447   SChannelBuffer& outbuf = impl_->outbuf;
448   while (pos < outbuf.size()) {
449     int sent = AsyncSocketAdapter::Send(&outbuf[pos], outbuf.size() - pos);
450     if (sent > 0) {
451       pos += sent;
452     } else if (GetError() == EWOULDBLOCK) {
453       break;  // Blocking
454     } else {
455       result = GetError();
456       break;
457     }
458   }
459   if (int remainder = static_cast<int>(outbuf.size() - pos)) {
460     memmove(&outbuf[0], &outbuf[pos], remainder);
461     outbuf.resize(remainder);
462   } else {
463     outbuf.clear();
464   }
465   return result;
466 }
467
468 //
469 // AsyncSocket Implementation
470 //
471
472 int
473 SChannelAdapter::Send(const void* pv, size_t cb) {
474   switch (state_) {
475   case SSL_NONE:
476     return AsyncSocketAdapter::Send(pv, cb);
477
478   case SSL_WAIT:
479   case SSL_CONNECTING:
480     SetError(EWOULDBLOCK);
481     return SOCKET_ERROR;
482
483   case SSL_CONNECTED:
484     break;
485
486   case SSL_ERROR:
487   default:
488     return SOCKET_ERROR;
489   }
490
491   size_t written = 0;
492   SChannelBuffer& outbuf = impl_->outbuf;
493   while (written < cb) {
494     const size_t encrypt_len = std::min<size_t>(cb - written,
495                                                 impl_->sizes.cbMaximumMessage);
496
497     CSecBufferBundle<4> out_buf;
498     out_buf[0].BufferType = SECBUFFER_STREAM_HEADER;
499     out_buf[0].cbBuffer = impl_->sizes.cbHeader;
500     out_buf[1].BufferType = SECBUFFER_DATA;
501     out_buf[1].cbBuffer = static_cast<unsigned long>(encrypt_len);
502     out_buf[2].BufferType = SECBUFFER_STREAM_TRAILER;
503     out_buf[2].cbBuffer = impl_->sizes.cbTrailer;
504
505     size_t packet_len = out_buf[0].cbBuffer
506                       + out_buf[1].cbBuffer
507                       + out_buf[2].cbBuffer;
508
509     SChannelBuffer message;
510     message.resize(packet_len);
511     out_buf[0].pvBuffer = &message[0];
512     out_buf[1].pvBuffer = &message[out_buf[0].cbBuffer];
513     out_buf[2].pvBuffer = &message[out_buf[0].cbBuffer + out_buf[1].cbBuffer];
514
515     memcpy(out_buf[1].pvBuffer,
516            static_cast<const char*>(pv) + written,
517            encrypt_len);
518
519     //DescribeBuffers(LS_VERBOSE, "Encrypt In ", out_buf.desc());
520     SECURITY_STATUS res = EncryptMessage(&impl_->ctx, 0, out_buf.desc(), 0);
521     //DescribeBuffers(LS_VERBOSE, "Encrypt Out ", out_buf.desc());
522
523     if (FAILED(res)) {
524       Error("EncryptMessage", res, false);
525       return SOCKET_ERROR;
526     }
527
528     // We assume that the header and data segments do not change length,
529     // or else encrypting the concatenated packet in-place is wrong.
530     ASSERT(out_buf[0].cbBuffer == impl_->sizes.cbHeader);
531     ASSERT(out_buf[1].cbBuffer == static_cast<unsigned long>(encrypt_len));
532
533     // However, the length of the trailer may change due to padding.
534     ASSERT(out_buf[2].cbBuffer <= impl_->sizes.cbTrailer);
535
536     packet_len = out_buf[0].cbBuffer
537                + out_buf[1].cbBuffer
538                + out_buf[2].cbBuffer;
539
540     written += encrypt_len;
541     outbuf.insert(outbuf.end(), &message[0], &message[packet_len-1]+1);
542   }
543
544   if (int err = Flush()) {
545     state_ = SSL_ERROR;
546     SetError(err);
547     return SOCKET_ERROR;
548   }
549
550   return static_cast<int>(written);
551 }
552
553 int
554 SChannelAdapter::Recv(void* pv, size_t cb) {
555   switch (state_) {
556   case SSL_NONE:
557     return AsyncSocketAdapter::Recv(pv, cb);
558
559   case SSL_WAIT:
560   case SSL_CONNECTING:
561     SetError(EWOULDBLOCK);
562     return SOCKET_ERROR;
563
564   case SSL_CONNECTED:
565     break;
566
567   case SSL_ERROR:
568   default:
569     return SOCKET_ERROR;
570   }
571
572   SChannelBuffer& readable = impl_->readable;
573   if (readable.empty()) {
574     SetError(EWOULDBLOCK);
575     return SOCKET_ERROR;
576   }
577   size_t read = _min(cb, readable.size());
578   memcpy(pv, &readable[0], read);
579   if (size_t remaining = readable.size() - read) {
580     memmove(&readable[0], &readable[read], remaining);
581     readable.resize(remaining);
582   } else {
583     readable.clear();
584   }
585
586   PostEvent();
587   return static_cast<int>(read);
588 }
589
590 int
591 SChannelAdapter::Close() {
592   if (!impl_->readable.empty()) {
593     LOG(WARNING) << "SChannelAdapter::Close with readable data";
594     // Note: this isn't strictly an error, but we're using it temporarily to
595     // track bugs.
596     //ASSERT(false);
597   }
598   if (state_ == SSL_CONNECTED) {
599     DWORD token = SCHANNEL_SHUTDOWN;
600     CSecBufferBundle<1> sb_in;
601     sb_in[0].BufferType = SECBUFFER_TOKEN;
602     sb_in[0].cbBuffer = sizeof(token);
603     sb_in[0].pvBuffer = &token;
604     ApplyControlToken(&impl_->ctx, sb_in.desc());
605     // TODO: In theory, to do a nice shutdown, we need to begin shutdown
606     // negotiation with more calls to InitializeSecurityContext.  Since the
607     // socket api doesn't support nice shutdown at this point, we don't bother.
608   }
609   Cleanup();
610   impl_ = new SSLImpl;
611   state_ = restartable_ ? SSL_WAIT : SSL_NONE;
612   signal_close_ = false;
613   message_pending_ = false;
614   return AsyncSocketAdapter::Close();
615 }
616
617 Socket::ConnState
618 SChannelAdapter::GetState() const {
619   if (signal_close_)
620     return CS_CONNECTED;
621   ConnState state = socket_->GetState();
622   if ((state == CS_CONNECTED)
623       && ((state_ == SSL_WAIT) || (state_ == SSL_CONNECTING)))
624     state = CS_CONNECTING;
625   return state;
626 }
627
628 void
629 SChannelAdapter::OnConnectEvent(AsyncSocket* socket) {
630   LOG(LS_VERBOSE) << "SChannelAdapter::OnConnectEvent";
631   if (state_ != SSL_WAIT) {
632     ASSERT(state_ == SSL_NONE);
633     AsyncSocketAdapter::OnConnectEvent(socket);
634     return;
635   }
636
637   state_ = SSL_CONNECTING;
638   if (int err = BeginSSL()) {
639     Error("BeginSSL", err);
640   }
641 }
642
643 void
644 SChannelAdapter::OnReadEvent(AsyncSocket* socket) {
645   if (state_ == SSL_NONE) {
646     AsyncSocketAdapter::OnReadEvent(socket);
647     return;
648   }
649
650   if (int err = Read()) {
651     Error("Read", err);
652     return;
653   }
654
655   if (impl_->inbuf.empty())
656     return;
657
658   if (state_ == SSL_CONNECTED) {
659     if (int err = DecryptData()) {
660       Error("DecryptData", err);
661     } else if (!impl_->readable.empty()) {
662       AsyncSocketAdapter::OnReadEvent(this);
663     }
664   } else if (state_ == SSL_CONNECTING) {
665     if (int err = ContinueSSL()) {
666       Error("ContinueSSL", err);
667     }
668   }
669 }
670
671 void
672 SChannelAdapter::OnWriteEvent(AsyncSocket* socket) {
673   if (state_ == SSL_NONE) {
674     AsyncSocketAdapter::OnWriteEvent(socket);
675     return;
676   }
677
678   if (int err = Flush()) {
679     Error("Flush", err);
680     return;
681   }
682
683   // See if we have more data to write
684   if (!impl_->outbuf.empty())
685     return;
686
687   // Buffer is empty, submit notification
688   if (state_ == SSL_CONNECTED) {
689     AsyncSocketAdapter::OnWriteEvent(socket);
690   }
691 }
692
693 void
694 SChannelAdapter::OnCloseEvent(AsyncSocket* socket, int err) {
695   if ((state_ == SSL_NONE) || impl_->readable.empty()) {
696     AsyncSocketAdapter::OnCloseEvent(socket, err);
697     return;
698   }
699
700   // If readable is non-empty, then we have a pending Message
701   // that will allow us to signal close (eventually).
702   signal_close_ = true;
703 }
704
705 void
706 SChannelAdapter::OnMessage(Message* pmsg) {
707   if (!message_pending_)
708     return;  // This occurs when socket is closed
709
710   message_pending_ = false;
711   if (!impl_->readable.empty()) {
712     AsyncSocketAdapter::OnReadEvent(this);
713   } else if (signal_close_) {
714     signal_close_ = false;
715     AsyncSocketAdapter::OnCloseEvent(this, 0); // TODO: cache this error?
716   }
717 }
718
719 } // namespace talk_base