Upgrade to 1.46.0
[platform/upstream/nghttp2.git] / src / shrpx_quic.cc
1 /*
2  * nghttp2 - HTTP/2 C Library
3  *
4  * Copyright (c) 2021 Tatsuhiro Tsujikawa
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining
7  * a copy of this software and associated documentation files (the
8  * "Software"), to deal in the Software without restriction, including
9  * without limitation the rights to use, copy, modify, merge, publish,
10  * distribute, sublicense, and/or sell copies of the Software, and to
11  * permit persons to whom the Software is furnished to do so, subject to
12  * the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be
15  * included in all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24  */
25 #include "shrpx_quic.h"
26
27 #include <sys/types.h>
28 #include <sys/socket.h>
29 #include <netdb.h>
30 #include <netinet/udp.h>
31
32 #include <array>
33 #include <chrono>
34
35 #include <ngtcp2/ngtcp2_crypto.h>
36
37 #include <nghttp3/nghttp3.h>
38
39 #include <openssl/rand.h>
40
41 #include "shrpx_config.h"
42 #include "shrpx_log.h"
43 #include "util.h"
44 #include "xsi_strerror.h"
45
46 bool operator==(const ngtcp2_cid &lhs, const ngtcp2_cid &rhs) {
47   return ngtcp2_cid_eq(&lhs, &rhs);
48 }
49
50 namespace shrpx {
51
52 ngtcp2_tstamp quic_timestamp() {
53   return std::chrono::duration_cast<std::chrono::nanoseconds>(
54              std::chrono::steady_clock::now().time_since_epoch())
55       .count();
56 }
57
58 int quic_send_packet(const UpstreamAddr *faddr, const sockaddr *remote_sa,
59                      size_t remote_salen, const sockaddr *local_sa,
60                      size_t local_salen, const uint8_t *data, size_t datalen,
61                      size_t gso_size) {
62   iovec msg_iov = {const_cast<uint8_t *>(data), datalen};
63   msghdr msg{};
64   msg.msg_name = const_cast<sockaddr *>(remote_sa);
65   msg.msg_namelen = remote_salen;
66   msg.msg_iov = &msg_iov;
67   msg.msg_iovlen = 1;
68
69   uint8_t msg_ctrl[
70 #ifdef UDP_SEGMENT
71       CMSG_SPACE(sizeof(uint16_t)) +
72 #endif // UDP_SEGMENT
73       CMSG_SPACE(sizeof(in6_pktinfo))];
74
75   memset(msg_ctrl, 0, sizeof(msg_ctrl));
76
77   msg.msg_control = msg_ctrl;
78   msg.msg_controllen = sizeof(msg_ctrl);
79
80   size_t controllen = 0;
81
82   auto cm = CMSG_FIRSTHDR(&msg);
83
84   switch (local_sa->sa_family) {
85   case AF_INET: {
86     controllen += CMSG_SPACE(sizeof(in_pktinfo));
87     cm->cmsg_level = IPPROTO_IP;
88     cm->cmsg_type = IP_PKTINFO;
89     cm->cmsg_len = CMSG_LEN(sizeof(in_pktinfo));
90     auto pktinfo = reinterpret_cast<in_pktinfo *>(CMSG_DATA(cm));
91     memset(pktinfo, 0, sizeof(in_pktinfo));
92     auto addrin =
93         reinterpret_cast<sockaddr_in *>(const_cast<sockaddr *>(local_sa));
94     pktinfo->ipi_spec_dst = addrin->sin_addr;
95     break;
96   }
97   case AF_INET6: {
98     controllen += CMSG_SPACE(sizeof(in6_pktinfo));
99     cm->cmsg_level = IPPROTO_IPV6;
100     cm->cmsg_type = IPV6_PKTINFO;
101     cm->cmsg_len = CMSG_LEN(sizeof(in6_pktinfo));
102     auto pktinfo = reinterpret_cast<in6_pktinfo *>(CMSG_DATA(cm));
103     memset(pktinfo, 0, sizeof(in6_pktinfo));
104     auto addrin =
105         reinterpret_cast<sockaddr_in6 *>(const_cast<sockaddr *>(local_sa));
106     pktinfo->ipi6_addr = addrin->sin6_addr;
107     break;
108   }
109   default:
110     assert(0);
111   }
112
113 #ifdef UDP_SEGMENT
114   if (gso_size && datalen > gso_size) {
115     controllen += CMSG_SPACE(sizeof(uint16_t));
116     cm = CMSG_NXTHDR(&msg, cm);
117     cm->cmsg_level = SOL_UDP;
118     cm->cmsg_type = UDP_SEGMENT;
119     cm->cmsg_len = CMSG_LEN(sizeof(uint16_t));
120     *(reinterpret_cast<uint16_t *>(CMSG_DATA(cm))) = gso_size;
121   }
122 #endif // UDP_SEGMENT
123
124   msg.msg_controllen = controllen;
125
126   ssize_t nwrite;
127
128   do {
129     nwrite = sendmsg(faddr->fd, &msg, 0);
130   } while (nwrite == -1 && errno == EINTR);
131
132   if (nwrite == -1) {
133     if (LOG_ENABLED(INFO)) {
134       auto error = errno;
135       LOG(INFO) << "sendmsg failed: errno=" << error;
136     }
137
138     return -errno;
139   }
140
141   if (LOG_ENABLED(INFO)) {
142     LOG(INFO) << "QUIC sent packet: local="
143               << util::to_numeric_addr(local_sa, local_salen)
144               << " remote=" << util::to_numeric_addr(remote_sa, remote_salen)
145               << " " << nwrite << " bytes";
146   }
147
148   return 0;
149 }
150
151 int generate_quic_retry_connection_id(ngtcp2_cid &cid, size_t cidlen,
152                                       const uint8_t *server_id, uint8_t km_id,
153                                       const uint8_t *key) {
154   assert(cidlen == SHRPX_QUIC_SCIDLEN);
155
156   if (RAND_bytes(cid.data, cidlen) != 1) {
157     return -1;
158   }
159
160   cid.datalen = cidlen;
161
162   cid.data[0] = (cid.data[0] & 0x3f) | km_id;
163
164   auto p = cid.data + SHRPX_QUIC_CID_PREFIX_OFFSET;
165
166   std::copy_n(server_id, SHRPX_QUIC_SERVER_IDLEN, p);
167
168   return encrypt_quic_connection_id(p, p, key);
169 }
170
171 int generate_quic_connection_id(ngtcp2_cid &cid, size_t cidlen,
172                                 const uint8_t *cid_prefix, uint8_t km_id,
173                                 const uint8_t *key) {
174   assert(cidlen == SHRPX_QUIC_SCIDLEN);
175
176   if (RAND_bytes(cid.data, cidlen) != 1) {
177     return -1;
178   }
179
180   cid.datalen = cidlen;
181
182   cid.data[0] = (cid.data[0] & 0x3f) | km_id;
183
184   auto p = cid.data + SHRPX_QUIC_CID_PREFIX_OFFSET;
185
186   std::copy_n(cid_prefix, SHRPX_QUIC_CID_PREFIXLEN, p);
187
188   return encrypt_quic_connection_id(p, p, key);
189 }
190
191 int encrypt_quic_connection_id(uint8_t *dest, const uint8_t *src,
192                                const uint8_t *key) {
193   auto ctx = EVP_CIPHER_CTX_new();
194   auto d = defer(EVP_CIPHER_CTX_free, ctx);
195
196   if (!EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), nullptr, key, nullptr)) {
197     return -1;
198   }
199
200   EVP_CIPHER_CTX_set_padding(ctx, 0);
201
202   int len;
203
204   if (!EVP_EncryptUpdate(ctx, dest, &len, src, SHRPX_QUIC_DECRYPTED_DCIDLEN) ||
205       !EVP_EncryptFinal_ex(ctx, dest + len, &len)) {
206     return -1;
207   }
208
209   return 0;
210 }
211
212 int decrypt_quic_connection_id(uint8_t *dest, const uint8_t *src,
213                                const uint8_t *key) {
214   auto ctx = EVP_CIPHER_CTX_new();
215   auto d = defer(EVP_CIPHER_CTX_free, ctx);
216
217   if (!EVP_DecryptInit_ex(ctx, EVP_aes_128_ecb(), nullptr, key, nullptr)) {
218     return -1;
219   }
220
221   EVP_CIPHER_CTX_set_padding(ctx, 0);
222
223   int len;
224
225   if (!EVP_DecryptUpdate(ctx, dest, &len, src, SHRPX_QUIC_DECRYPTED_DCIDLEN) ||
226       !EVP_DecryptFinal_ex(ctx, dest + len, &len)) {
227     return -1;
228   }
229
230   return 0;
231 }
232
233 int generate_quic_hashed_connection_id(ngtcp2_cid &dest,
234                                        const Address &remote_addr,
235                                        const Address &local_addr,
236                                        const ngtcp2_cid &cid) {
237   auto ctx = EVP_MD_CTX_new();
238   auto d = defer(EVP_MD_CTX_free, ctx);
239
240   std::array<uint8_t, 32> h;
241   unsigned int hlen = EVP_MD_size(EVP_sha256());
242
243   if (!EVP_DigestInit_ex(ctx, EVP_sha256(), nullptr) ||
244       !EVP_DigestUpdate(ctx, &remote_addr.su.sa, remote_addr.len) ||
245       !EVP_DigestUpdate(ctx, &local_addr.su.sa, local_addr.len) ||
246       !EVP_DigestUpdate(ctx, cid.data, cid.datalen) ||
247       !EVP_DigestFinal_ex(ctx, h.data(), &hlen)) {
248     return -1;
249   }
250
251   assert(hlen == h.size());
252
253   std::copy_n(std::begin(h), sizeof(dest.data), std::begin(dest.data));
254   dest.datalen = sizeof(dest.data);
255
256   return 0;
257 }
258
259 int generate_quic_stateless_reset_token(uint8_t *token, const ngtcp2_cid &cid,
260                                         const uint8_t *secret,
261                                         size_t secretlen) {
262   if (ngtcp2_crypto_generate_stateless_reset_token(token, secret, secretlen,
263                                                    &cid) != 0) {
264     return -1;
265   }
266
267   return 0;
268 }
269
270 int generate_retry_token(uint8_t *token, size_t &tokenlen, const sockaddr *sa,
271                          socklen_t salen, const ngtcp2_cid &retry_scid,
272                          const ngtcp2_cid &odcid, const uint8_t *secret,
273                          size_t secretlen) {
274   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
275                std::chrono::system_clock::now().time_since_epoch())
276                .count();
277
278   auto stokenlen = ngtcp2_crypto_generate_retry_token(
279       token, secret, secretlen, sa, salen, &retry_scid, &odcid, t);
280   if (stokenlen < 0) {
281     return -1;
282   }
283
284   tokenlen = stokenlen;
285
286   return 0;
287 }
288
289 int verify_retry_token(ngtcp2_cid &odcid, const uint8_t *token, size_t tokenlen,
290                        const ngtcp2_cid &dcid, const sockaddr *sa,
291                        socklen_t salen, const uint8_t *secret,
292                        size_t secretlen) {
293
294   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
295                std::chrono::system_clock::now().time_since_epoch())
296                .count();
297
298   if (ngtcp2_crypto_verify_retry_token(&odcid, token, tokenlen, secret,
299                                        secretlen, sa, salen, &dcid,
300                                        10 * NGTCP2_SECONDS, t) != 0) {
301     return -1;
302   }
303
304   return 0;
305 }
306
307 int generate_token(uint8_t *token, size_t &tokenlen, const sockaddr *sa,
308                    size_t salen, const uint8_t *secret, size_t secretlen) {
309   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
310                std::chrono::system_clock::now().time_since_epoch())
311                .count();
312
313   auto stokenlen = ngtcp2_crypto_generate_regular_token(
314       token, secret, secretlen, sa, salen, t);
315   if (stokenlen < 0) {
316     return -1;
317   }
318
319   tokenlen = stokenlen;
320
321   return 0;
322 }
323
324 int verify_token(const uint8_t *token, size_t tokenlen, const sockaddr *sa,
325                  socklen_t salen, const uint8_t *secret, size_t secretlen) {
326   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
327                std::chrono::system_clock::now().time_since_epoch())
328                .count();
329
330   if (ngtcp2_crypto_verify_regular_token(token, tokenlen, secret, secretlen, sa,
331                                          salen, 3600 * NGTCP2_SECONDS,
332                                          t) != 0) {
333     return -1;
334   }
335
336   return 0;
337 }
338
339 int generate_quic_connection_id_encryption_key(uint8_t *key, size_t keylen,
340                                                const uint8_t *secret,
341                                                size_t secretlen,
342                                                const uint8_t *salt,
343                                                size_t saltlen) {
344   constexpr uint8_t info[] = "connection id encryption key";
345   ngtcp2_crypto_md sha256;
346   ngtcp2_crypto_md_init(
347       &sha256, reinterpret_cast<void *>(const_cast<EVP_MD *>(EVP_sha256())));
348
349   if (ngtcp2_crypto_hkdf(key, keylen, &sha256, secret, secretlen, salt, saltlen,
350                          info, str_size(info)) != 0) {
351     return -1;
352   }
353
354   return 0;
355 }
356
357 const QUICKeyingMaterial *
358 select_quic_keying_material(const QUICKeyingMaterials &qkms,
359                             const uint8_t *cid) {
360   for (auto &qkm : qkms.keying_materials) {
361     if (((*cid) & 0xc0) == qkm.id) {
362       return &qkm;
363     }
364   }
365
366   return &qkms.keying_materials.front();
367 }
368
369 } // namespace shrpx