Imported Upstream version 1.46.0
[platform/upstream/nghttp2.git] / src / shrpx_memcached_connection.cc
1 /*
2  * nghttp2 - HTTP/2 C Library
3  *
4  * Copyright (c) 2015 Tatsuhiro Tsujikawa
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining
7  * a copy of this software and associated documentation files (the
8  * "Software"), to deal in the Software without restriction, including
9  * without limitation the rights to use, copy, modify, merge, publish,
10  * distribute, sublicense, and/or sell copies of the Software, and to
11  * permit persons to whom the Software is furnished to do so, subject to
12  * the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be
15  * included in all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24  */
25 #include "shrpx_memcached_connection.h"
26
27 #include <limits.h>
28 #include <sys/uio.h>
29
30 #include <cerrno>
31
32 #include "shrpx_memcached_request.h"
33 #include "shrpx_memcached_result.h"
34 #include "shrpx_config.h"
35 #include "shrpx_tls.h"
36 #include "shrpx_log.h"
37 #include "util.h"
38
39 namespace shrpx {
40
41 namespace {
42 void timeoutcb(struct ev_loop *loop, ev_timer *w, int revents) {
43   auto conn = static_cast<Connection *>(w->data);
44   auto mconn = static_cast<MemcachedConnection *>(conn->data);
45
46   if (w == &conn->rt && !conn->expired_rt()) {
47     return;
48   }
49
50   if (LOG_ENABLED(INFO)) {
51     MCLOG(INFO, mconn) << "Time out";
52   }
53
54   mconn->disconnect();
55 }
56 } // namespace
57
58 namespace {
59 void readcb(struct ev_loop *loop, ev_io *w, int revents) {
60   auto conn = static_cast<Connection *>(w->data);
61   auto mconn = static_cast<MemcachedConnection *>(conn->data);
62
63   if (mconn->on_read() != 0) {
64     mconn->reconnect_or_fail();
65     return;
66   }
67 }
68 } // namespace
69
70 namespace {
71 void writecb(struct ev_loop *loop, ev_io *w, int revents) {
72   auto conn = static_cast<Connection *>(w->data);
73   auto mconn = static_cast<MemcachedConnection *>(conn->data);
74
75   if (mconn->on_write() != 0) {
76     mconn->reconnect_or_fail();
77     return;
78   }
79 }
80 } // namespace
81
82 namespace {
83 void connectcb(struct ev_loop *loop, ev_io *w, int revents) {
84   auto conn = static_cast<Connection *>(w->data);
85   auto mconn = static_cast<MemcachedConnection *>(conn->data);
86
87   if (mconn->connected() != 0) {
88     mconn->disconnect();
89     return;
90   }
91
92   writecb(loop, w, revents);
93 }
94 } // namespace
95
96 constexpr auto write_timeout = 10_s;
97 constexpr auto read_timeout = 10_s;
98
99 MemcachedConnection::MemcachedConnection(const Address *addr,
100                                          struct ev_loop *loop, SSL_CTX *ssl_ctx,
101                                          const StringRef &sni_name,
102                                          MemchunkPool *mcpool,
103                                          std::mt19937 &gen)
104     : conn_(loop, -1, nullptr, mcpool, write_timeout, read_timeout, {}, {},
105             connectcb, readcb, timeoutcb, this, 0, 0., Proto::MEMCACHED),
106       do_read_(&MemcachedConnection::noop),
107       do_write_(&MemcachedConnection::noop),
108       sni_name_(sni_name),
109       connect_blocker_(
110           gen, loop, [] {}, [] {}),
111       parse_state_{},
112       addr_(addr),
113       ssl_ctx_(ssl_ctx),
114       sendsum_(0),
115       try_count_(0),
116       connected_(false) {}
117
118 MemcachedConnection::~MemcachedConnection() { conn_.disconnect(); }
119
120 namespace {
121 void clear_request(std::deque<std::unique_ptr<MemcachedRequest>> &q) {
122   for (auto &req : q) {
123     if (req->cb) {
124       req->cb(req.get(),
125               MemcachedResult(MemcachedStatusCode::EXT_NETWORK_ERROR));
126     }
127   }
128   q.clear();
129 }
130 } // namespace
131
132 void MemcachedConnection::disconnect() {
133   clear_request(recvq_);
134   clear_request(sendq_);
135
136   sendbufv_.clear();
137   sendsum_ = 0;
138
139   parse_state_ = {};
140
141   connected_ = false;
142
143   conn_.disconnect();
144
145   assert(recvbuf_.rleft() == 0);
146   recvbuf_.reset();
147
148   do_read_ = do_write_ = &MemcachedConnection::noop;
149 }
150
151 int MemcachedConnection::initiate_connection() {
152   assert(conn_.fd == -1);
153
154   if (ssl_ctx_) {
155     auto ssl = tls::create_ssl(ssl_ctx_);
156     if (!ssl) {
157       return -1;
158     }
159     conn_.set_ssl(ssl);
160     conn_.tls.client_session_cache = &tls_session_cache_;
161   }
162
163   conn_.fd = util::create_nonblock_socket(addr_->su.storage.ss_family);
164
165   if (conn_.fd == -1) {
166     auto error = errno;
167     MCLOG(WARN, this) << "socket() failed; errno=" << error;
168
169     return -1;
170   }
171
172   int rv;
173   rv = connect(conn_.fd, &addr_->su.sa, addr_->len);
174   if (rv != 0 && errno != EINPROGRESS) {
175     auto error = errno;
176     MCLOG(WARN, this) << "connect() failed; errno=" << error;
177
178     close(conn_.fd);
179     conn_.fd = -1;
180
181     return -1;
182   }
183
184   if (ssl_ctx_) {
185     if (!util::numeric_host(sni_name_.c_str())) {
186       SSL_set_tlsext_host_name(conn_.tls.ssl, sni_name_.c_str());
187     }
188
189     auto session = tls::reuse_tls_session(tls_session_cache_);
190     if (session) {
191       SSL_set_session(conn_.tls.ssl, session);
192       SSL_SESSION_free(session);
193     }
194
195     conn_.prepare_client_handshake();
196   }
197
198   if (LOG_ENABLED(INFO)) {
199     MCLOG(INFO, this) << "Connecting to memcached server";
200   }
201
202   ev_io_set(&conn_.wev, conn_.fd, EV_WRITE);
203   ev_io_set(&conn_.rev, conn_.fd, EV_READ);
204
205   ev_set_cb(&conn_.wev, connectcb);
206
207   conn_.wlimit.startw();
208   ev_timer_again(conn_.loop, &conn_.wt);
209
210   return 0;
211 }
212
213 int MemcachedConnection::connected() {
214   auto sock_error = util::get_socket_error(conn_.fd);
215   if (sock_error != 0) {
216     MCLOG(WARN, this) << "memcached connect failed; addr="
217                       << util::to_numeric_addr(addr_)
218                       << ": errno=" << sock_error;
219
220     connect_blocker_.on_failure();
221
222     conn_.wlimit.stopw();
223
224     return -1;
225   }
226
227   if (LOG_ENABLED(INFO)) {
228     MCLOG(INFO, this) << "connected to memcached server";
229   }
230
231   conn_.rlimit.startw();
232
233   ev_set_cb(&conn_.wev, writecb);
234
235   if (conn_.tls.ssl) {
236     conn_.again_rt();
237
238     do_read_ = &MemcachedConnection::tls_handshake;
239     do_write_ = &MemcachedConnection::tls_handshake;
240
241     return 0;
242   }
243
244   ev_timer_stop(conn_.loop, &conn_.wt);
245
246   connected_ = true;
247
248   connect_blocker_.on_success();
249
250   do_read_ = &MemcachedConnection::read_clear;
251   do_write_ = &MemcachedConnection::write_clear;
252
253   return 0;
254 }
255
256 int MemcachedConnection::on_write() { return do_write_(*this); }
257 int MemcachedConnection::on_read() { return do_read_(*this); }
258
259 int MemcachedConnection::tls_handshake() {
260   ERR_clear_error();
261
262   conn_.last_read = ev_now(conn_.loop);
263
264   auto rv = conn_.tls_handshake();
265   if (rv == SHRPX_ERR_INPROGRESS) {
266     return 0;
267   }
268
269   if (rv < 0) {
270     connect_blocker_.on_failure();
271     return rv;
272   }
273
274   if (LOG_ENABLED(INFO)) {
275     LOG(INFO) << "SSL/TLS handshake completed";
276   }
277
278   auto &tlsconf = get_config()->tls;
279
280   if (!tlsconf.insecure &&
281       tls::check_cert(conn_.tls.ssl, addr_, sni_name_) != 0) {
282     connect_blocker_.on_failure();
283     return -1;
284   }
285
286   ev_timer_stop(conn_.loop, &conn_.rt);
287   ev_timer_stop(conn_.loop, &conn_.wt);
288
289   connected_ = true;
290
291   connect_blocker_.on_success();
292
293   do_read_ = &MemcachedConnection::read_tls;
294   do_write_ = &MemcachedConnection::write_tls;
295
296   return on_write();
297 }
298
299 int MemcachedConnection::write_tls() {
300   if (!connected_) {
301     return 0;
302   }
303
304   conn_.last_read = ev_now(conn_.loop);
305
306   std::array<struct iovec, MAX_WR_IOVCNT> iov;
307   std::array<uint8_t, 16_k> buf;
308
309   for (; !sendq_.empty();) {
310     auto iovcnt = fill_request_buffer(iov.data(), iov.size());
311     auto p = std::begin(buf);
312     for (size_t i = 0; i < iovcnt; ++i) {
313       auto &v = iov[i];
314       auto n = std::min(static_cast<size_t>(std::end(buf) - p), v.iov_len);
315       p = std::copy_n(static_cast<uint8_t *>(v.iov_base), n, p);
316       if (p == std::end(buf)) {
317         break;
318       }
319     }
320
321     auto nwrite = conn_.write_tls(buf.data(), p - std::begin(buf));
322     if (nwrite < 0) {
323       return -1;
324     }
325     if (nwrite == 0) {
326       return 0;
327     }
328
329     drain_send_queue(nwrite);
330   }
331
332   conn_.wlimit.stopw();
333   ev_timer_stop(conn_.loop, &conn_.wt);
334
335   return 0;
336 }
337
338 int MemcachedConnection::read_tls() {
339   if (!connected_) {
340     return 0;
341   }
342
343   conn_.last_read = ev_now(conn_.loop);
344
345   for (;;) {
346     auto nread = conn_.read_tls(recvbuf_.last, recvbuf_.wleft());
347
348     if (nread == 0) {
349       return 0;
350     }
351
352     if (nread < 0) {
353       return -1;
354     }
355
356     recvbuf_.write(nread);
357
358     if (parse_packet() != 0) {
359       return -1;
360     }
361   }
362
363   return 0;
364 }
365
366 int MemcachedConnection::write_clear() {
367   if (!connected_) {
368     return 0;
369   }
370
371   conn_.last_read = ev_now(conn_.loop);
372
373   std::array<struct iovec, MAX_WR_IOVCNT> iov;
374
375   for (; !sendq_.empty();) {
376     auto iovcnt = fill_request_buffer(iov.data(), iov.size());
377     auto nwrite = conn_.writev_clear(iov.data(), iovcnt);
378     if (nwrite < 0) {
379       return -1;
380     }
381     if (nwrite == 0) {
382       return 0;
383     }
384
385     drain_send_queue(nwrite);
386   }
387
388   conn_.wlimit.stopw();
389   ev_timer_stop(conn_.loop, &conn_.wt);
390
391   return 0;
392 }
393
394 int MemcachedConnection::read_clear() {
395   if (!connected_) {
396     return 0;
397   }
398
399   conn_.last_read = ev_now(conn_.loop);
400
401   for (;;) {
402     auto nread = conn_.read_clear(recvbuf_.last, recvbuf_.wleft());
403
404     if (nread == 0) {
405       return 0;
406     }
407
408     if (nread < 0) {
409       return -1;
410     }
411
412     recvbuf_.write(nread);
413
414     if (parse_packet() != 0) {
415       return -1;
416     }
417   }
418
419   return 0;
420 }
421
422 int MemcachedConnection::parse_packet() {
423   auto in = recvbuf_.pos;
424
425   for (;;) {
426     auto busy = false;
427
428     switch (parse_state_.state) {
429     case MemcachedParseState::HEADER24: {
430       if (recvbuf_.last - in < 24) {
431         recvbuf_.drain_reset(in - recvbuf_.pos);
432         return 0;
433       }
434
435       if (recvq_.empty()) {
436         MCLOG(WARN, this)
437             << "Response received, but there is no in-flight request.";
438         return -1;
439       }
440
441       auto &req = recvq_.front();
442
443       if (*in != MEMCACHED_RES_MAGIC) {
444         MCLOG(WARN, this) << "Response has bad magic: "
445                           << static_cast<uint32_t>(*in);
446         return -1;
447       }
448       ++in;
449
450       parse_state_.op = static_cast<MemcachedOp>(*in++);
451       parse_state_.keylen = util::get_uint16(in);
452       in += 2;
453       parse_state_.extralen = *in++;
454       // skip 1 byte reserved data type
455       ++in;
456       parse_state_.status_code =
457           static_cast<MemcachedStatusCode>(util::get_uint16(in));
458       in += 2;
459       parse_state_.totalbody = util::get_uint32(in);
460       in += 4;
461       // skip 4 bytes opaque
462       in += 4;
463       parse_state_.cas = util::get_uint64(in);
464       in += 8;
465
466       if (req->op != parse_state_.op) {
467         MCLOG(WARN, this)
468             << "opcode in response does not match to the request: want "
469             << static_cast<uint32_t>(req->op) << ", got "
470             << static_cast<uint32_t>(parse_state_.op);
471         return -1;
472       }
473
474       if (parse_state_.keylen != 0) {
475         MCLOG(WARN, this) << "zero length keylen expected: got "
476                           << parse_state_.keylen;
477         return -1;
478       }
479
480       if (parse_state_.totalbody > 16_k) {
481         MCLOG(WARN, this) << "totalbody is too large: got "
482                           << parse_state_.totalbody;
483         return -1;
484       }
485
486       if (parse_state_.op == MemcachedOp::GET &&
487           parse_state_.status_code == MemcachedStatusCode::NO_ERROR &&
488           parse_state_.extralen == 0) {
489         MCLOG(WARN, this) << "response for GET does not have extra";
490         return -1;
491       }
492
493       if (parse_state_.totalbody <
494           parse_state_.keylen + parse_state_.extralen) {
495         MCLOG(WARN, this) << "totalbody is too short: totalbody "
496                           << parse_state_.totalbody << ", want min "
497                           << parse_state_.keylen + parse_state_.extralen;
498         return -1;
499       }
500
501       if (parse_state_.extralen) {
502         parse_state_.state = MemcachedParseState::EXTRA;
503         parse_state_.read_left = parse_state_.extralen;
504       } else {
505         parse_state_.state = MemcachedParseState::VALUE;
506         parse_state_.read_left = parse_state_.totalbody - parse_state_.keylen -
507                                  parse_state_.extralen;
508       }
509       busy = true;
510       break;
511     }
512     case MemcachedParseState::EXTRA: {
513       // We don't use extra for now. Just read and forget.
514       auto n = std::min(static_cast<size_t>(recvbuf_.last - in),
515                         parse_state_.read_left);
516
517       parse_state_.read_left -= n;
518       in += n;
519       if (parse_state_.read_left) {
520         recvbuf_.reset();
521         return 0;
522       }
523       parse_state_.state = MemcachedParseState::VALUE;
524       // since we require keylen == 0, totalbody - extralen ==
525       // valuelen
526       parse_state_.read_left =
527           parse_state_.totalbody - parse_state_.keylen - parse_state_.extralen;
528       busy = true;
529       break;
530     }
531     case MemcachedParseState::VALUE: {
532       auto n = std::min(static_cast<size_t>(recvbuf_.last - in),
533                         parse_state_.read_left);
534
535       parse_state_.value.insert(std::end(parse_state_.value), in, in + n);
536
537       parse_state_.read_left -= n;
538       in += n;
539       if (parse_state_.read_left) {
540         recvbuf_.reset();
541         return 0;
542       }
543
544       if (LOG_ENABLED(INFO)) {
545         if (parse_state_.status_code != MemcachedStatusCode::NO_ERROR) {
546           MCLOG(INFO, this) << "response returned error status: "
547                             << static_cast<uint16_t>(parse_state_.status_code);
548         }
549       }
550
551       // We require at least one complete response to clear try count.
552       try_count_ = 0;
553
554       auto req = std::move(recvq_.front());
555       recvq_.pop_front();
556
557       if (sendq_.empty() && recvq_.empty()) {
558         ev_timer_stop(conn_.loop, &conn_.rt);
559       }
560
561       if (!req->canceled && req->cb) {
562         req->cb(req.get(), MemcachedResult(parse_state_.status_code,
563                                            std::move(parse_state_.value)));
564       }
565
566       parse_state_ = {};
567       break;
568     }
569     }
570
571     if (!busy && in == recvbuf_.last) {
572       break;
573     }
574   }
575
576   assert(in == recvbuf_.last);
577   recvbuf_.reset();
578
579   return 0;
580 }
581
582 #undef DEFAULT_WR_IOVCNT
583 #define DEFAULT_WR_IOVCNT 128
584
585 #if defined(IOV_MAX) && IOV_MAX < DEFAULT_WR_IOVCNT
586 #  define MAX_WR_IOVCNT IOV_MAX
587 #else // !defined(IOV_MAX) || IOV_MAX >= DEFAULT_WR_IOVCNT
588 #  define MAX_WR_IOVCNT DEFAULT_WR_IOVCNT
589 #endif // !defined(IOV_MAX) || IOV_MAX >= DEFAULT_WR_IOVCNT
590
591 size_t MemcachedConnection::fill_request_buffer(struct iovec *iov,
592                                                 size_t iovlen) {
593   if (sendsum_ == 0) {
594     for (auto &req : sendq_) {
595       if (req->canceled) {
596         continue;
597       }
598       if (serialized_size(req.get()) + sendsum_ > 1300) {
599         break;
600       }
601       sendbufv_.emplace_back();
602       sendbufv_.back().req = req.get();
603       make_request(&sendbufv_.back(), req.get());
604       sendsum_ += sendbufv_.back().left();
605     }
606
607     if (sendsum_ == 0) {
608       sendq_.clear();
609       return 0;
610     }
611   }
612
613   size_t iovcnt = 0;
614   for (auto &buf : sendbufv_) {
615     if (iovcnt + 2 > iovlen) {
616       break;
617     }
618
619     auto req = buf.req;
620     if (buf.headbuf.rleft()) {
621       iov[iovcnt++] = {buf.headbuf.pos, buf.headbuf.rleft()};
622     }
623     if (buf.send_value_left) {
624       iov[iovcnt++] = {req->value.data() + req->value.size() -
625                            buf.send_value_left,
626                        buf.send_value_left};
627     }
628   }
629
630   return iovcnt;
631 }
632
633 void MemcachedConnection::drain_send_queue(size_t nwrite) {
634   sendsum_ -= nwrite;
635
636   while (nwrite > 0) {
637     auto &buf = sendbufv_.front();
638     auto &req = sendq_.front();
639     if (req->canceled) {
640       sendq_.pop_front();
641       continue;
642     }
643     assert(buf.req == req.get());
644     auto n = std::min(static_cast<size_t>(nwrite), buf.headbuf.rleft());
645     buf.headbuf.drain(n);
646     nwrite -= n;
647     n = std::min(static_cast<size_t>(nwrite), buf.send_value_left);
648     buf.send_value_left -= n;
649     nwrite -= n;
650
651     if (buf.headbuf.rleft() || buf.send_value_left) {
652       break;
653     }
654     sendbufv_.pop_front();
655     recvq_.push_back(std::move(sendq_.front()));
656     sendq_.pop_front();
657   }
658
659   // start read timer only when we wait for responses.
660   if (recvq_.empty()) {
661     ev_timer_stop(conn_.loop, &conn_.rt);
662   } else if (!ev_is_active(&conn_.rt)) {
663     conn_.again_rt();
664   }
665 }
666
667 size_t MemcachedConnection::serialized_size(MemcachedRequest *req) {
668   switch (req->op) {
669   case MemcachedOp::GET:
670     return 24 + req->key.size();
671   case MemcachedOp::ADD:
672   default:
673     return 24 + 8 + req->key.size() + req->value.size();
674   }
675 }
676
677 void MemcachedConnection::make_request(MemcachedSendbuf *sendbuf,
678                                        MemcachedRequest *req) {
679   auto &headbuf = sendbuf->headbuf;
680
681   std::fill(std::begin(headbuf.buf), std::end(headbuf.buf), 0);
682
683   headbuf[0] = MEMCACHED_REQ_MAGIC;
684   headbuf[1] = static_cast<uint8_t>(req->op);
685   switch (req->op) {
686   case MemcachedOp::GET:
687     util::put_uint16be(&headbuf[2], req->key.size());
688     util::put_uint32be(&headbuf[8], req->key.size());
689     headbuf.write(24);
690     break;
691   case MemcachedOp::ADD:
692     util::put_uint16be(&headbuf[2], req->key.size());
693     headbuf[4] = 8;
694     util::put_uint32be(&headbuf[8], 8 + req->key.size() + req->value.size());
695     util::put_uint32be(&headbuf[28], req->expiry);
696     headbuf.write(32);
697     break;
698   }
699
700   headbuf.write(req->key.c_str(), req->key.size());
701
702   sendbuf->send_value_left = req->value.size();
703 }
704
705 int MemcachedConnection::add_request(std::unique_ptr<MemcachedRequest> req) {
706   if (connect_blocker_.blocked()) {
707     return -1;
708   }
709
710   sendq_.push_back(std::move(req));
711
712   if (connected_) {
713     signal_write();
714     return 0;
715   }
716
717   if (conn_.fd == -1 && initiate_connection() != 0) {
718     connect_blocker_.on_failure();
719     disconnect();
720     return -1;
721   }
722
723   return 0;
724 }
725
726 // TODO should we start write timer too?
727 void MemcachedConnection::signal_write() { conn_.wlimit.startw(); }
728
729 int MemcachedConnection::noop() { return 0; }
730
731 void MemcachedConnection::reconnect_or_fail() {
732   if (!connected_ || (recvq_.empty() && sendq_.empty())) {
733     disconnect();
734     return;
735   }
736
737   constexpr size_t MAX_TRY_COUNT = 3;
738
739   if (++try_count_ >= MAX_TRY_COUNT) {
740     if (LOG_ENABLED(INFO)) {
741       MCLOG(INFO, this) << "Tried " << MAX_TRY_COUNT
742                         << " times, and all failed.  Aborting";
743     }
744     try_count_ = 0;
745     disconnect();
746     return;
747   }
748
749   std::vector<std::unique_ptr<MemcachedRequest>> q;
750   q.reserve(recvq_.size() + sendq_.size());
751
752   if (LOG_ENABLED(INFO)) {
753     MCLOG(INFO, this) << "Retry connection, enqueue "
754                       << recvq_.size() + sendq_.size() << " request(s) again";
755   }
756
757   q.insert(std::end(q), std::make_move_iterator(std::begin(recvq_)),
758            std::make_move_iterator(std::end(recvq_)));
759   q.insert(std::end(q), std::make_move_iterator(std::begin(sendq_)),
760            std::make_move_iterator(std::end(sendq_)));
761
762   recvq_.clear();
763   sendq_.clear();
764
765   disconnect();
766
767   sendq_.insert(std::end(sendq_), std::make_move_iterator(std::begin(q)),
768                 std::make_move_iterator(std::end(q)));
769
770   if (initiate_connection() != 0) {
771     connect_blocker_.on_failure();
772     disconnect();
773     return;
774   }
775 }
776
777 } // namespace shrpx