Imported Upstream version 1.34.0
[platform/upstream/nghttp2.git] / src / shrpx_tls.cc
1 /*
2  * nghttp2 - HTTP/2 C Library
3  *
4  * Copyright (c) 2012 Tatsuhiro Tsujikawa
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining
7  * a copy of this software and associated documentation files (the
8  * "Software"), to deal in the Software without restriction, including
9  * without limitation the rights to use, copy, modify, merge, publish,
10  * distribute, sublicense, and/or sell copies of the Software, and to
11  * permit persons to whom the Software is furnished to do so, subject to
12  * the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be
15  * included in all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24  */
25 #include "shrpx_tls.h"
26
27 #ifdef HAVE_SYS_SOCKET_H
28 #  include <sys/socket.h>
29 #endif // HAVE_SYS_SOCKET_H
30 #ifdef HAVE_NETDB_H
31 #  include <netdb.h>
32 #endif // HAVE_NETDB_H
33 #include <netinet/tcp.h>
34 #include <pthread.h>
35 #include <sys/types.h>
36
37 #include <vector>
38 #include <string>
39 #include <iomanip>
40
41 #include <iostream>
42
43 #include <openssl/crypto.h>
44 #include <openssl/x509.h>
45 #include <openssl/x509v3.h>
46 #include <openssl/rand.h>
47 #include <openssl/dh.h>
48 #ifndef OPENSSL_NO_OCSP
49 #  include <openssl/ocsp.h>
50 #endif // OPENSSL_NO_OCSP
51
52 #include <nghttp2/nghttp2.h>
53
54 #include "shrpx_log.h"
55 #include "shrpx_client_handler.h"
56 #include "shrpx_config.h"
57 #include "shrpx_worker.h"
58 #include "shrpx_downstream_connection_pool.h"
59 #include "shrpx_http2_session.h"
60 #include "shrpx_memcached_request.h"
61 #include "shrpx_memcached_dispatcher.h"
62 #include "shrpx_connection_handler.h"
63 #include "util.h"
64 #include "tls.h"
65 #include "template.h"
66 #include "ssl_compat.h"
67 #include "timegm.h"
68
69 using namespace nghttp2;
70
71 namespace shrpx {
72
73 namespace tls {
74
75 #if !OPENSSL_1_1_API
76 namespace {
77 const unsigned char *ASN1_STRING_get0_data(ASN1_STRING *x) {
78   return ASN1_STRING_data(x);
79 }
80 } // namespace
81 #endif // !OPENSSL_1_1_API
82
83 #ifndef OPENSSL_NO_NEXTPROTONEG
84 namespace {
85 int next_proto_cb(SSL *s, const unsigned char **data, unsigned int *len,
86                   void *arg) {
87   auto &prefs = get_config()->tls.alpn_prefs;
88   *data = prefs.data();
89   *len = prefs.size();
90   return SSL_TLSEXT_ERR_OK;
91 }
92 } // namespace
93 #endif // !OPENSSL_NO_NEXTPROTONEG
94
95 namespace {
96 int verify_callback(int preverify_ok, X509_STORE_CTX *ctx) {
97   if (!preverify_ok) {
98     int err = X509_STORE_CTX_get_error(ctx);
99     int depth = X509_STORE_CTX_get_error_depth(ctx);
100     if (err == X509_V_ERR_CERT_HAS_EXPIRED && depth == 0 &&
101         get_config()->tls.client_verify.tolerate_expired) {
102       LOG(INFO) << "The client certificate has expired, but is accepted by "
103                    "configuration";
104       return 1;
105     }
106     LOG(ERROR) << "client certificate verify error:num=" << err << ":"
107                << X509_verify_cert_error_string(err) << ":depth=" << depth;
108   }
109   return preverify_ok;
110 }
111 } // namespace
112
113 int set_alpn_prefs(std::vector<unsigned char> &out,
114                    const std::vector<StringRef> &protos) {
115   size_t len = 0;
116
117   for (const auto &proto : protos) {
118     if (proto.size() > 255) {
119       LOG(FATAL) << "Too long ALPN identifier: " << proto.size();
120       return -1;
121     }
122
123     len += 1 + proto.size();
124   }
125
126   if (len > (1 << 16) - 1) {
127     LOG(FATAL) << "Too long ALPN identifier list: " << len;
128     return -1;
129   }
130
131   out.resize(len);
132   auto ptr = out.data();
133
134   for (const auto &proto : protos) {
135     *ptr++ = proto.size();
136     ptr = std::copy(std::begin(proto), std::end(proto), ptr);
137   }
138
139   return 0;
140 }
141
142 namespace {
143 int ssl_pem_passwd_cb(char *buf, int size, int rwflag, void *user_data) {
144   auto config = static_cast<Config *>(user_data);
145   auto len = static_cast<int>(config->tls.private_key_passwd.size());
146   if (size < len + 1) {
147     LOG(ERROR) << "ssl_pem_passwd_cb: buf is too small " << size;
148     return 0;
149   }
150   // Copy string including last '\0'.
151   memcpy(buf, config->tls.private_key_passwd.c_str(), len + 1);
152   return len;
153 }
154 } // namespace
155
156 namespace {
157 // *al is set to SSL_AD_UNRECOGNIZED_NAME by openssl, so we don't have
158 // to set it explicitly.
159 int servername_callback(SSL *ssl, int *al, void *arg) {
160   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
161   auto handler = static_cast<ClientHandler *>(conn->data);
162   auto worker = handler->get_worker();
163
164   auto rawhost = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
165   if (rawhost == nullptr) {
166     return SSL_TLSEXT_ERR_NOACK;
167   }
168
169   auto len = strlen(rawhost);
170   // NI_MAXHOST includes terminal NULL.
171   if (len == 0 || len + 1 > NI_MAXHOST) {
172     return SSL_TLSEXT_ERR_NOACK;
173   }
174
175   std::array<uint8_t, NI_MAXHOST> buf;
176
177   auto end_buf = std::copy_n(rawhost, len, std::begin(buf));
178
179   util::inp_strlower(std::begin(buf), end_buf);
180
181   auto hostname = StringRef{std::begin(buf), end_buf};
182
183   auto cert_tree = worker->get_cert_lookup_tree();
184
185   auto idx = cert_tree->lookup(hostname);
186   if (idx == -1) {
187     return SSL_TLSEXT_ERR_NOACK;
188   }
189
190   handler->set_tls_sni(hostname);
191
192   auto conn_handler = worker->get_connection_handler();
193
194   const auto &ssl_ctx_list = conn_handler->get_indexed_ssl_ctx(idx);
195   assert(!ssl_ctx_list.empty());
196
197 #if !defined(OPENSSL_IS_BORINGSSL) && !LIBRESSL_IN_USE &&                      \
198     OPENSSL_VERSION_NUMBER >= 0x10002000L
199   auto num_shared_curves = SSL_get_shared_curve(ssl, -1);
200
201   for (auto i = 0; i < num_shared_curves; ++i) {
202     auto shared_curve = SSL_get_shared_curve(ssl, i);
203
204     for (auto ssl_ctx : ssl_ctx_list) {
205       auto cert = SSL_CTX_get0_certificate(ssl_ctx);
206
207 #  if OPENSSL_1_1_API
208       auto pubkey = X509_get0_pubkey(cert);
209 #  else  // !OPENSSL_1_1_API
210       auto pubkey = X509_get_pubkey(cert);
211 #  endif // !OPENSSL_1_1_API
212
213       if (EVP_PKEY_base_id(pubkey) != EVP_PKEY_EC) {
214         continue;
215       }
216
217 #  if OPENSSL_1_1_API
218       auto eckey = EVP_PKEY_get0_EC_KEY(pubkey);
219 #  else  // !OPENSSL_1_1_API
220       auto eckey = EVP_PKEY_get1_EC_KEY(pubkey);
221 #  endif // !OPENSSL_1_1_API
222
223       if (eckey == nullptr) {
224         continue;
225       }
226
227       auto ecgroup = EC_KEY_get0_group(eckey);
228       auto cert_curve = EC_GROUP_get_curve_name(ecgroup);
229
230 #  if !OPENSSL_1_1_API
231       EC_KEY_free(eckey);
232       EVP_PKEY_free(pubkey);
233 #  endif // !OPENSSL_1_1_API
234
235       if (shared_curve == cert_curve) {
236         SSL_set_SSL_CTX(ssl, ssl_ctx);
237         return SSL_TLSEXT_ERR_OK;
238       }
239     }
240   }
241 #endif // !defined(OPENSSL_IS_BORINGSSL) && !LIBRESSL_IN_USE &&
242        // OPENSSL_VERSION_NUMBER >= 0x10002000L
243
244   SSL_set_SSL_CTX(ssl, ssl_ctx_list[0]);
245
246   return SSL_TLSEXT_ERR_OK;
247 }
248 } // namespace
249
250 #ifndef OPENSSL_IS_BORINGSSL
251 namespace {
252 std::shared_ptr<std::vector<uint8_t>>
253 get_ocsp_data(TLSContextData *tls_ctx_data) {
254 #  ifdef HAVE_ATOMIC_STD_SHARED_PTR
255   return std::atomic_load_explicit(&tls_ctx_data->ocsp_data,
256                                    std::memory_order_acquire);
257 #  else  // !HAVE_ATOMIC_STD_SHARED_PTR
258   std::lock_guard<std::mutex> g(tls_ctx_data->mu);
259   return tls_ctx_data->ocsp_data;
260 #  endif // !HAVE_ATOMIC_STD_SHARED_PTR
261 }
262 } // namespace
263
264 namespace {
265 int ocsp_resp_cb(SSL *ssl, void *arg) {
266   auto ssl_ctx = SSL_get_SSL_CTX(ssl);
267   auto tls_ctx_data =
268       static_cast<TLSContextData *>(SSL_CTX_get_app_data(ssl_ctx));
269
270   auto data = get_ocsp_data(tls_ctx_data);
271
272   if (!data) {
273     return SSL_TLSEXT_ERR_OK;
274   }
275
276   auto buf =
277       static_cast<uint8_t *>(CRYPTO_malloc(data->size(), __FILE__, __LINE__));
278
279   if (!buf) {
280     return SSL_TLSEXT_ERR_OK;
281   }
282
283   std::copy(std::begin(*data), std::end(*data), buf);
284
285   SSL_set_tlsext_status_ocsp_resp(ssl, buf, data->size());
286
287   return SSL_TLSEXT_ERR_OK;
288 }
289 } // namespace
290 #endif // OPENSSL_IS_BORINGSSL
291
292 constexpr auto MEMCACHED_SESSION_CACHE_KEY_PREFIX =
293     StringRef::from_lit("nghttpx:tls-session-cache:");
294
295 namespace {
296 int tls_session_client_new_cb(SSL *ssl, SSL_SESSION *session) {
297   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
298   if (conn->tls.client_session_cache == nullptr) {
299     return 0;
300   }
301
302   try_cache_tls_session(conn->tls.client_session_cache, session,
303                         ev_now(conn->loop));
304
305   return 0;
306 }
307 } // namespace
308
309 namespace {
310 int tls_session_new_cb(SSL *ssl, SSL_SESSION *session) {
311   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
312   auto handler = static_cast<ClientHandler *>(conn->data);
313   auto worker = handler->get_worker();
314   auto dispatcher = worker->get_session_cache_memcached_dispatcher();
315   auto &balloc = handler->get_block_allocator();
316
317 #ifdef TLS1_3_VERSION
318   if (SSL_version(ssl) == TLS1_3_VERSION) {
319     return 0;
320   }
321 #endif // TLS1_3_VERSION
322
323   const unsigned char *id;
324   unsigned int idlen;
325
326   id = SSL_SESSION_get_id(session, &idlen);
327
328   if (LOG_ENABLED(INFO)) {
329     LOG(INFO) << "Memcached: cache session, id=" << util::format_hex(id, idlen);
330   }
331
332   auto req = make_unique<MemcachedRequest>();
333   req->op = MEMCACHED_OP_ADD;
334   req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str();
335   req->key +=
336       util::format_hex(balloc, StringRef{id, static_cast<size_t>(idlen)});
337
338   auto sessionlen = i2d_SSL_SESSION(session, nullptr);
339   req->value.resize(sessionlen);
340   auto buf = &req->value[0];
341   i2d_SSL_SESSION(session, &buf);
342   req->expiry = 12_h;
343   req->cb = [](MemcachedRequest *req, MemcachedResult res) {
344     if (LOG_ENABLED(INFO)) {
345       LOG(INFO) << "Memcached: session cache done.  key=" << req->key
346                 << ", status_code=" << res.status_code << ", value="
347                 << std::string(std::begin(res.value), std::end(res.value));
348     }
349     if (res.status_code != 0) {
350       LOG(WARN) << "Memcached: failed to cache session key=" << req->key
351                 << ", status_code=" << res.status_code << ", value="
352                 << std::string(std::begin(res.value), std::end(res.value));
353     }
354   };
355   assert(!req->canceled);
356
357   dispatcher->add_request(std::move(req));
358
359   return 0;
360 }
361 } // namespace
362
363 namespace {
364 SSL_SESSION *tls_session_get_cb(SSL *ssl,
365 #if OPENSSL_1_1_API
366                                 const unsigned char *id,
367 #else  // !OPENSSL_1_1_API
368                                 unsigned char *id,
369 #endif // !OPENSSL_1_1_API
370                                 int idlen, int *copy) {
371   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
372   auto handler = static_cast<ClientHandler *>(conn->data);
373   auto worker = handler->get_worker();
374   auto dispatcher = worker->get_session_cache_memcached_dispatcher();
375   auto &balloc = handler->get_block_allocator();
376
377   if (idlen == 0) {
378     return nullptr;
379   }
380
381   if (conn->tls.cached_session) {
382     if (LOG_ENABLED(INFO)) {
383       LOG(INFO) << "Memcached: found cached session, id="
384                 << util::format_hex(id, idlen);
385     }
386
387     // This is required, without this, memory leak occurs.
388     *copy = 0;
389
390     auto session = conn->tls.cached_session;
391     conn->tls.cached_session = nullptr;
392     return session;
393   }
394
395   if (LOG_ENABLED(INFO)) {
396     LOG(INFO) << "Memcached: get cached session, id="
397               << util::format_hex(id, idlen);
398   }
399
400   auto req = make_unique<MemcachedRequest>();
401   req->op = MEMCACHED_OP_GET;
402   req->key = MEMCACHED_SESSION_CACHE_KEY_PREFIX.str();
403   req->key +=
404       util::format_hex(balloc, StringRef{id, static_cast<size_t>(idlen)});
405   req->cb = [conn](MemcachedRequest *, MemcachedResult res) {
406     if (LOG_ENABLED(INFO)) {
407       LOG(INFO) << "Memcached: returned status code " << res.status_code;
408     }
409
410     // We might stop reading, so start it again
411     conn->rlimit.startw();
412     ev_timer_again(conn->loop, &conn->rt);
413
414     conn->wlimit.startw();
415     ev_timer_again(conn->loop, &conn->wt);
416
417     conn->tls.cached_session_lookup_req = nullptr;
418     if (res.status_code != 0) {
419       conn->tls.handshake_state = TLS_CONN_CANCEL_SESSION_CACHE;
420       return;
421     }
422
423     const uint8_t *p = res.value.data();
424
425     auto session = d2i_SSL_SESSION(nullptr, &p, res.value.size());
426     if (!session) {
427       if (LOG_ENABLED(INFO)) {
428         LOG(INFO) << "cannot materialize session";
429       }
430       conn->tls.handshake_state = TLS_CONN_CANCEL_SESSION_CACHE;
431       return;
432     }
433
434     conn->tls.cached_session = session;
435     conn->tls.handshake_state = TLS_CONN_GOT_SESSION_CACHE;
436   };
437
438   conn->tls.handshake_state = TLS_CONN_WAIT_FOR_SESSION_CACHE;
439   conn->tls.cached_session_lookup_req = req.get();
440
441   dispatcher->add_request(std::move(req));
442
443   return nullptr;
444 }
445 } // namespace
446
447 namespace {
448 int ticket_key_cb(SSL *ssl, unsigned char *key_name, unsigned char *iv,
449                   EVP_CIPHER_CTX *ctx, HMAC_CTX *hctx, int enc) {
450   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
451   auto handler = static_cast<ClientHandler *>(conn->data);
452   auto worker = handler->get_worker();
453   auto ticket_keys = worker->get_ticket_keys();
454
455   if (!ticket_keys) {
456     // No ticket keys available.
457     return -1;
458   }
459
460   auto &keys = ticket_keys->keys;
461   assert(!keys.empty());
462
463   if (enc) {
464     if (RAND_bytes(iv, EVP_MAX_IV_LENGTH) == 0) {
465       if (LOG_ENABLED(INFO)) {
466         CLOG(INFO, handler) << "session ticket key: RAND_bytes failed";
467       }
468       return -1;
469     }
470
471     auto &key = keys[0];
472
473     if (LOG_ENABLED(INFO)) {
474       CLOG(INFO, handler) << "encrypt session ticket key: "
475                           << util::format_hex(key.data.name);
476     }
477
478     std::copy(std::begin(key.data.name), std::end(key.data.name), key_name);
479
480     EVP_EncryptInit_ex(ctx, get_config()->tls.ticket.cipher, nullptr,
481                        key.data.enc_key.data(), iv);
482     HMAC_Init_ex(hctx, key.data.hmac_key.data(), key.hmac_keylen, key.hmac,
483                  nullptr);
484     return 1;
485   }
486
487   size_t i;
488   for (i = 0; i < keys.size(); ++i) {
489     auto &key = keys[i];
490     if (std::equal(std::begin(key.data.name), std::end(key.data.name),
491                    key_name)) {
492       break;
493     }
494   }
495
496   if (i == keys.size()) {
497     if (LOG_ENABLED(INFO)) {
498       CLOG(INFO, handler) << "session ticket key "
499                           << util::format_hex(key_name, 16) << " not found";
500     }
501     return 0;
502   }
503
504   if (LOG_ENABLED(INFO)) {
505     CLOG(INFO, handler) << "decrypt session ticket key: "
506                         << util::format_hex(key_name, 16);
507   }
508
509   auto &key = keys[i];
510   HMAC_Init_ex(hctx, key.data.hmac_key.data(), key.hmac_keylen, key.hmac,
511                nullptr);
512   EVP_DecryptInit_ex(ctx, key.cipher, nullptr, key.data.enc_key.data(), iv);
513
514   return i == 0 ? 1 : 2;
515 }
516 } // namespace
517
518 namespace {
519 void info_callback(const SSL *ssl, int where, int ret) {
520 #ifdef TLS1_3_VERSION
521   // TLSv1.3 has no renegotiation.
522   if (SSL_version(ssl) == TLS1_3_VERSION) {
523     return;
524   }
525 #endif // TLS1_3_VERSION
526
527   // To mitigate possible DOS attack using lots of renegotiations, we
528   // disable renegotiation. Since OpenSSL does not provide an easy way
529   // to disable it, we check that renegotiation is started in this
530   // callback.
531   if (where & SSL_CB_HANDSHAKE_START) {
532     auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
533     if (conn && conn->tls.initial_handshake_done) {
534       auto handler = static_cast<ClientHandler *>(conn->data);
535       if (LOG_ENABLED(INFO)) {
536         CLOG(INFO, handler) << "TLS renegotiation started";
537       }
538       handler->start_immediate_shutdown();
539     }
540   }
541 }
542 } // namespace
543
544 #if OPENSSL_VERSION_NUMBER >= 0x10002000L
545 namespace {
546 int alpn_select_proto_cb(SSL *ssl, const unsigned char **out,
547                          unsigned char *outlen, const unsigned char *in,
548                          unsigned int inlen, void *arg) {
549   // We assume that get_config()->npn_list contains ALPN protocol
550   // identifier sorted by preference order.  So we just break when we
551   // found the first overlap.
552   for (const auto &target_proto_id : get_config()->tls.npn_list) {
553     for (auto p = in, end = in + inlen; p < end;) {
554       auto proto_id = p + 1;
555       auto proto_len = *p;
556
557       if (proto_id + proto_len <= end &&
558           util::streq(target_proto_id, StringRef{proto_id, proto_len})) {
559
560         *out = reinterpret_cast<const unsigned char *>(proto_id);
561         *outlen = proto_len;
562
563         return SSL_TLSEXT_ERR_OK;
564       }
565
566       p += 1 + proto_len;
567     }
568   }
569
570   return SSL_TLSEXT_ERR_NOACK;
571 }
572 } // namespace
573 #endif // OPENSSL_VERSION_NUMBER >= 0x10002000L
574
575 #if !LIBRESSL_IN_USE && OPENSSL_VERSION_NUMBER >= 0x10002000L
576
577 #  ifndef TLSEXT_TYPE_signed_certificate_timestamp
578 #    define TLSEXT_TYPE_signed_certificate_timestamp 18
579 #  endif // !TLSEXT_TYPE_signed_certificate_timestamp
580
581 namespace {
582 int sct_add_cb(SSL *ssl, unsigned int ext_type, unsigned int context,
583                const unsigned char **out, size_t *outlen, X509 *x,
584                size_t chainidx, int *al, void *add_arg) {
585   assert(ext_type == TLSEXT_TYPE_signed_certificate_timestamp);
586
587   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
588   if (!conn->tls.sct_requested) {
589     return 0;
590   }
591
592   if (LOG_ENABLED(INFO)) {
593     LOG(INFO) << "sct_add_cb is called, chainidx=" << chainidx << ", x=" << x
594               << ", context=" << log::hex << context;
595   }
596
597   // We only have SCTs for leaf certificate.
598   if (chainidx != 0) {
599     return 0;
600   }
601
602   auto ssl_ctx = SSL_get_SSL_CTX(ssl);
603   auto tls_ctx_data =
604       static_cast<TLSContextData *>(SSL_CTX_get_app_data(ssl_ctx));
605
606   *out = tls_ctx_data->sct_data.data();
607   *outlen = tls_ctx_data->sct_data.size();
608
609   return 1;
610 }
611 } // namespace
612
613 namespace {
614 void sct_free_cb(SSL *ssl, unsigned int ext_type, unsigned int context,
615                  const unsigned char *out, void *add_arg) {
616   assert(ext_type == TLSEXT_TYPE_signed_certificate_timestamp);
617 }
618 } // namespace
619
620 namespace {
621 int sct_parse_cb(SSL *ssl, unsigned int ext_type, unsigned int context,
622                  const unsigned char *in, size_t inlen, X509 *x,
623                  size_t chainidx, int *al, void *parse_arg) {
624   assert(ext_type == TLSEXT_TYPE_signed_certificate_timestamp);
625   // client SHOULD send 0 length extension_data, but it is still
626   // SHOULD, and not MUST.
627
628   // For TLSv1.3 Certificate message, sct_add_cb is called even if
629   // client has not sent signed_certificate_timestamp extension in its
630   // ClientHello.  Explicitly remember that client has included it
631   // here.
632   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
633   conn->tls.sct_requested = true;
634
635   return 1;
636 }
637 } // namespace
638
639 #  if !OPENSSL_1_1_1_API
640
641 namespace {
642 int legacy_sct_add_cb(SSL *ssl, unsigned int ext_type,
643                       const unsigned char **out, size_t *outlen, int *al,
644                       void *add_arg) {
645   return sct_add_cb(ssl, ext_type, 0, out, outlen, nullptr, 0, al, add_arg);
646 }
647 } // namespace
648
649 namespace {
650 void legacy_sct_free_cb(SSL *ssl, unsigned int ext_type,
651                         const unsigned char *out, void *add_arg) {
652   sct_free_cb(ssl, ext_type, 0, out, add_arg);
653 }
654 } // namespace
655
656 namespace {
657 int legacy_sct_parse_cb(SSL *ssl, unsigned int ext_type,
658                         const unsigned char *in, size_t inlen, int *al,
659                         void *parse_arg) {
660   return sct_parse_cb(ssl, ext_type, 0, in, inlen, nullptr, 0, al, parse_arg);
661 }
662 } // namespace
663
664 #  endif // !OPENSSL_1_1_1_API
665 #endif   // !LIBRESSL_IN_USE && OPENSSL_VERSION_NUMBER >= 0x10002000L
666
667 #ifndef OPENSSL_NO_PSK
668 namespace {
669 unsigned int psk_server_cb(SSL *ssl, const char *identity, unsigned char *psk,
670                            unsigned int max_psk_len) {
671   auto config = get_config();
672   auto &tlsconf = config->tls;
673
674   auto it = tlsconf.psk_secrets.find(StringRef{identity});
675   if (it == std::end(tlsconf.psk_secrets)) {
676     return 0;
677   }
678
679   auto &secret = (*it).second;
680   if (secret.size() > max_psk_len) {
681     LOG(ERROR) << "The size of PSK secret is " << secret.size()
682                << ", but the acceptable maximum size is" << max_psk_len;
683     return 0;
684   }
685
686   std::copy(std::begin(secret), std::end(secret), psk);
687
688   return static_cast<unsigned int>(secret.size());
689 }
690 } // namespace
691 #endif // !OPENSSL_NO_PSK
692
693 #ifndef OPENSSL_NO_PSK
694 namespace {
695 unsigned int psk_client_cb(SSL *ssl, const char *hint, char *identity_out,
696                            unsigned int max_identity_len, unsigned char *psk,
697                            unsigned int max_psk_len) {
698   auto config = get_config();
699   auto &tlsconf = config->tls;
700
701   auto &identity = tlsconf.client.psk.identity;
702   auto &secret = tlsconf.client.psk.secret;
703
704   if (identity.empty()) {
705     return 0;
706   }
707
708   if (identity.size() + 1 > max_identity_len) {
709     LOG(ERROR) << "The size of PSK identity is " << identity.size()
710                << ", but the acceptable maximum size is " << max_identity_len;
711     return 0;
712   }
713
714   if (secret.size() > max_psk_len) {
715     LOG(ERROR) << "The size of PSK secret is " << secret.size()
716                << ", but the acceptable maximum size is " << max_psk_len;
717     return 0;
718   }
719
720   *std::copy(std::begin(identity), std::end(identity), identity_out) = '\0';
721   std::copy(std::begin(secret), std::end(secret), psk);
722
723   return static_cast<unsigned int>(secret.size());
724 }
725 } // namespace
726 #endif // !OPENSSL_NO_PSK
727
728 struct TLSProtocol {
729   StringRef name;
730   long int mask;
731 };
732
733 constexpr TLSProtocol TLS_PROTOS[] = {
734     TLSProtocol{StringRef::from_lit("TLSv1.2"), SSL_OP_NO_TLSv1_2},
735     TLSProtocol{StringRef::from_lit("TLSv1.1"), SSL_OP_NO_TLSv1_1},
736     TLSProtocol{StringRef::from_lit("TLSv1.0"), SSL_OP_NO_TLSv1}};
737
738 long int create_tls_proto_mask(const std::vector<StringRef> &tls_proto_list) {
739   long int res = 0;
740
741   for (auto &supported : TLS_PROTOS) {
742     auto ok = false;
743     for (auto &name : tls_proto_list) {
744       if (util::strieq(supported.name, name)) {
745         ok = true;
746         break;
747       }
748     }
749     if (!ok) {
750       res |= supported.mask;
751     }
752   }
753   return res;
754 }
755
756 SSL_CTX *create_ssl_context(const char *private_key_file, const char *cert_file,
757                             const std::vector<uint8_t> &sct_data
758 #ifdef HAVE_NEVERBLEED
759                             ,
760                             neverbleed_t *nb
761 #endif // HAVE_NEVERBLEED
762 ) {
763   auto ssl_ctx = SSL_CTX_new(SSLv23_server_method());
764   if (!ssl_ctx) {
765     LOG(FATAL) << ERR_error_string(ERR_get_error(), nullptr);
766     DIE();
767   }
768
769   constexpr auto ssl_opts =
770       (SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) | SSL_OP_NO_SSLv2 |
771       SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION |
772       SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION | SSL_OP_SINGLE_ECDH_USE |
773       SSL_OP_SINGLE_DH_USE |
774       SSL_OP_CIPHER_SERVER_PREFERENCE
775 #if OPENSSL_1_1_1_API
776       // The reason for disabling built-in anti-replay in OpenSSL is
777       // that it only works if client gets back to the same server.
778       // The freshness check described in
779       // https://tools.ietf.org/html/rfc8446#section-8.3 is still
780       // performed.
781       | SSL_OP_NO_ANTI_REPLAY
782 #endif // OPENSSL_1_1_1_API
783       ;
784
785   auto config = mod_config();
786   auto &tlsconf = config->tls;
787
788   SSL_CTX_set_options(ssl_ctx, ssl_opts | tlsconf.tls_proto_mask);
789
790   if (nghttp2::tls::ssl_ctx_set_proto_versions(
791           ssl_ctx, tlsconf.min_proto_version, tlsconf.max_proto_version) != 0) {
792     LOG(FATAL) << "Could not set TLS protocol version";
793     DIE();
794   }
795
796   const unsigned char sid_ctx[] = "shrpx";
797   SSL_CTX_set_session_id_context(ssl_ctx, sid_ctx, sizeof(sid_ctx) - 1);
798   SSL_CTX_set_session_cache_mode(ssl_ctx, SSL_SESS_CACHE_SERVER);
799
800   if (!tlsconf.session_cache.memcached.host.empty()) {
801     SSL_CTX_sess_set_new_cb(ssl_ctx, tls_session_new_cb);
802     SSL_CTX_sess_set_get_cb(ssl_ctx, tls_session_get_cb);
803   }
804
805   SSL_CTX_set_timeout(ssl_ctx, tlsconf.session_timeout.count());
806
807   if (SSL_CTX_set_cipher_list(ssl_ctx, tlsconf.ciphers.c_str()) == 0) {
808     LOG(FATAL) << "SSL_CTX_set_cipher_list " << tlsconf.ciphers
809                << " failed: " << ERR_error_string(ERR_get_error(), nullptr);
810     DIE();
811   }
812
813 #if OPENSSL_1_1_1_API
814   if (SSL_CTX_set_ciphersuites(ssl_ctx, tlsconf.tls13_ciphers.c_str()) == 0) {
815     LOG(FATAL) << "SSL_CTX_set_ciphersuites " << tlsconf.tls13_ciphers
816                << " failed: " << ERR_error_string(ERR_get_error(), nullptr);
817     DIE();
818   }
819 #endif // OPENSSL_1_1_1_API
820
821 #ifndef OPENSSL_NO_EC
822 #  if !LIBRESSL_LEGACY_API && OPENSSL_VERSION_NUMBER >= 0x10002000L
823   if (SSL_CTX_set1_curves_list(ssl_ctx, tlsconf.ecdh_curves.c_str()) != 1) {
824     LOG(FATAL) << "SSL_CTX_set1_curves_list " << tlsconf.ecdh_curves
825                << " failed";
826     DIE();
827   }
828 #    if !defined(OPENSSL_IS_BORINGSSL) && !OPENSSL_1_1_API
829   // It looks like we need this function call for OpenSSL 1.0.2.  This
830   // function was deprecated in OpenSSL 1.1.0 and BoringSSL.
831   SSL_CTX_set_ecdh_auto(ssl_ctx, 1);
832 #    endif // !defined(OPENSSL_IS_BORINGSSL) && !OPENSSL_1_1_API
833 #  else    // LIBRESSL_LEGACY_API || OPENSSL_VERSION_NUBMER < 0x10002000L
834   // Use P-256, which is sufficiently secure at the time of this
835   // writing.
836   auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
837   if (ecdh == nullptr) {
838     LOG(FATAL) << "EC_KEY_new_by_curv_name failed: "
839                << ERR_error_string(ERR_get_error(), nullptr);
840     DIE();
841   }
842   SSL_CTX_set_tmp_ecdh(ssl_ctx, ecdh);
843   EC_KEY_free(ecdh);
844 #  endif   // LIBRESSL_LEGACY_API || OPENSSL_VERSION_NUBMER < 0x10002000L
845 #endif     // OPENSSL_NO_EC
846
847   if (!tlsconf.dh_param_file.empty()) {
848     // Read DH parameters from file
849     auto bio = BIO_new_file(tlsconf.dh_param_file.c_str(), "r");
850     if (bio == nullptr) {
851       LOG(FATAL) << "BIO_new_file() failed: "
852                  << ERR_error_string(ERR_get_error(), nullptr);
853       DIE();
854     }
855     auto dh = PEM_read_bio_DHparams(bio, nullptr, nullptr, nullptr);
856     if (dh == nullptr) {
857       LOG(FATAL) << "PEM_read_bio_DHparams() failed: "
858                  << ERR_error_string(ERR_get_error(), nullptr);
859       DIE();
860     }
861     SSL_CTX_set_tmp_dh(ssl_ctx, dh);
862     DH_free(dh);
863     BIO_free(bio);
864   }
865
866   SSL_CTX_set_mode(ssl_ctx, SSL_MODE_RELEASE_BUFFERS);
867
868   if (SSL_CTX_set_default_verify_paths(ssl_ctx) != 1) {
869     LOG(WARN) << "Could not load system trusted ca certificates: "
870               << ERR_error_string(ERR_get_error(), nullptr);
871   }
872
873   if (!tlsconf.cacert.empty()) {
874     if (SSL_CTX_load_verify_locations(ssl_ctx, tlsconf.cacert.c_str(),
875                                       nullptr) != 1) {
876       LOG(FATAL) << "Could not load trusted ca certificates from "
877                  << tlsconf.cacert << ": "
878                  << ERR_error_string(ERR_get_error(), nullptr);
879       DIE();
880     }
881   }
882
883   if (!tlsconf.private_key_passwd.empty()) {
884     SSL_CTX_set_default_passwd_cb(ssl_ctx, ssl_pem_passwd_cb);
885     SSL_CTX_set_default_passwd_cb_userdata(ssl_ctx, config);
886   }
887
888 #ifndef HAVE_NEVERBLEED
889   if (SSL_CTX_use_PrivateKey_file(ssl_ctx, private_key_file,
890                                   SSL_FILETYPE_PEM) != 1) {
891     LOG(FATAL) << "SSL_CTX_use_PrivateKey_file failed: "
892                << ERR_error_string(ERR_get_error(), nullptr);
893   }
894 #else  // HAVE_NEVERBLEED
895   std::array<char, NEVERBLEED_ERRBUF_SIZE> errbuf;
896   if (neverbleed_load_private_key_file(nb, ssl_ctx, private_key_file,
897                                        errbuf.data()) != 1) {
898     LOG(FATAL) << "neverbleed_load_private_key_file failed: " << errbuf.data();
899     DIE();
900   }
901 #endif // HAVE_NEVERBLEED
902
903   if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file) != 1) {
904     LOG(FATAL) << "SSL_CTX_use_certificate_file failed: "
905                << ERR_error_string(ERR_get_error(), nullptr);
906     DIE();
907   }
908   if (SSL_CTX_check_private_key(ssl_ctx) != 1) {
909     LOG(FATAL) << "SSL_CTX_check_private_key failed: "
910                << ERR_error_string(ERR_get_error(), nullptr);
911     DIE();
912   }
913   if (tlsconf.client_verify.enabled) {
914     if (!tlsconf.client_verify.cacert.empty()) {
915       if (SSL_CTX_load_verify_locations(
916               ssl_ctx, tlsconf.client_verify.cacert.c_str(), nullptr) != 1) {
917
918         LOG(FATAL) << "Could not load trusted ca certificates from "
919                    << tlsconf.client_verify.cacert << ": "
920                    << ERR_error_string(ERR_get_error(), nullptr);
921         DIE();
922       }
923       // It is heard that SSL_CTX_load_verify_locations() may leave
924       // error even though it returns success. See
925       // http://forum.nginx.org/read.php?29,242540
926       ERR_clear_error();
927       auto list = SSL_load_client_CA_file(tlsconf.client_verify.cacert.c_str());
928       if (!list) {
929         LOG(FATAL) << "Could not load ca certificates from "
930                    << tlsconf.client_verify.cacert << ": "
931                    << ERR_error_string(ERR_get_error(), nullptr);
932         DIE();
933       }
934       SSL_CTX_set_client_CA_list(ssl_ctx, list);
935     }
936     SSL_CTX_set_verify(ssl_ctx,
937                        SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE |
938                            SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
939                        verify_callback);
940   }
941   SSL_CTX_set_tlsext_servername_callback(ssl_ctx, servername_callback);
942   SSL_CTX_set_tlsext_ticket_key_cb(ssl_ctx, ticket_key_cb);
943 #ifndef OPENSSL_IS_BORINGSSL
944   SSL_CTX_set_tlsext_status_cb(ssl_ctx, ocsp_resp_cb);
945 #endif // OPENSSL_IS_BORINGSSL
946   SSL_CTX_set_info_callback(ssl_ctx, info_callback);
947
948 #ifdef OPENSSL_IS_BORINGSSL
949   SSL_CTX_set_early_data_enabled(ssl_ctx, 1);
950 #endif // OPENSSL_IS_BORINGSSL
951
952   // NPN advertisement
953 #ifndef OPENSSL_NO_NEXTPROTONEG
954   SSL_CTX_set_next_protos_advertised_cb(ssl_ctx, next_proto_cb, nullptr);
955 #endif // !OPENSSL_NO_NEXTPROTONEG
956 #if OPENSSL_VERSION_NUMBER >= 0x10002000L
957   // ALPN selection callback
958   SSL_CTX_set_alpn_select_cb(ssl_ctx, alpn_select_proto_cb, nullptr);
959 #endif // OPENSSL_VERSION_NUMBER >= 0x10002000L
960
961 #if !LIBRESSL_IN_USE && OPENSSL_VERSION_NUMBER >= 0x10002000L
962   // SSL_extension_supported(TLSEXT_TYPE_signed_certificate_timestamp)
963   // returns 1, which means OpenSSL internally handles it.  But
964   // OpenSSL handles signed_certificate_timestamp extension specially,
965   // and it lets custom handler to process the extension.
966   if (!sct_data.empty()) {
967 #  if OPENSSL_1_1_1_API
968     // It is not entirely clear to me that SSL_EXT_CLIENT_HELLO is
969     // required here.  sct_parse_cb is called without
970     // SSL_EXT_CLIENT_HELLO being set.  But the passed context value
971     // is SSL_EXT_CLIENT_HELLO.
972     if (SSL_CTX_add_custom_ext(
973             ssl_ctx, TLSEXT_TYPE_signed_certificate_timestamp,
974             SSL_EXT_CLIENT_HELLO | SSL_EXT_TLS1_2_SERVER_HELLO |
975                 SSL_EXT_TLS1_3_CERTIFICATE | SSL_EXT_IGNORE_ON_RESUMPTION,
976             sct_add_cb, sct_free_cb, nullptr, sct_parse_cb, nullptr) != 1) {
977       LOG(FATAL) << "SSL_CTX_add_custom_ext failed: "
978                  << ERR_error_string(ERR_get_error(), nullptr);
979       DIE();
980     }
981 #  else  // !OPENSSL_1_1_1_API
982     if (SSL_CTX_add_server_custom_ext(
983             ssl_ctx, TLSEXT_TYPE_signed_certificate_timestamp,
984             legacy_sct_add_cb, legacy_sct_free_cb, nullptr, legacy_sct_parse_cb,
985             nullptr) != 1) {
986       LOG(FATAL) << "SSL_CTX_add_server_custom_ext failed: "
987                  << ERR_error_string(ERR_get_error(), nullptr);
988       DIE();
989     }
990 #  endif // !OPENSSL_1_1_1_API
991   }
992 #endif // !LIBRESSL_IN_USE && OPENSSL_VERSION_NUMBER >= 0x10002000L
993
994 #if OPENSSL_1_1_1_API
995   if (SSL_CTX_set_max_early_data(ssl_ctx, tlsconf.max_early_data) != 1) {
996     LOG(FATAL) << "SSL_CTX_set_max_early_data failed: "
997                << ERR_error_string(ERR_get_error(), nullptr);
998     DIE();
999   }
1000 #endif // OPENSSL_1_1_1_API
1001
1002 #ifndef OPENSSL_NO_PSK
1003   SSL_CTX_set_psk_server_callback(ssl_ctx, psk_server_cb);
1004 #endif // !LIBRESSL_NO_PSK
1005
1006   auto tls_ctx_data = new TLSContextData();
1007   tls_ctx_data->cert_file = cert_file;
1008   tls_ctx_data->sct_data = sct_data;
1009
1010   SSL_CTX_set_app_data(ssl_ctx, tls_ctx_data);
1011
1012   return ssl_ctx;
1013 }
1014
1015 namespace {
1016 int select_h2_next_proto_cb(SSL *ssl, unsigned char **out,
1017                             unsigned char *outlen, const unsigned char *in,
1018                             unsigned int inlen, void *arg) {
1019   if (!util::select_h2(const_cast<const unsigned char **>(out), outlen, in,
1020                        inlen)) {
1021     return SSL_TLSEXT_ERR_NOACK;
1022   }
1023
1024   return SSL_TLSEXT_ERR_OK;
1025 }
1026 } // namespace
1027
1028 namespace {
1029 int select_h1_next_proto_cb(SSL *ssl, unsigned char **out,
1030                             unsigned char *outlen, const unsigned char *in,
1031                             unsigned int inlen, void *arg) {
1032   auto end = in + inlen;
1033   for (; in < end;) {
1034     if (util::streq(NGHTTP2_H1_1_ALPN, StringRef{in, in + (in[0] + 1)})) {
1035       *out = const_cast<unsigned char *>(in) + 1;
1036       *outlen = in[0];
1037       return SSL_TLSEXT_ERR_OK;
1038     }
1039     in += in[0] + 1;
1040   }
1041
1042   return SSL_TLSEXT_ERR_NOACK;
1043 }
1044 } // namespace
1045
1046 namespace {
1047 int select_next_proto_cb(SSL *ssl, unsigned char **out, unsigned char *outlen,
1048                          const unsigned char *in, unsigned int inlen,
1049                          void *arg) {
1050   auto conn = static_cast<Connection *>(SSL_get_app_data(ssl));
1051   switch (conn->proto) {
1052   case PROTO_HTTP1:
1053     return select_h1_next_proto_cb(ssl, out, outlen, in, inlen, arg);
1054   case PROTO_HTTP2:
1055     return select_h2_next_proto_cb(ssl, out, outlen, in, inlen, arg);
1056   default:
1057     return SSL_TLSEXT_ERR_NOACK;
1058   }
1059 }
1060 } // namespace
1061
1062 SSL_CTX *create_ssl_client_context(
1063 #ifdef HAVE_NEVERBLEED
1064     neverbleed_t *nb,
1065 #endif // HAVE_NEVERBLEED
1066     const StringRef &cacert, const StringRef &cert_file,
1067     const StringRef &private_key_file,
1068     int (*next_proto_select_cb)(SSL *s, unsigned char **out,
1069                                 unsigned char *outlen, const unsigned char *in,
1070                                 unsigned int inlen, void *arg)) {
1071   auto ssl_ctx = SSL_CTX_new(SSLv23_client_method());
1072   if (!ssl_ctx) {
1073     LOG(FATAL) << ERR_error_string(ERR_get_error(), nullptr);
1074     DIE();
1075   }
1076
1077   constexpr auto ssl_opts = (SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS) |
1078                             SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
1079                             SSL_OP_NO_COMPRESSION |
1080                             SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION;
1081
1082   auto &tlsconf = get_config()->tls;
1083
1084   SSL_CTX_set_options(ssl_ctx, ssl_opts | tlsconf.tls_proto_mask);
1085
1086   SSL_CTX_set_session_cache_mode(ssl_ctx, SSL_SESS_CACHE_CLIENT |
1087                                               SSL_SESS_CACHE_NO_INTERNAL_STORE);
1088   SSL_CTX_sess_set_new_cb(ssl_ctx, tls_session_client_new_cb);
1089
1090   if (nghttp2::tls::ssl_ctx_set_proto_versions(
1091           ssl_ctx, tlsconf.min_proto_version, tlsconf.max_proto_version) != 0) {
1092     LOG(FATAL) << "Could not set TLS protocol version";
1093     DIE();
1094   }
1095
1096   if (SSL_CTX_set_cipher_list(ssl_ctx, tlsconf.client.ciphers.c_str()) == 0) {
1097     LOG(FATAL) << "SSL_CTX_set_cipher_list " << tlsconf.client.ciphers
1098                << " failed: " << ERR_error_string(ERR_get_error(), nullptr);
1099     DIE();
1100   }
1101
1102 #if OPENSSL_1_1_1_API
1103   if (SSL_CTX_set_ciphersuites(ssl_ctx, tlsconf.client.tls13_ciphers.c_str()) ==
1104       0) {
1105     LOG(FATAL) << "SSL_CTX_set_ciphersuites " << tlsconf.client.tls13_ciphers
1106                << " failed: " << ERR_error_string(ERR_get_error(), nullptr);
1107     DIE();
1108   }
1109 #endif // OPENSSL_1_1_1_API
1110
1111   SSL_CTX_set_mode(ssl_ctx, SSL_MODE_RELEASE_BUFFERS);
1112
1113   if (SSL_CTX_set_default_verify_paths(ssl_ctx) != 1) {
1114     LOG(WARN) << "Could not load system trusted ca certificates: "
1115               << ERR_error_string(ERR_get_error(), nullptr);
1116   }
1117
1118   if (!cacert.empty()) {
1119     if (SSL_CTX_load_verify_locations(ssl_ctx, cacert.c_str(), nullptr) != 1) {
1120
1121       LOG(FATAL) << "Could not load trusted ca certificates from " << cacert
1122                  << ": " << ERR_error_string(ERR_get_error(), nullptr);
1123       DIE();
1124     }
1125   }
1126
1127   if (!tlsconf.insecure) {
1128     SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr);
1129   }
1130
1131   if (!cert_file.empty()) {
1132     if (SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_file.c_str()) != 1) {
1133
1134       LOG(FATAL) << "Could not load client certificate from " << cert_file
1135                  << ": " << ERR_error_string(ERR_get_error(), nullptr);
1136       DIE();
1137     }
1138   }
1139
1140   if (!private_key_file.empty()) {
1141 #ifndef HAVE_NEVERBLEED
1142     if (SSL_CTX_use_PrivateKey_file(ssl_ctx, private_key_file.c_str(),
1143                                     SSL_FILETYPE_PEM) != 1) {
1144       LOG(FATAL) << "Could not load client private key from "
1145                  << private_key_file << ": "
1146                  << ERR_error_string(ERR_get_error(), nullptr);
1147       DIE();
1148     }
1149 #else  // HAVE_NEVERBLEED
1150     std::array<char, NEVERBLEED_ERRBUF_SIZE> errbuf;
1151     if (neverbleed_load_private_key_file(nb, ssl_ctx, private_key_file.c_str(),
1152                                          errbuf.data()) != 1) {
1153       LOG(FATAL) << "neverbleed_load_private_key_file: could not load client "
1154                     "private key from "
1155                  << private_key_file << ": " << errbuf.data();
1156       DIE();
1157     }
1158 #endif // HAVE_NEVERBLEED
1159   }
1160
1161 #ifndef OPENSSL_NO_PSK
1162   SSL_CTX_set_psk_client_callback(ssl_ctx, psk_client_cb);
1163 #endif // !OPENSSL_NO_PSK
1164
1165   // NPN selection callback.  This is required to set SSL_CTX because
1166   // OpenSSL does not offer SSL_set_next_proto_select_cb.
1167 #ifndef OPENSSL_NO_NEXTPROTONEG
1168   SSL_CTX_set_next_proto_select_cb(ssl_ctx, next_proto_select_cb, nullptr);
1169 #endif // !OPENSSL_NO_NEXTPROTONEG
1170
1171   return ssl_ctx;
1172 }
1173
1174 SSL *create_ssl(SSL_CTX *ssl_ctx) {
1175   auto ssl = SSL_new(ssl_ctx);
1176   if (!ssl) {
1177     LOG(ERROR) << "SSL_new() failed: "
1178                << ERR_error_string(ERR_get_error(), nullptr);
1179     return nullptr;
1180   }
1181
1182   return ssl;
1183 }
1184
1185 ClientHandler *accept_connection(Worker *worker, int fd, sockaddr *addr,
1186                                  int addrlen, const UpstreamAddr *faddr) {
1187   std::array<char, NI_MAXHOST> host;
1188   std::array<char, NI_MAXSERV> service;
1189   int rv;
1190
1191   if (addr->sa_family == AF_UNIX) {
1192     std::copy_n("localhost", sizeof("localhost"), std::begin(host));
1193     service[0] = '\0';
1194   } else {
1195     rv = getnameinfo(addr, addrlen, host.data(), host.size(), service.data(),
1196                      service.size(), NI_NUMERICHOST | NI_NUMERICSERV);
1197     if (rv != 0) {
1198       LOG(ERROR) << "getnameinfo() failed: " << gai_strerror(rv);
1199
1200       return nullptr;
1201     }
1202
1203     rv = util::make_socket_nodelay(fd);
1204     if (rv == -1) {
1205       LOG(WARN) << "Setting option TCP_NODELAY failed: errno=" << errno;
1206     }
1207   }
1208   SSL *ssl = nullptr;
1209   if (faddr->tls) {
1210     auto ssl_ctx = worker->get_sv_ssl_ctx();
1211
1212     assert(ssl_ctx);
1213
1214     ssl = create_ssl(ssl_ctx);
1215     if (!ssl) {
1216       return nullptr;
1217     }
1218     // Disable TLS session ticket if we don't have working ticket
1219     // keys.
1220     if (!worker->get_ticket_keys()) {
1221       SSL_set_options(ssl, SSL_OP_NO_TICKET);
1222     }
1223   }
1224
1225   return new ClientHandler(worker, fd, ssl, StringRef{host.data()},
1226                            StringRef{service.data()}, addr->sa_family, faddr);
1227 }
1228
1229 bool tls_hostname_match(const StringRef &pattern, const StringRef &hostname) {
1230   auto ptWildcard = std::find(std::begin(pattern), std::end(pattern), '*');
1231   if (ptWildcard == std::end(pattern)) {
1232     return util::strieq(pattern, hostname);
1233   }
1234
1235   auto ptLeftLabelEnd = std::find(std::begin(pattern), std::end(pattern), '.');
1236   auto wildcardEnabled = true;
1237   // Do case-insensitive match. At least 2 dots are required to enable
1238   // wildcard match. Also wildcard must be in the left-most label.
1239   // Don't attempt to match a presented identifier where the wildcard
1240   // character is embedded within an A-label.
1241   if (ptLeftLabelEnd == std::end(pattern) ||
1242       std::find(ptLeftLabelEnd + 1, std::end(pattern), '.') ==
1243           std::end(pattern) ||
1244       ptLeftLabelEnd < ptWildcard || util::istarts_with_l(pattern, "xn--")) {
1245     wildcardEnabled = false;
1246   }
1247
1248   if (!wildcardEnabled) {
1249     return util::strieq(pattern, hostname);
1250   }
1251
1252   auto hnLeftLabelEnd =
1253       std::find(std::begin(hostname), std::end(hostname), '.');
1254   if (hnLeftLabelEnd == std::end(hostname) ||
1255       !util::strieq(StringRef{ptLeftLabelEnd, std::end(pattern)},
1256                     StringRef{hnLeftLabelEnd, std::end(hostname)})) {
1257     return false;
1258   }
1259   // Perform wildcard match. Here '*' must match at least one
1260   // character.
1261   if (hnLeftLabelEnd - std::begin(hostname) <
1262       ptLeftLabelEnd - std::begin(pattern)) {
1263     return false;
1264   }
1265   return util::istarts_with(StringRef{std::begin(hostname), hnLeftLabelEnd},
1266                             StringRef{std::begin(pattern), ptWildcard}) &&
1267          util::iends_with(StringRef{std::begin(hostname), hnLeftLabelEnd},
1268                           StringRef{ptWildcard + 1, ptLeftLabelEnd});
1269 }
1270
1271 namespace {
1272 // if return value is not empty, StringRef.c_str() must be freed using
1273 // OPENSSL_free().
1274 StringRef get_common_name(X509 *cert) {
1275   auto subjectname = X509_get_subject_name(cert);
1276   if (!subjectname) {
1277     LOG(WARN) << "Could not get X509 name object from the certificate.";
1278     return StringRef{};
1279   }
1280   int lastpos = -1;
1281   for (;;) {
1282     lastpos = X509_NAME_get_index_by_NID(subjectname, NID_commonName, lastpos);
1283     if (lastpos == -1) {
1284       break;
1285     }
1286     auto entry = X509_NAME_get_entry(subjectname, lastpos);
1287
1288     unsigned char *p;
1289     auto plen = ASN1_STRING_to_UTF8(&p, X509_NAME_ENTRY_get_data(entry));
1290     if (plen < 0) {
1291       continue;
1292     }
1293     if (std::find(p, p + plen, '\0') != p + plen) {
1294       // Embedded NULL is not permitted.
1295       continue;
1296     }
1297     if (plen == 0) {
1298       LOG(WARN) << "X509 name is empty";
1299       OPENSSL_free(p);
1300       continue;
1301     }
1302
1303     return StringRef{p, static_cast<size_t>(plen)};
1304   }
1305   return StringRef{};
1306 }
1307 } // namespace
1308
1309 namespace {
1310 int verify_numeric_hostname(X509 *cert, const StringRef &hostname,
1311                             const Address *addr) {
1312   const void *saddr;
1313   switch (addr->su.storage.ss_family) {
1314   case AF_INET:
1315     saddr = &addr->su.in.sin_addr;
1316     break;
1317   case AF_INET6:
1318     saddr = &addr->su.in6.sin6_addr;
1319     break;
1320   default:
1321     return -1;
1322   }
1323
1324   auto altnames = static_cast<GENERAL_NAMES *>(
1325       X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
1326   if (altnames) {
1327     auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
1328     size_t n = sk_GENERAL_NAME_num(altnames);
1329     auto ip_found = false;
1330     for (size_t i = 0; i < n; ++i) {
1331       auto altname = sk_GENERAL_NAME_value(altnames, i);
1332       if (altname->type != GEN_IPADD) {
1333         continue;
1334       }
1335
1336       auto ip_addr = altname->d.iPAddress->data;
1337       if (!ip_addr) {
1338         continue;
1339       }
1340       size_t ip_addrlen = altname->d.iPAddress->length;
1341
1342       ip_found = true;
1343       if (addr->len == ip_addrlen && memcmp(saddr, ip_addr, ip_addrlen) == 0) {
1344         return 0;
1345       }
1346     }
1347
1348     if (ip_found) {
1349       return -1;
1350     }
1351   }
1352
1353   auto cn = get_common_name(cert);
1354   if (cn.empty()) {
1355     return -1;
1356   }
1357
1358   // cn is not NULL terminated
1359   auto rv = util::streq(hostname, cn);
1360   OPENSSL_free(const_cast<char *>(cn.c_str()));
1361
1362   if (rv) {
1363     return 0;
1364   }
1365
1366   return -1;
1367 }
1368 } // namespace
1369
1370 namespace {
1371 int verify_hostname(X509 *cert, const StringRef &hostname,
1372                     const Address *addr) {
1373   if (util::numeric_host(hostname.c_str())) {
1374     return verify_numeric_hostname(cert, hostname, addr);
1375   }
1376
1377   auto altnames = static_cast<GENERAL_NAMES *>(
1378       X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
1379   if (altnames) {
1380     auto dns_found = false;
1381     auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
1382     size_t n = sk_GENERAL_NAME_num(altnames);
1383     for (size_t i = 0; i < n; ++i) {
1384       auto altname = sk_GENERAL_NAME_value(altnames, i);
1385       if (altname->type != GEN_DNS) {
1386         continue;
1387       }
1388
1389       auto name = ASN1_STRING_get0_data(altname->d.ia5);
1390       if (!name) {
1391         continue;
1392       }
1393
1394       auto len = ASN1_STRING_length(altname->d.ia5);
1395       if (len == 0) {
1396         continue;
1397       }
1398       if (std::find(name, name + len, '\0') != name + len) {
1399         // Embedded NULL is not permitted.
1400         continue;
1401       }
1402
1403       if (name[len - 1] == '.') {
1404         --len;
1405         if (len == 0) {
1406           continue;
1407         }
1408       }
1409
1410       dns_found = true;
1411
1412       if (tls_hostname_match(StringRef{name, static_cast<size_t>(len)},
1413                              hostname)) {
1414         return 0;
1415       }
1416     }
1417
1418     // RFC 6125, section 6.4.4. says that client MUST not seek a match
1419     // for CN if a dns dNSName is found.
1420     if (dns_found) {
1421       return -1;
1422     }
1423   }
1424
1425   auto cn = get_common_name(cert);
1426   if (cn.empty()) {
1427     return -1;
1428   }
1429
1430   if (cn[cn.size() - 1] == '.') {
1431     if (cn.size() == 1) {
1432       OPENSSL_free(const_cast<char *>(cn.c_str()));
1433
1434       return -1;
1435     }
1436     cn = StringRef{cn.c_str(), cn.size() - 1};
1437   }
1438
1439   auto rv = tls_hostname_match(cn, hostname);
1440   OPENSSL_free(const_cast<char *>(cn.c_str()));
1441
1442   return rv ? 0 : -1;
1443 }
1444 } // namespace
1445
1446 int check_cert(SSL *ssl, const Address *addr, const StringRef &host) {
1447   auto cert = SSL_get_peer_certificate(ssl);
1448   if (!cert) {
1449     // By the protocol definition, TLS server always sends certificate
1450     // if it has.  If certificate cannot be retrieved, authentication
1451     // without certificate is used, such as PSK.
1452     return 0;
1453   }
1454   auto cert_deleter = defer(X509_free, cert);
1455
1456   if (verify_hostname(cert, host, addr) != 0) {
1457     LOG(ERROR) << "Certificate verification failed: hostname does not match";
1458     return -1;
1459   }
1460   return 0;
1461 }
1462
1463 int check_cert(SSL *ssl, const DownstreamAddr *addr, const Address *raddr) {
1464   auto hostname =
1465       addr->sni.empty() ? StringRef{addr->host} : StringRef{addr->sni};
1466   return check_cert(ssl, raddr, hostname);
1467 }
1468
1469 CertLookupTree::CertLookupTree() {}
1470
1471 ssize_t CertLookupTree::add_cert(const StringRef &hostname, size_t idx) {
1472   std::array<uint8_t, NI_MAXHOST> buf;
1473
1474   // NI_MAXHOST includes terminal NULL byte
1475   if (hostname.empty() || hostname.size() + 1 > buf.size()) {
1476     return -1;
1477   }
1478
1479   auto wildcard_it = std::find(std::begin(hostname), std::end(hostname), '*');
1480   if (wildcard_it != std::end(hostname) &&
1481       wildcard_it + 1 != std::end(hostname)) {
1482     auto wildcard_prefix = StringRef{std::begin(hostname), wildcard_it};
1483     auto wildcard_suffix = StringRef{wildcard_it + 1, std::end(hostname)};
1484
1485     auto rev_suffix = StringRef{std::begin(buf),
1486                                 std::reverse_copy(std::begin(wildcard_suffix),
1487                                                   std::end(wildcard_suffix),
1488                                                   std::begin(buf))};
1489
1490     WildcardPattern *wpat;
1491
1492     if (wildcard_patterns_.size() !=
1493         rev_wildcard_router_.add_route(rev_suffix, wildcard_patterns_.size())) {
1494       auto wcidx = rev_wildcard_router_.match(rev_suffix);
1495
1496       assert(wcidx != -1);
1497
1498       wpat = &wildcard_patterns_[wcidx];
1499     } else {
1500       wildcard_patterns_.emplace_back();
1501       wpat = &wildcard_patterns_.back();
1502     }
1503
1504     auto rev_prefix = StringRef{std::begin(buf),
1505                                 std::reverse_copy(std::begin(wildcard_prefix),
1506                                                   std::end(wildcard_prefix),
1507                                                   std::begin(buf))};
1508
1509     for (auto &p : wpat->rev_prefix) {
1510       if (p.prefix == rev_prefix) {
1511         return p.idx;
1512       }
1513     }
1514
1515     wpat->rev_prefix.emplace_back(rev_prefix, idx);
1516
1517     return idx;
1518   }
1519
1520   return router_.add_route(hostname, idx);
1521 }
1522
1523 ssize_t CertLookupTree::lookup(const StringRef &hostname) {
1524   std::array<uint8_t, NI_MAXHOST> buf;
1525
1526   // NI_MAXHOST includes terminal NULL byte
1527   if (hostname.empty() || hostname.size() + 1 > buf.size()) {
1528     return -1;
1529   }
1530
1531   // Always prefer exact match
1532   auto idx = router_.match(hostname);
1533   if (idx != -1) {
1534     return idx;
1535   }
1536
1537   if (wildcard_patterns_.empty()) {
1538     return -1;
1539   }
1540
1541   ssize_t best_idx = -1;
1542   size_t best_prefixlen = 0;
1543   const RNode *last_node = nullptr;
1544
1545   auto rev_host = StringRef{
1546       std::begin(buf), std::reverse_copy(std::begin(hostname),
1547                                          std::end(hostname), std::begin(buf))};
1548
1549   for (;;) {
1550     size_t nread = 0;
1551
1552     auto wcidx =
1553         rev_wildcard_router_.match_prefix(&nread, &last_node, rev_host);
1554     if (wcidx == -1) {
1555       return best_idx;
1556     }
1557
1558     // '*' must match at least one byte
1559     if (nread == rev_host.size()) {
1560       return best_idx;
1561     }
1562
1563     rev_host = StringRef{std::begin(rev_host) + nread, std::end(rev_host)};
1564
1565     auto rev_prefix = StringRef{std::begin(rev_host) + 1, std::end(rev_host)};
1566
1567     auto &wpat = wildcard_patterns_[wcidx];
1568     for (auto &wprefix : wpat.rev_prefix) {
1569       if (!util::ends_with(rev_prefix, wprefix.prefix)) {
1570         continue;
1571       }
1572
1573       auto prefixlen =
1574           wprefix.prefix.size() +
1575           (reinterpret_cast<const uint8_t *>(&rev_host[0]) - &buf[0]);
1576
1577       // Breaking a tie with longer suffix
1578       if (prefixlen < best_prefixlen) {
1579         continue;
1580       }
1581
1582       best_idx = wprefix.idx;
1583       best_prefixlen = prefixlen;
1584     }
1585   }
1586 }
1587
1588 void CertLookupTree::dump() const {
1589   std::cerr << "exact:" << std::endl;
1590   router_.dump();
1591   std::cerr << "wildcard suffix (reversed):" << std::endl;
1592   rev_wildcard_router_.dump();
1593 }
1594
1595 int cert_lookup_tree_add_ssl_ctx(
1596     CertLookupTree *lt, std::vector<std::vector<SSL_CTX *>> &indexed_ssl_ctx,
1597     SSL_CTX *ssl_ctx) {
1598   std::array<uint8_t, NI_MAXHOST> buf;
1599
1600 #if LIBRESSL_2_7_API ||                                                        \
1601     (!LIBRESSL_IN_USE && OPENSSL_VERSION_NUMBER >= 0x10002000L)
1602   auto cert = SSL_CTX_get0_certificate(ssl_ctx);
1603 #else  // !LIBRESSL_2_7_API && OPENSSL_VERSION_NUMBER < 0x10002000L
1604   auto tls_ctx_data =
1605       static_cast<TLSContextData *>(SSL_CTX_get_app_data(ssl_ctx));
1606   auto cert = load_certificate(tls_ctx_data->cert_file);
1607   auto cert_deleter = defer(X509_free, cert);
1608 #endif // !LIBRESSL_2_7_API && OPENSSL_VERSION_NUMBER < 0x10002000L
1609
1610   auto altnames = static_cast<GENERAL_NAMES *>(
1611       X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
1612   if (altnames) {
1613     auto altnames_deleter = defer(GENERAL_NAMES_free, altnames);
1614     size_t n = sk_GENERAL_NAME_num(altnames);
1615     auto dns_found = false;
1616     for (size_t i = 0; i < n; ++i) {
1617       auto altname = sk_GENERAL_NAME_value(altnames, i);
1618       if (altname->type != GEN_DNS) {
1619         continue;
1620       }
1621
1622       auto name = ASN1_STRING_get0_data(altname->d.ia5);
1623       if (!name) {
1624         continue;
1625       }
1626
1627       auto len = ASN1_STRING_length(altname->d.ia5);
1628       if (len == 0) {
1629         continue;
1630       }
1631       if (std::find(name, name + len, '\0') != name + len) {
1632         // Embedded NULL is not permitted.
1633         continue;
1634       }
1635
1636       if (name[len - 1] == '.') {
1637         --len;
1638         if (len == 0) {
1639           continue;
1640         }
1641       }
1642
1643       dns_found = true;
1644
1645       if (static_cast<size_t>(len) + 1 > buf.size()) {
1646         continue;
1647       }
1648
1649       auto end_buf = std::copy_n(name, len, std::begin(buf));
1650       util::inp_strlower(std::begin(buf), end_buf);
1651
1652       auto idx = lt->add_cert(StringRef{std::begin(buf), end_buf},
1653                               indexed_ssl_ctx.size());
1654       if (idx == -1) {
1655         continue;
1656       }
1657
1658       if (static_cast<size_t>(idx) < indexed_ssl_ctx.size()) {
1659         indexed_ssl_ctx[idx].push_back(ssl_ctx);
1660       } else {
1661         assert(static_cast<size_t>(idx) == indexed_ssl_ctx.size());
1662         indexed_ssl_ctx.emplace_back(std::vector<SSL_CTX *>{ssl_ctx});
1663       }
1664     }
1665
1666     // Don't bother CN if we have dNSName.
1667     if (dns_found) {
1668       return 0;
1669     }
1670   }
1671
1672   auto cn = get_common_name(cert);
1673   if (cn.empty()) {
1674     return 0;
1675   }
1676
1677   if (cn[cn.size() - 1] == '.') {
1678     if (cn.size() == 1) {
1679       OPENSSL_free(const_cast<char *>(cn.c_str()));
1680
1681       return 0;
1682     }
1683
1684     cn = StringRef{cn.c_str(), cn.size() - 1};
1685   }
1686
1687   auto end_buf = std::copy(std::begin(cn), std::end(cn), std::begin(buf));
1688
1689   OPENSSL_free(const_cast<char *>(cn.c_str()));
1690
1691   util::inp_strlower(std::begin(buf), end_buf);
1692
1693   auto idx =
1694       lt->add_cert(StringRef{std::begin(buf), end_buf}, indexed_ssl_ctx.size());
1695   if (idx == -1) {
1696     return 0;
1697   }
1698
1699   if (static_cast<size_t>(idx) < indexed_ssl_ctx.size()) {
1700     indexed_ssl_ctx[idx].push_back(ssl_ctx);
1701   } else {
1702     assert(static_cast<size_t>(idx) == indexed_ssl_ctx.size());
1703     indexed_ssl_ctx.emplace_back(std::vector<SSL_CTX *>{ssl_ctx});
1704   }
1705
1706   return 0;
1707 }
1708
1709 bool in_proto_list(const std::vector<StringRef> &protos,
1710                    const StringRef &needle) {
1711   for (auto &proto : protos) {
1712     if (util::streq(proto, needle)) {
1713       return true;
1714     }
1715   }
1716   return false;
1717 }
1718
1719 bool upstream_tls_enabled(const ConnectionConfig &connconf) {
1720   const auto &faddrs = connconf.listener.addrs;
1721   return std::any_of(std::begin(faddrs), std::end(faddrs),
1722                      [](const UpstreamAddr &faddr) { return faddr.tls; });
1723 }
1724
1725 X509 *load_certificate(const char *filename) {
1726   auto bio = BIO_new(BIO_s_file());
1727   if (!bio) {
1728     fprintf(stderr, "BIO_new() failed\n");
1729     return nullptr;
1730   }
1731   auto bio_deleter = defer(BIO_vfree, bio);
1732   if (!BIO_read_filename(bio, filename)) {
1733     fprintf(stderr, "Could not read certificate file '%s'\n", filename);
1734     return nullptr;
1735   }
1736   auto cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
1737   if (!cert) {
1738     fprintf(stderr, "Could not read X509 structure from file '%s'\n", filename);
1739     return nullptr;
1740   }
1741
1742   return cert;
1743 }
1744
1745 SSL_CTX *
1746 setup_server_ssl_context(std::vector<SSL_CTX *> &all_ssl_ctx,
1747                          std::vector<std::vector<SSL_CTX *>> &indexed_ssl_ctx,
1748                          CertLookupTree *cert_tree
1749 #ifdef HAVE_NEVERBLEED
1750                          ,
1751                          neverbleed_t *nb
1752 #endif // HAVE_NEVERBLEED
1753 ) {
1754   auto config = get_config();
1755
1756   if (!upstream_tls_enabled(config->conn)) {
1757     return nullptr;
1758   }
1759
1760   auto &tlsconf = config->tls;
1761
1762   auto ssl_ctx = create_ssl_context(tlsconf.private_key_file.c_str(),
1763                                     tlsconf.cert_file.c_str(), tlsconf.sct_data
1764 #ifdef HAVE_NEVERBLEED
1765                                     ,
1766                                     nb
1767 #endif // HAVE_NEVERBLEED
1768   );
1769
1770   all_ssl_ctx.push_back(ssl_ctx);
1771
1772   assert(cert_tree);
1773
1774   if (cert_lookup_tree_add_ssl_ctx(cert_tree, indexed_ssl_ctx, ssl_ctx) == -1) {
1775     LOG(FATAL) << "Failed to add default certificate.";
1776     DIE();
1777   }
1778
1779   for (auto &c : tlsconf.subcerts) {
1780     auto ssl_ctx = create_ssl_context(c.private_key_file.c_str(),
1781                                       c.cert_file.c_str(), c.sct_data
1782 #ifdef HAVE_NEVERBLEED
1783                                       ,
1784                                       nb
1785 #endif // HAVE_NEVERBLEED
1786     );
1787     all_ssl_ctx.push_back(ssl_ctx);
1788
1789     if (cert_lookup_tree_add_ssl_ctx(cert_tree, indexed_ssl_ctx, ssl_ctx) ==
1790         -1) {
1791       LOG(FATAL) << "Failed to add sub certificate.";
1792       DIE();
1793     }
1794   }
1795
1796   return ssl_ctx;
1797 }
1798
1799 SSL_CTX *setup_downstream_client_ssl_context(
1800 #ifdef HAVE_NEVERBLEED
1801     neverbleed_t *nb
1802 #endif // HAVE_NEVERBLEED
1803 ) {
1804   auto &tlsconf = get_config()->tls;
1805
1806   return create_ssl_client_context(
1807 #ifdef HAVE_NEVERBLEED
1808       nb,
1809 #endif // HAVE_NEVERBLEED
1810       tlsconf.cacert, tlsconf.client.cert_file, tlsconf.client.private_key_file,
1811       select_next_proto_cb);
1812 }
1813
1814 void setup_downstream_http2_alpn(SSL *ssl) {
1815 #if OPENSSL_VERSION_NUMBER >= 0x10002000L
1816   // ALPN advertisement
1817   auto alpn = util::get_default_alpn();
1818   SSL_set_alpn_protos(ssl, alpn.data(), alpn.size());
1819 #endif // OPENSSL_VERSION_NUMBER >= 0x10002000L
1820 }
1821
1822 void setup_downstream_http1_alpn(SSL *ssl) {
1823 #if OPENSSL_VERSION_NUMBER >= 0x10002000L
1824   // ALPN advertisement
1825   SSL_set_alpn_protos(ssl, NGHTTP2_H1_1_ALPN.byte(), NGHTTP2_H1_1_ALPN.size());
1826 #endif // OPENSSL_VERSION_NUMBER >= 0x10002000L
1827 }
1828
1829 std::unique_ptr<CertLookupTree> create_cert_lookup_tree() {
1830   auto config = get_config();
1831   if (!upstream_tls_enabled(config->conn)) {
1832     return nullptr;
1833   }
1834   return make_unique<CertLookupTree>();
1835 }
1836
1837 namespace {
1838 std::vector<uint8_t> serialize_ssl_session(SSL_SESSION *session) {
1839   auto len = i2d_SSL_SESSION(session, nullptr);
1840   auto buf = std::vector<uint8_t>(len);
1841   auto p = buf.data();
1842   i2d_SSL_SESSION(session, &p);
1843
1844   return buf;
1845 }
1846 } // namespace
1847
1848 void try_cache_tls_session(TLSSessionCache *cache, SSL_SESSION *session,
1849                            ev_tstamp t) {
1850   if (cache->last_updated + 1_min > t) {
1851     if (LOG_ENABLED(INFO)) {
1852       LOG(INFO) << "Client session cache entry is still fresh.";
1853     }
1854     return;
1855   }
1856
1857   if (LOG_ENABLED(INFO)) {
1858     LOG(INFO) << "Update client cache entry "
1859               << "timestamp = " << t;
1860   }
1861
1862   cache->session_data = serialize_ssl_session(session);
1863   cache->last_updated = t;
1864 }
1865
1866 SSL_SESSION *reuse_tls_session(const TLSSessionCache &cache) {
1867   if (cache.session_data.empty()) {
1868     return nullptr;
1869   }
1870
1871   auto p = cache.session_data.data();
1872   return d2i_SSL_SESSION(nullptr, &p, cache.session_data.size());
1873 }
1874
1875 int proto_version_from_string(const StringRef &v) {
1876 #ifdef TLS1_3_VERSION
1877   if (util::strieq_l("TLSv1.3", v)) {
1878     return TLS1_3_VERSION;
1879   }
1880 #endif // TLS1_3_VERSION
1881   if (util::strieq_l("TLSv1.2", v)) {
1882     return TLS1_2_VERSION;
1883   }
1884   if (util::strieq_l("TLSv1.1", v)) {
1885     return TLS1_1_VERSION;
1886   }
1887   if (util::strieq_l("TLSv1.0", v)) {
1888     return TLS1_VERSION;
1889   }
1890   return -1;
1891 }
1892
1893 int verify_ocsp_response(SSL_CTX *ssl_ctx, const uint8_t *ocsp_resp,
1894                          size_t ocsp_resplen) {
1895
1896 #if !defined(OPENSSL_NO_OCSP) && !LIBRESSL_IN_USE &&                           \
1897     OPENSSL_VERSION_NUMBER >= 0x10002000L
1898   int rv;
1899
1900   STACK_OF(X509) * chain_certs;
1901   SSL_CTX_get0_chain_certs(ssl_ctx, &chain_certs);
1902
1903   auto resp = d2i_OCSP_RESPONSE(nullptr, &ocsp_resp, ocsp_resplen);
1904   if (resp == nullptr) {
1905     LOG(ERROR) << "d2i_OCSP_RESPONSE failed";
1906     return -1;
1907   }
1908   auto resp_deleter = defer(OCSP_RESPONSE_free, resp);
1909
1910   if (OCSP_response_status(resp) != OCSP_RESPONSE_STATUS_SUCCESSFUL) {
1911     LOG(ERROR) << "OCSP response status is not successful";
1912     return -1;
1913   }
1914
1915   ERR_clear_error();
1916
1917   auto bs = OCSP_response_get1_basic(resp);
1918   if (bs == nullptr) {
1919     LOG(ERROR) << "OCSP_response_get1_basic failed: "
1920                << ERR_error_string(ERR_get_error(), nullptr);
1921     return -1;
1922   }
1923   auto bs_deleter = defer(OCSP_BASICRESP_free, bs);
1924
1925   auto store = SSL_CTX_get_cert_store(ssl_ctx);
1926
1927   ERR_clear_error();
1928
1929   rv = OCSP_basic_verify(bs, chain_certs, store, 0);
1930
1931   if (rv != 1) {
1932     LOG(ERROR) << "OCSP_basic_verify failed: "
1933                << ERR_error_string(ERR_get_error(), nullptr);
1934     return -1;
1935   }
1936
1937   auto sresp = OCSP_resp_get0(bs, 0);
1938   if (sresp == nullptr) {
1939     LOG(ERROR) << "OCSP response verification failed: no single response";
1940     return -1;
1941   }
1942
1943 #  if OPENSSL_1_1_API
1944   auto certid = OCSP_SINGLERESP_get0_id(sresp);
1945 #  else  // !OPENSSL_1_1_API
1946   auto certid = sresp->certId;
1947 #  endif // !OPENSSL_1_1_API
1948   assert(certid != nullptr);
1949
1950   ASN1_INTEGER *serial;
1951   rv = OCSP_id_get0_info(nullptr, nullptr, nullptr, &serial,
1952                          const_cast<OCSP_CERTID *>(certid));
1953   if (rv != 1) {
1954     LOG(ERROR) << "OCSP_id_get0_info failed";
1955     return -1;
1956   }
1957
1958   if (serial == nullptr) {
1959     LOG(ERROR) << "OCSP response does not contain serial number";
1960     return -1;
1961   }
1962
1963   auto cert = SSL_CTX_get0_certificate(ssl_ctx);
1964   auto cert_serial = X509_get_serialNumber(cert);
1965
1966   if (ASN1_INTEGER_cmp(cert_serial, serial)) {
1967     LOG(ERROR) << "OCSP verification serial numbers do not match";
1968     return -1;
1969   }
1970
1971   if (LOG_ENABLED(INFO)) {
1972     LOG(INFO) << "OCSP verification succeeded";
1973   }
1974 #endif // !defined(OPENSSL_NO_OCSP) && !LIBRESSL_IN_USE
1975        // && OPENSSL_VERSION_NUMBER >= 0x10002000L
1976
1977   return 0;
1978 }
1979
1980 ssize_t get_x509_fingerprint(uint8_t *dst, size_t dstlen, const X509 *x,
1981                              const EVP_MD *md) {
1982   unsigned int len = dstlen;
1983   if (X509_digest(x, md, dst, &len) != 1) {
1984     return -1;
1985   }
1986   return len;
1987 }
1988
1989 namespace {
1990 StringRef get_x509_name(BlockAllocator &balloc, X509_NAME *nm) {
1991   auto b = BIO_new(BIO_s_mem());
1992   if (!b) {
1993     return StringRef{};
1994   }
1995
1996   auto b_deleter = defer(BIO_free, b);
1997
1998   // Not documented, but it seems that X509_NAME_print_ex returns the
1999   // number of bytes written into b.
2000   auto slen = X509_NAME_print_ex(b, nm, 0, XN_FLAG_RFC2253);
2001   if (slen <= 0) {
2002     return StringRef{};
2003   }
2004
2005   auto iov = make_byte_ref(balloc, slen + 1);
2006   BIO_read(b, iov.base, slen);
2007   iov.base[slen] = '\0';
2008   return StringRef{iov.base, static_cast<size_t>(slen)};
2009 }
2010 } // namespace
2011
2012 StringRef get_x509_subject_name(BlockAllocator &balloc, X509 *x) {
2013   return get_x509_name(balloc, X509_get_subject_name(x));
2014 }
2015
2016 StringRef get_x509_issuer_name(BlockAllocator &balloc, X509 *x) {
2017   return get_x509_name(balloc, X509_get_issuer_name(x));
2018 }
2019
2020 #ifdef WORDS_BIGENDIAN
2021 #  define bswap64(N) (N)
2022 #else /* !WORDS_BIGENDIAN */
2023 #  define bswap64(N)                                                           \
2024     ((uint64_t)(ntohl((uint32_t)(N))) << 32 | ntohl((uint32_t)((N) >> 32)))
2025 #endif /* !WORDS_BIGENDIAN */
2026
2027 StringRef get_x509_serial(BlockAllocator &balloc, X509 *x) {
2028 #if OPENSSL_1_1_API
2029   auto sn = X509_get0_serialNumber(x);
2030   uint64_t r;
2031   if (ASN1_INTEGER_get_uint64(&r, sn) != 1) {
2032     return StringRef{};
2033   }
2034
2035   r = bswap64(r);
2036   return util::format_hex(
2037       balloc, StringRef{reinterpret_cast<uint8_t *>(&r), sizeof(r)});
2038 #else  // !OPENSSL_1_1_API
2039   auto sn = X509_get_serialNumber(x);
2040   auto bn = BN_new();
2041   auto bn_d = defer(BN_free, bn);
2042   if (!ASN1_INTEGER_to_BN(sn, bn)) {
2043     return StringRef{};
2044   }
2045
2046   std::array<uint8_t, 8> b;
2047   auto n = BN_bn2bin(bn, b.data());
2048   assert(n == b.size());
2049
2050   return util::format_hex(balloc, StringRef{std::begin(b), std::end(b)});
2051 #endif // !OPENSSL_1_1_API
2052 }
2053
2054 namespace {
2055 // Performs conversion from |at| to time_t.  The result is stored in
2056 // |t|.  This function returns 0 if it succeeds, or -1.
2057 int time_t_from_asn1_time(time_t &t, const ASN1_TIME *at) {
2058   int rv;
2059
2060 #if OPENSSL_1_1_1_API
2061   struct tm tm;
2062   rv = ASN1_TIME_to_tm(at, &tm);
2063   if (rv != 1) {
2064     return -1;
2065   }
2066
2067   t = nghttp2_timegm(&tm);
2068 #else  // !OPENSSL_1_1_1_API
2069   auto b = BIO_new(BIO_s_mem());
2070   if (!b) {
2071     return -1;
2072   }
2073
2074   auto bio_deleter = defer(BIO_free, b);
2075
2076   rv = ASN1_TIME_print(b, at);
2077   if (rv != 1) {
2078     return -1;
2079   }
2080
2081   unsigned char *s;
2082   auto slen = BIO_get_mem_data(b, &s);
2083   auto tt = util::parse_openssl_asn1_time_print(
2084       StringRef{s, static_cast<size_t>(slen)});
2085   if (tt == 0) {
2086     return -1;
2087   }
2088
2089   t = tt;
2090 #endif // !OPENSSL_1_1_1_API
2091
2092   return 0;
2093 }
2094 } // namespace
2095
2096 int get_x509_not_before(time_t &t, X509 *x) {
2097 #if OPENSSL_1_1_API
2098   auto at = X509_get0_notBefore(x);
2099 #else  // !OPENSSL_1_1_API
2100   auto at = X509_get_notBefore(x);
2101 #endif // !OPENSSL_1_1_API
2102   if (!at) {
2103     return -1;
2104   }
2105
2106   return time_t_from_asn1_time(t, at);
2107 }
2108
2109 int get_x509_not_after(time_t &t, X509 *x) {
2110 #if OPENSSL_1_1_API
2111   auto at = X509_get0_notAfter(x);
2112 #else  // !OPENSSL_1_1_API
2113   auto at = X509_get_notAfter(x);
2114 #endif // !OPENSSL_1_1_API
2115   if (!at) {
2116     return -1;
2117   }
2118
2119   return time_t_from_asn1_time(t, at);
2120 }
2121
2122 } // namespace tls
2123
2124 } // namespace shrpx