tls: asynchronous SNICallback
authorFedor Indutny <fedor.indutny@gmail.com>
Sat, 3 Aug 2013 17:29:54 +0000 (21:29 +0400)
committerFedor Indutny <fedor.indutny@gmail.com>
Tue, 6 Aug 2013 12:13:01 +0000 (16:13 +0400)
Make ClientHelloParser handle SNI extension, and extend `_tls_wrap.js`
to support loading SNI Context from both hello, and resumed session.

fix #5967

doc/api/tls.markdown
lib/_tls_wrap.js
src/node_crypto_clienthello-inl.h
src/node_crypto_clienthello.cc
src/node_crypto_clienthello.h
src/tls_wrap.cc
src/tls_wrap.h
test/simple/test-tls-sni-option.js

index e4fd3aa..e1e8d2c 100644 (file)
@@ -156,9 +156,10 @@ automatically set as a listener for the [secureConnection][] event.  The
   - `NPNProtocols`: An array or `Buffer` of possible NPN protocols. (Protocols
     should be ordered by their priority).
 
-  - `SNICallback`: A function that will be called if client supports SNI TLS
-    extension. Only one argument will be passed to it: `servername`. And
-    `SNICallback` should return SecureContext instance.
+  - `SNICallback(servername, cb)`: A function that will be called if client
+    supports SNI TLS extension. Two argument will be passed to it: `servername`,
+    and `cb`. `SNICallback` should invoke `cb(null, ctx)`, where `ctx` is a
+    SecureContext instance.
     (You can use `crypto.createCredentials(...).context` to get proper
     SecureContext). If `SNICallback` wasn't provided - default callback with
     high-level API will be used (see below).
index b432f89..d4c2b57 100644 (file)
@@ -49,24 +49,67 @@ function onhandshakedone() {
 
 function onclienthello(hello) {
   var self = this,
-      once = false;
+      onceSession = false,
+      onceSNI = false;
 
   function callback(err, session) {
-    if (once)
-      return self.destroy(new Error('TLS session callback was called twice'));
-    once = true;
+    if (onceSession)
+      return self.destroy(new Error('TLS session callback was called 2 times'));
+    onceSession = true;
 
     if (err)
       return self.destroy(err);
 
-    self.ssl.loadSession(session);
+    // NOTE: That we have disabled OpenSSL's internal session storage in
+    // `node_crypto.cc` and hence its safe to rely on getting servername only
+    // from clienthello or this place.
+    var ret = self.ssl.loadSession(session);
+
+    // Servername came from SSL session
+    // NOTE: TLS Session ticket doesn't include servername information
+    //
+    // Another note, From RFC3546:
+    //
+    //   If, on the other hand, the older
+    //   session is resumed, then the server MUST ignore extensions appearing
+    //   in the client hello, and send a server hello containing no
+    //   extensions; in this case the extension functionality negotiated
+    //   during the original session initiation is applied to the resumed
+    //   session.
+    //
+    // Therefore we should account session loading when dealing with servername
+    if (ret && ret.servername) {
+      self._SNICallback(ret.servername, onSNIResult);
+    } else if (hello.servername && self._SNICallback) {
+      self._SNICallback(hello.servername, onSNIResult);
+    } else {
+      self.ssl.endParser();
+    }
+  }
+
+  function onSNIResult(err, context) {
+    if (onceSNI)
+      return self.destroy(new Error('TLS SNI callback was called 2 times'));
+    onceSNI = true;
+
+    if (err)
+      return self.destroy(err);
+
+    if (context)
+      self.ssl.sni_context = context;
+
+    self.ssl.endParser();
   }
 
   if (hello.sessionId.length <= 0 ||
       hello.tlsTicket ||
       this.server &&
       !this.server.emit('resumeSession', hello.sessionId, callback)) {
-    callback(null, null);
+    // Invoke SNI callback, since we've no session to resume
+    if (hello.servername && this._SNICallback)
+      this._SNICallback(hello.servername, onSNIResult);
+    else
+      this.ssl.endParser();
   }
 }
 
@@ -94,6 +137,7 @@ function TLSSocket(socket, options) {
   this._tlsOptions = options;
   this._secureEstablished = false;
   this._controlReleased = false;
+  this._SNICallback = null;
   this.ssl = null;
   this.servername = null;
   this.npnProtocol = null;
@@ -176,7 +220,8 @@ TLSSocket.prototype._init = function() {
       (options.SNICallback !== SNICallback ||
        options.server._contexts.length)) {
     assert(typeof options.SNICallback === 'function');
-    this.ssl.onsniselect = options.SNICallback;
+    this._SNICallback = options.SNICallback;
+    this.ssl.enableHelloParser();
   }
 
   if (process.features.tls_npn && options.NPNProtocols)
@@ -499,7 +544,7 @@ Server.prototype.addContext = function(servername, credentials) {
   this._contexts.push([re, crypto.createCredentials(credentials).context]);
 };
 
-function SNICallback(servername) {
+function SNICallback(servername, callback) {
   var ctx;
 
   this._contexts.some(function(elem) {
@@ -509,7 +554,7 @@ function SNICallback(servername) {
     }
   });
 
-  return ctx;
+  callback(null, ctx);
 }
 
 Server.prototype.SNICallback = SNICallback;
index 82c0d27..7b735dd 100644 (file)
@@ -34,6 +34,8 @@ inline void ClientHelloParser::Reset() {
   session_id_ = NULL;
   tls_ticket_size_ = -1;
   tls_ticket_ = NULL;
+  servername_size_ = 0;
+  servername_ = NULL;
 }
 
 inline void ClientHelloParser::Start(ClientHelloParser::OnHelloCb onhello_cb,
index 5c1ecfa..424b30e 100644 (file)
@@ -123,6 +123,8 @@ void ClientHelloParser::ParseHeader(const uint8_t* data, size_t avail) {
   hello.session_id_ = session_id_;
   hello.session_size_ = session_size_;
   hello.has_ticket_ = tls_ticket_ != NULL && tls_ticket_size_ != 0;
+  hello.servername_ = servername_;
+  hello.servername_size_ = servername_size_;
   onhello_cb_(cb_arg_, hello);
 }
 
@@ -134,6 +136,29 @@ void ClientHelloParser::ParseExtension(ClientHelloParser::ExtensionType type,
   // That's because we're heavily relying on OpenSSL to solve any problem with
   // incoming data.
   switch (type) {
+    case kServerName:
+      {
+        if (len < 2)
+          return;
+        uint16_t server_names_len = (data[0] << 8) + data[1];
+        if (server_names_len + 2 > len)
+          return;
+        for (size_t offset = 2; offset < 2 + server_names_len; ) {
+          if (offset + 3 > len)
+            return;
+          uint8_t name_type = data[offset];
+          if (name_type != kServernameHostname)
+            return;
+          uint16_t name_len = (data[offset + 1] << 8) + data[offset + 2];
+          offset += 3;
+          if (offset + name_len > len)
+            return;
+          servername_ = data + offset;
+          servername_size_ = name_len;
+          offset += name_len;
+        }
+      }
+      break;
     case kTLSSessionTicket:
       tls_ticket_size_ = len;
       tls_ticket_ = data + len;
index 6c98f5c..4301d9b 100644 (file)
@@ -46,11 +46,15 @@ class ClientHelloParser {
     inline uint8_t session_size() const { return session_size_; }
     inline const uint8_t* session_id() const { return session_id_; }
     inline bool has_ticket() const { return has_ticket_; }
+    inline uint8_t servername_size() const { return servername_size_; }
+    inline const uint8_t* servername() const { return servername_; }
 
    private:
     uint8_t session_size_;
     const uint8_t* session_id_;
     bool has_ticket_;
+    uint8_t servername_size_;
+    const uint8_t* servername_;
 
     friend class ClientHelloParser;
   };
@@ -71,6 +75,7 @@ class ClientHelloParser {
   static const uint8_t kSSL2HeaderMask = 0x3f;
   static const size_t kMaxTLSFrameLen = 16 * 1024 + 5;
   static const size_t kMaxSSLExFrameLen = 32 * 1024;
+  static const uint8_t kServernameHostname = 0;
 
   enum ParseState {
     kWaiting,
@@ -93,6 +98,7 @@ class ClientHelloParser {
   };
 
   enum ExtensionType {
+    kServerName = 0,
     kTLSSessionTicket = 35
   };
 
@@ -115,6 +121,8 @@ class ClientHelloParser {
   size_t extension_offset_;
   uint8_t session_size_;
   const uint8_t* session_id_;
+  uint16_t servername_size_;
+  const uint8_t* servername_;
   uint16_t tls_ticket_size_;
   const uint8_t* tls_ticket_;
 };
index 0fbb7e5..b3172a1 100644 (file)
@@ -49,7 +49,6 @@ using v8::Value;
 
 static Cached<String> onread_sym;
 static Cached<String> onerror_sym;
-static Cached<String> onsniselect_sym;
 static Cached<String> onhandshakestart_sym;
 static Cached<String> onhandshakedone_sym;
 static Cached<String> onclienthello_sym;
@@ -67,6 +66,8 @@ static Cached<String> version_sym;
 static Cached<String> ext_key_usage_sym;
 static Cached<String> sessionid_sym;
 static Cached<String> tls_ticket_sym;
+static Cached<String> servername_sym;
+static Cached<String> sni_context_sym;
 
 static Persistent<Function> tlsWrap;
 
@@ -174,7 +175,6 @@ TLSCallbacks::~TLSCallbacks() {
 #endif  // OPENSSL_NPN_NEGOTIATED
 
 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
-  servername_.Dispose();
   sni_context_.Dispose();
 #endif  // SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
 }
@@ -640,7 +640,6 @@ void TLSCallbacks::DoRead(uv_stream_t* handle,
 
   // Parse ClientHello first
   if (!hello_.IsEnded()) {
-    assert(session_callbacks_);
     size_t avail = 0;
     uint8_t* data = reinterpret_cast<uint8_t*>(enc_in->Peek(&avail));
     assert(avail == 0 || data != NULL);
@@ -770,6 +769,16 @@ void TLSCallbacks::EnableSessionCallbacks(
   UNWRAP(TLSCallbacks);
 
   wrap->session_callbacks_ = true;
+  EnableHelloParser(args);
+}
+
+
+void TLSCallbacks::EnableHelloParser(
+    const FunctionCallbackInfo<Value>& args) {
+  HandleScope scope(node_isolate);
+
+  UNWRAP(TLSCallbacks);
+
   wrap->hello_.Start(OnClientHello, OnClientHelloParseEnd, wrap);
 }
 
@@ -785,6 +794,14 @@ void TLSCallbacks::OnClientHello(void* arg,
       reinterpret_cast<const char*>(hello.session_id()),
                                     hello.session_size());
   hello_obj->Set(sessionid_sym, buff);
+  if (hello.servername() == NULL) {
+    hello_obj->Set(servername_sym, String::Empty(node_isolate));
+  } else {
+    Local<String> servername = String::New(
+        reinterpret_cast<const char*>(hello.servername()),
+        hello.servername_size());
+    hello_obj->Set(servername_sym, servername);
+  }
   hello_obj->Set(tls_ticket_sym, Boolean::New(hello.has_ticket()));
 
   Handle<Value> argv[1] = { hello_obj };
@@ -999,7 +1016,23 @@ void TLSCallbacks::LoadSession(const FunctionCallbackInfo<Value>& args) {
     if (wrap->next_sess_ != NULL)
       SSL_SESSION_free(wrap->next_sess_);
     wrap->next_sess_ = sess;
+
+    Local<Object> info = Object::New();
+#ifndef OPENSSL_NO_TLSEXT
+    if (sess->tlsext_hostname == NULL) {
+      info->Set(servername_sym, False(node_isolate));
+    } else {
+      info->Set(servername_sym, String::New(sess->tlsext_hostname));
+    }
+#endif
+    args.GetReturnValue().Set(info);
   }
+}
+
+void TLSCallbacks::EndParser(const FunctionCallbackInfo<Value>& args) {
+  HandleScope scope(node_isolate);
+
+  UNWRAP(TLSCallbacks);
 
   wrap->hello_.End();
 }
@@ -1143,8 +1176,10 @@ void TLSCallbacks::GetServername(const FunctionCallbackInfo<Value>& args) {
 
   UNWRAP(TLSCallbacks);
 
-  if (wrap->kind_ == kTLSServer && !wrap->servername_.IsEmpty()) {
-    args.GetReturnValue().Set(wrap->servername_);
+  const char* servername = SSL_get_servername(wrap->ssl_,
+                                              TLSEXT_NAMETYPE_host_name);
+  if (servername != NULL) {
+    args.GetReturnValue().Set(String::New(servername));
   } else {
     args.GetReturnValue().Set(false);
   }
@@ -1179,25 +1214,22 @@ int TLSCallbacks::SelectSNIContextCallback(SSL* s, int* ad, void* arg) {
 
   const char* servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
 
-  if (servername) {
-    p->servername_.Reset(node_isolate, String::New(servername));
-
+  if (servername != NULL) {
     // Call the SNI callback and use its return value as context
     Local<Object> object = p->object();
-    if (object->Has(onsniselect_sym)) {
-      p->sni_context_.Dispose();
+    Local<Value> ctx;
+    if (object->Has(sni_context_sym)) {
+      ctx = object->Get(sni_context_sym);
+    }
 
-      Local<Value> arg = PersistentToLocal(node_isolate, p->servername_);
-      Handle<Value> ret = MakeCallback(object, onsniselect_sym, 1, &arg);
+    if (ctx.IsEmpty() || ctx->IsUndefined())
+      return SSL_TLSEXT_ERR_NOACK;
 
-      // If ret is SecureContext
-      if (ret->IsUndefined())
-        return SSL_TLSEXT_ERR_NOACK;
+    p->sni_context_.Dispose();
+    p->sni_context_.Reset(node_isolate, ctx);
 
-      p->sni_context_.Reset(node_isolate, ret);
-      SecureContext* sc = ObjectWrap::Unwrap<SecureContext>(ret.As<Object>());
-      SSL_set_SSL_CTX(s, sc->ctx_);
-    }
+    SecureContext* sc = ObjectWrap::Unwrap<SecureContext>(ctx.As<Object>());
+    SSL_set_SSL_CTX(s, sc->ctx_);
   }
 
   return SSL_TLSEXT_ERR_OK;
@@ -1219,6 +1251,7 @@ void TLSCallbacks::Initialize(Handle<Object> target) {
   NODE_SET_PROTOTYPE_METHOD(t, "getSession", GetSession);
   NODE_SET_PROTOTYPE_METHOD(t, "setSession", SetSession);
   NODE_SET_PROTOTYPE_METHOD(t, "loadSession", LoadSession);
+  NODE_SET_PROTOTYPE_METHOD(t, "endParser", EndParser);
   NODE_SET_PROTOTYPE_METHOD(t, "getCurrentCipher", GetCurrentCipher);
   NODE_SET_PROTOTYPE_METHOD(t, "verifyError", VerifyError);
   NODE_SET_PROTOTYPE_METHOD(t, "setVerifyMode", SetVerifyMode);
@@ -1226,6 +1259,9 @@ void TLSCallbacks::Initialize(Handle<Object> target) {
   NODE_SET_PROTOTYPE_METHOD(t,
                             "enableSessionCallbacks",
                             EnableSessionCallbacks);
+  NODE_SET_PROTOTYPE_METHOD(t,
+                            "enableHelloParser",
+                            EnableHelloParser);
 
 #ifdef OPENSSL_NPN_NEGOTIATED
   NODE_SET_PROTOTYPE_METHOD(t, "getNegotiatedProtocol", GetNegotiatedProto);
@@ -1240,7 +1276,6 @@ void TLSCallbacks::Initialize(Handle<Object> target) {
   tlsWrap.Reset(node_isolate, t->GetFunction());
 
   onread_sym = String::New("onread");
-  onsniselect_sym = String::New("onsniselect");
   onerror_sym = String::New("onerror");
   onhandshakestart_sym = String::New("onhandshakestart");
   onhandshakedone_sym = String::New("onhandshakedone");
@@ -1260,6 +1295,8 @@ void TLSCallbacks::Initialize(Handle<Object> target) {
   ext_key_usage_sym = String::New("ext_key_usage");
   sessionid_sym = String::New("sessionId");
   tls_ticket_sym = String::New("tlsTicket");
+  servername_sym = String::New("servername");
+  sni_context_sym = String::New("sni_context");
 }
 
 }  // namespace node
index eb8d7ce..bb56794 100644 (file)
@@ -108,12 +108,15 @@ class TLSCallbacks : public StreamWrapCallbacks {
   static void GetSession(const v8::FunctionCallbackInfo<v8::Value>& args);
   static void SetSession(const v8::FunctionCallbackInfo<v8::Value>& args);
   static void LoadSession(const v8::FunctionCallbackInfo<v8::Value>& args);
+  static void EndParser(const v8::FunctionCallbackInfo<v8::Value>& args);
   static void GetCurrentCipher(const v8::FunctionCallbackInfo<v8::Value>& args);
   static void VerifyError(const v8::FunctionCallbackInfo<v8::Value>& args);
   static void SetVerifyMode(const v8::FunctionCallbackInfo<v8::Value>& args);
   static void IsSessionReused(const v8::FunctionCallbackInfo<v8::Value>& args);
   static void EnableSessionCallbacks(
       const v8::FunctionCallbackInfo<v8::Value>& args);
+  static void EnableHelloParser(
+      const v8::FunctionCallbackInfo<v8::Value>& args);
 
   // TLS Session API
   static SSL_SESSION* GetSessionCallback(SSL* s,
@@ -178,7 +181,6 @@ class TLSCallbacks : public StreamWrapCallbacks {
 #endif  // OPENSSL_NPN_NEGOTIATED
 
 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
-  v8::Persistent<v8::String> servername_;
   v8::Persistent<v8::Value> sni_context_;
 #endif  // SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
 };
index 3518202..aaf37c7 100644 (file)
@@ -42,10 +42,16 @@ function loadPEM(n) {
 var serverOptions = {
   key: loadPEM('agent2-key'),
   cert: loadPEM('agent2-cert'),
-  SNICallback: function(servername) {
+  SNICallback: function(servername, callback) {
     var credentials = SNIContexts[servername];
-    if (credentials)
-      return crypto.createCredentials(credentials).context;
+
+    // Just to test asynchronous callback
+    setTimeout(function() {
+      if (credentials)
+        callback(null, crypto.createCredentials(credentials).context);
+      else
+        callback(null, null);
+    }, 100);
   }
 };