Imported Upstream version 1.41.0
[platform/upstream/grpc.git] / test / cpp / qps / client_async.cc
1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <forward_list>
20 #include <functional>
21 #include <list>
22 #include <memory>
23 #include <mutex>
24 #include <sstream>
25 #include <string>
26 #include <thread>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/memory/memory.h"
31
32 #include <grpc/grpc.h>
33 #include <grpc/support/cpu.h>
34 #include <grpc/support/log.h>
35 #include <grpcpp/alarm.h>
36 #include <grpcpp/channel.h>
37 #include <grpcpp/client_context.h>
38 #include <grpcpp/generic/generic_stub.h>
39
40 #include "src/core/lib/surface/completion_queue.h"
41 #include "src/proto/grpc/testing/benchmark_service.grpc.pb.h"
42 #include "test/cpp/qps/client.h"
43 #include "test/cpp/qps/usage_timer.h"
44 #include "test/cpp/util/create_test_channel.h"
45
46 namespace grpc {
47 namespace testing {
48
49 class ClientRpcContext {
50  public:
51   ClientRpcContext() {}
52   virtual ~ClientRpcContext() {}
53   // next state, return false if done. Collect stats when appropriate
54   virtual bool RunNextState(bool, HistogramEntry* entry) = 0;
55   virtual void StartNewClone(CompletionQueue* cq) = 0;
56   static void* tag(ClientRpcContext* c) { return static_cast<void*>(c); }
57   static ClientRpcContext* detag(void* t) {
58     return static_cast<ClientRpcContext*>(t);
59   }
60
61   virtual void Start(CompletionQueue* cq, const ClientConfig& config) = 0;
62   virtual void TryCancel() = 0;
63 };
64
65 template <class RequestType, class ResponseType>
66 class ClientRpcContextUnaryImpl : public ClientRpcContext {
67  public:
68   ClientRpcContextUnaryImpl(
69       BenchmarkService::Stub* stub, const RequestType& req,
70       std::function<gpr_timespec()> next_issue,
71       std::function<
72           std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>(
73               BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
74               CompletionQueue*)>
75           prepare_req,
76       std::function<void(grpc::Status, ResponseType*, HistogramEntry*)> on_done)
77       : context_(),
78         stub_(stub),
79         cq_(nullptr),
80         req_(req),
81         response_(),
82         next_state_(State::READY),
83         callback_(on_done),
84         next_issue_(std::move(next_issue)),
85         prepare_req_(prepare_req) {}
86   ~ClientRpcContextUnaryImpl() override {}
87   void Start(CompletionQueue* cq, const ClientConfig& config) override {
88     GPR_ASSERT(!config.use_coalesce_api());  // not supported.
89     StartInternal(cq);
90   }
91   bool RunNextState(bool /*ok*/, HistogramEntry* entry) override {
92     switch (next_state_) {
93       case State::READY:
94         start_ = UsageTimer::Now();
95         response_reader_ = prepare_req_(stub_, &context_, req_, cq_);
96         response_reader_->StartCall();
97         next_state_ = State::RESP_DONE;
98         response_reader_->Finish(&response_, &status_,
99                                  ClientRpcContext::tag(this));
100         return true;
101       case State::RESP_DONE:
102         if (status_.ok()) {
103           entry->set_value((UsageTimer::Now() - start_) * 1e9);
104         }
105         callback_(status_, &response_, entry);
106         next_state_ = State::INVALID;
107         return false;
108       default:
109         GPR_ASSERT(false);
110         return false;
111     }
112   }
113   void StartNewClone(CompletionQueue* cq) override {
114     auto* clone = new ClientRpcContextUnaryImpl(stub_, req_, next_issue_,
115                                                 prepare_req_, callback_);
116     clone->StartInternal(cq);
117   }
118   void TryCancel() override { context_.TryCancel(); }
119
120  private:
121   grpc::ClientContext context_;
122   BenchmarkService::Stub* stub_;
123   CompletionQueue* cq_;
124   std::unique_ptr<Alarm> alarm_;
125   const RequestType& req_;
126   ResponseType response_;
127   enum State { INVALID, READY, RESP_DONE };
128   State next_state_;
129   std::function<void(grpc::Status, ResponseType*, HistogramEntry*)> callback_;
130   std::function<gpr_timespec()> next_issue_;
131   std::function<std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>(
132       BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
133       CompletionQueue*)>
134       prepare_req_;
135   grpc::Status status_;
136   double start_;
137   std::unique_ptr<grpc::ClientAsyncResponseReader<ResponseType>>
138       response_reader_;
139
140   void StartInternal(CompletionQueue* cq) {
141     cq_ = cq;
142     if (!next_issue_) {  // ready to issue
143       RunNextState(true, nullptr);
144     } else {  // wait for the issue time
145       alarm_ = absl::make_unique<Alarm>();
146       alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
147     }
148   }
149 };
150
151 template <class StubType, class RequestType>
152 class AsyncClient : public ClientImpl<StubType, RequestType> {
153   // Specify which protected members we are using since there is no
154   // member name resolution until the template types are fully resolved
155  public:
156   using Client::closed_loop_;
157   using Client::NextIssuer;
158   using Client::SetupLoadTest;
159   using ClientImpl<StubType, RequestType>::cores_;
160   using ClientImpl<StubType, RequestType>::channels_;
161   using ClientImpl<StubType, RequestType>::request_;
162   AsyncClient(const ClientConfig& config,
163               std::function<ClientRpcContext*(
164                   StubType*, std::function<gpr_timespec()> next_issue,
165                   const RequestType&)>
166                   setup_ctx,
167               std::function<std::unique_ptr<StubType>(std::shared_ptr<Channel>)>
168                   create_stub)
169       : ClientImpl<StubType, RequestType>(config, create_stub),
170         num_async_threads_(NumThreads(config)) {
171     SetupLoadTest(config, num_async_threads_);
172
173     int tpc = std::max(1, config.threads_per_cq());      // 1 if unspecified
174     int num_cqs = (num_async_threads_ + tpc - 1) / tpc;  // ceiling operator
175     for (int i = 0; i < num_cqs; i++) {
176       cli_cqs_.emplace_back(new CompletionQueue);
177     }
178
179     for (int i = 0; i < num_async_threads_; i++) {
180       cq_.emplace_back(i % cli_cqs_.size());
181       next_issuers_.emplace_back(NextIssuer(i));
182       shutdown_state_.emplace_back(new PerThreadShutdownState());
183     }
184
185     int t = 0;
186     for (int ch = 0; ch < config.client_channels(); ch++) {
187       for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) {
188         auto* cq = cli_cqs_[t].get();
189         auto ctx =
190             setup_ctx(channels_[ch].get_stub(), next_issuers_[t], request_);
191         ctx->Start(cq, config);
192       }
193       t = (t + 1) % cli_cqs_.size();
194     }
195   }
196   ~AsyncClient() override {
197     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
198       void* got_tag;
199       bool ok;
200       while ((*cq)->Next(&got_tag, &ok)) {
201         delete ClientRpcContext::detag(got_tag);
202       }
203     }
204   }
205
206   int GetPollCount() override {
207     int count = 0;
208     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
209       count += grpc_get_cq_poll_num((*cq)->cq());
210     }
211     return count;
212   }
213
214  protected:
215   const int num_async_threads_;
216
217  private:
218   struct PerThreadShutdownState {
219     mutable std::mutex mutex;
220     bool shutdown;
221     PerThreadShutdownState() : shutdown(false) {}
222   };
223
224   int NumThreads(const ClientConfig& config) {
225     int num_threads = config.async_client_threads();
226     if (num_threads <= 0) {  // Use dynamic sizing
227       num_threads = cores_;
228       gpr_log(GPR_INFO, "Sizing async client to %d threads", num_threads);
229     }
230     return num_threads;
231   }
232   void DestroyMultithreading() final {
233     for (auto ss = shutdown_state_.begin(); ss != shutdown_state_.end(); ++ss) {
234       std::lock_guard<std::mutex> lock((*ss)->mutex);
235       (*ss)->shutdown = true;
236     }
237     for (auto cq = cli_cqs_.begin(); cq != cli_cqs_.end(); cq++) {
238       (*cq)->Shutdown();
239     }
240     this->EndThreads();  // this needed for resolution
241   }
242
243   ClientRpcContext* ProcessTag(size_t thread_idx, void* tag) {
244     ClientRpcContext* ctx = ClientRpcContext::detag(tag);
245     if (shutdown_state_[thread_idx]->shutdown) {
246       ctx->TryCancel();
247       delete ctx;
248       bool ok;
249       while (cli_cqs_[cq_[thread_idx]]->Next(&tag, &ok)) {
250         ctx = ClientRpcContext::detag(tag);
251         ctx->TryCancel();
252         delete ctx;
253       }
254       return nullptr;
255     }
256     return ctx;
257   }
258
259   void ThreadFunc(size_t thread_idx, Client::Thread* t) final {
260     void* got_tag;
261     bool ok;
262
263     HistogramEntry entry;
264     HistogramEntry* entry_ptr = &entry;
265     if (!cli_cqs_[cq_[thread_idx]]->Next(&got_tag, &ok)) {
266       return;
267     }
268     std::mutex* shutdown_mu = &shutdown_state_[thread_idx]->mutex;
269     shutdown_mu->lock();
270     ClientRpcContext* ctx = ProcessTag(thread_idx, got_tag);
271     if (ctx == nullptr) {
272       shutdown_mu->unlock();
273       return;
274     }
275     while (cli_cqs_[cq_[thread_idx]]->DoThenAsyncNext(
276         [&, ctx, ok, entry_ptr, shutdown_mu]() {
277           if (!ctx->RunNextState(ok, entry_ptr)) {
278             // The RPC and callback are done, so clone the ctx
279             // and kickstart the new one
280             ctx->StartNewClone(cli_cqs_[cq_[thread_idx]].get());
281             delete ctx;
282           }
283           shutdown_mu->unlock();
284         },
285         &got_tag, &ok, gpr_inf_future(GPR_CLOCK_REALTIME))) {
286       t->UpdateHistogram(entry_ptr);
287       entry = HistogramEntry();
288       shutdown_mu->lock();
289       ctx = ProcessTag(thread_idx, got_tag);
290       if (ctx == nullptr) {
291         shutdown_mu->unlock();
292         return;
293       }
294     }
295   }
296
297   std::vector<std::unique_ptr<CompletionQueue>> cli_cqs_;
298   std::vector<int> cq_;
299   std::vector<std::function<gpr_timespec()>> next_issuers_;
300   std::vector<std::unique_ptr<PerThreadShutdownState>> shutdown_state_;
301 };
302
303 static std::unique_ptr<BenchmarkService::Stub> BenchmarkStubCreator(
304     const std::shared_ptr<Channel>& ch) {
305   return BenchmarkService::NewStub(ch);
306 }
307
308 class AsyncUnaryClient final
309     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
310  public:
311   explicit AsyncUnaryClient(const ClientConfig& config)
312       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
313             config, SetupCtx, BenchmarkStubCreator) {
314     StartThreads(num_async_threads_);
315   }
316   ~AsyncUnaryClient() override {}
317
318  private:
319   static void CheckDone(const grpc::Status& s, SimpleResponse* /*response*/,
320                         HistogramEntry* entry) {
321     entry->set_status(s.error_code());
322   }
323   static std::unique_ptr<grpc::ClientAsyncResponseReader<SimpleResponse>>
324   PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
325              const SimpleRequest& request, CompletionQueue* cq) {
326     return stub->PrepareAsyncUnaryCall(ctx, request, cq);
327   };
328   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
329                                     std::function<gpr_timespec()> next_issue,
330                                     const SimpleRequest& req) {
331     return new ClientRpcContextUnaryImpl<SimpleRequest, SimpleResponse>(
332         stub, req, std::move(next_issue), AsyncUnaryClient::PrepareReq,
333         AsyncUnaryClient::CheckDone);
334   }
335 };
336
337 template <class RequestType, class ResponseType>
338 class ClientRpcContextStreamingPingPongImpl : public ClientRpcContext {
339  public:
340   ClientRpcContextStreamingPingPongImpl(
341       BenchmarkService::Stub* stub, const RequestType& req,
342       std::function<gpr_timespec()> next_issue,
343       std::function<std::unique_ptr<
344           grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>(
345           BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)>
346           prepare_req,
347       std::function<void(grpc::Status, ResponseType*)> on_done)
348       : context_(),
349         stub_(stub),
350         cq_(nullptr),
351         req_(req),
352         response_(),
353         next_state_(State::INVALID),
354         callback_(on_done),
355         next_issue_(std::move(next_issue)),
356         prepare_req_(prepare_req),
357         coalesce_(false) {}
358   ~ClientRpcContextStreamingPingPongImpl() override {}
359   void Start(CompletionQueue* cq, const ClientConfig& config) override {
360     StartInternal(cq, config.messages_per_stream(), config.use_coalesce_api());
361   }
362   bool RunNextState(bool ok, HistogramEntry* entry) override {
363     while (true) {
364       switch (next_state_) {
365         case State::STREAM_IDLE:
366           if (!next_issue_) {  // ready to issue
367             next_state_ = State::READY_TO_WRITE;
368           } else {
369             next_state_ = State::WAIT;
370           }
371           break;  // loop around, don't return
372         case State::WAIT:
373           next_state_ = State::READY_TO_WRITE;
374           alarm_ = absl::make_unique<Alarm>();
375           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
376           return true;
377         case State::READY_TO_WRITE:
378           if (!ok) {
379             return false;
380           }
381           start_ = UsageTimer::Now();
382           next_state_ = State::WRITE_DONE;
383           if (coalesce_ && messages_issued_ == messages_per_stream_ - 1) {
384             stream_->WriteLast(req_, WriteOptions(),
385                                ClientRpcContext::tag(this));
386           } else {
387             stream_->Write(req_, ClientRpcContext::tag(this));
388           }
389           return true;
390         case State::WRITE_DONE:
391           if (!ok) {
392             return false;
393           }
394           next_state_ = State::READ_DONE;
395           stream_->Read(&response_, ClientRpcContext::tag(this));
396           return true;
397           break;
398         case State::READ_DONE:
399           entry->set_value((UsageTimer::Now() - start_) * 1e9);
400           callback_(status_, &response_);
401           if ((messages_per_stream_ != 0) &&
402               (++messages_issued_ >= messages_per_stream_)) {
403             next_state_ = State::WRITES_DONE_DONE;
404             if (coalesce_) {
405               // WritesDone should have been called on the last Write.
406               // loop around to call Finish.
407               break;
408             }
409             stream_->WritesDone(ClientRpcContext::tag(this));
410             return true;
411           }
412           next_state_ = State::STREAM_IDLE;
413           break;  // loop around
414         case State::WRITES_DONE_DONE:
415           next_state_ = State::FINISH_DONE;
416           stream_->Finish(&status_, ClientRpcContext::tag(this));
417           return true;
418         case State::FINISH_DONE:
419           next_state_ = State::INVALID;
420           return false;
421           break;
422         default:
423           GPR_ASSERT(false);
424           return false;
425       }
426     }
427   }
428   void StartNewClone(CompletionQueue* cq) override {
429     auto* clone = new ClientRpcContextStreamingPingPongImpl(
430         stub_, req_, next_issue_, prepare_req_, callback_);
431     clone->StartInternal(cq, messages_per_stream_, coalesce_);
432   }
433   void TryCancel() override { context_.TryCancel(); }
434
435  private:
436   grpc::ClientContext context_;
437   BenchmarkService::Stub* stub_;
438   CompletionQueue* cq_;
439   std::unique_ptr<Alarm> alarm_;
440   const RequestType& req_;
441   ResponseType response_;
442   enum State {
443     INVALID,
444     STREAM_IDLE,
445     WAIT,
446     READY_TO_WRITE,
447     WRITE_DONE,
448     READ_DONE,
449     WRITES_DONE_DONE,
450     FINISH_DONE
451   };
452   State next_state_;
453   std::function<void(grpc::Status, ResponseType*)> callback_;
454   std::function<gpr_timespec()> next_issue_;
455   std::function<
456       std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>(
457           BenchmarkService::Stub*, grpc::ClientContext*, CompletionQueue*)>
458       prepare_req_;
459   grpc::Status status_;
460   double start_;
461   std::unique_ptr<grpc::ClientAsyncReaderWriter<RequestType, ResponseType>>
462       stream_;
463
464   // Allow a limit on number of messages in a stream
465   int messages_per_stream_;
466   int messages_issued_;
467   // Whether to use coalescing API.
468   bool coalesce_;
469
470   void StartInternal(CompletionQueue* cq, int messages_per_stream,
471                      bool coalesce) {
472     cq_ = cq;
473     messages_per_stream_ = messages_per_stream;
474     messages_issued_ = 0;
475     coalesce_ = coalesce;
476     if (coalesce_) {
477       GPR_ASSERT(messages_per_stream_ != 0);
478       context_.set_initial_metadata_corked(true);
479     }
480     stream_ = prepare_req_(stub_, &context_, cq);
481     next_state_ = State::STREAM_IDLE;
482     stream_->StartCall(ClientRpcContext::tag(this));
483     if (coalesce_) {
484       // When the initial metadata is corked, the tag will not come back and we
485       // need to manually drive the state machine.
486       RunNextState(true, nullptr);
487     }
488   }
489 };
490
491 class AsyncStreamingPingPongClient final
492     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
493  public:
494   explicit AsyncStreamingPingPongClient(const ClientConfig& config)
495       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
496             config, SetupCtx, BenchmarkStubCreator) {
497     StartThreads(num_async_threads_);
498   }
499
500   ~AsyncStreamingPingPongClient() override {}
501
502  private:
503   static void CheckDone(const grpc::Status& /*s*/,
504                         SimpleResponse* /*response*/) {}
505   static std::unique_ptr<
506       grpc::ClientAsyncReaderWriter<SimpleRequest, SimpleResponse>>
507   PrepareReq(BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
508              CompletionQueue* cq) {
509     auto stream = stub->PrepareAsyncStreamingCall(ctx, cq);
510     return stream;
511   };
512   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
513                                     std::function<gpr_timespec()> next_issue,
514                                     const SimpleRequest& req) {
515     return new ClientRpcContextStreamingPingPongImpl<SimpleRequest,
516                                                      SimpleResponse>(
517         stub, req, std::move(next_issue),
518         AsyncStreamingPingPongClient::PrepareReq,
519         AsyncStreamingPingPongClient::CheckDone);
520   }
521 };
522
523 template <class RequestType, class ResponseType>
524 class ClientRpcContextStreamingFromClientImpl : public ClientRpcContext {
525  public:
526   ClientRpcContextStreamingFromClientImpl(
527       BenchmarkService::Stub* stub, const RequestType& req,
528       std::function<gpr_timespec()> next_issue,
529       std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>(
530           BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*,
531           CompletionQueue*)>
532           prepare_req,
533       std::function<void(grpc::Status, ResponseType*)> on_done)
534       : context_(),
535         stub_(stub),
536         cq_(nullptr),
537         req_(req),
538         response_(),
539         next_state_(State::INVALID),
540         callback_(on_done),
541         next_issue_(std::move(next_issue)),
542         prepare_req_(prepare_req) {}
543   ~ClientRpcContextStreamingFromClientImpl() override {}
544   void Start(CompletionQueue* cq, const ClientConfig& config) override {
545     GPR_ASSERT(!config.use_coalesce_api());  // not supported yet.
546     StartInternal(cq);
547   }
548   bool RunNextState(bool ok, HistogramEntry* entry) override {
549     while (true) {
550       switch (next_state_) {
551         case State::STREAM_IDLE:
552           if (!next_issue_) {  // ready to issue
553             next_state_ = State::READY_TO_WRITE;
554           } else {
555             next_state_ = State::WAIT;
556           }
557           break;  // loop around, don't return
558         case State::WAIT:
559           alarm_ = absl::make_unique<Alarm>();
560           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
561           next_state_ = State::READY_TO_WRITE;
562           return true;
563         case State::READY_TO_WRITE:
564           if (!ok) {
565             return false;
566           }
567           start_ = UsageTimer::Now();
568           next_state_ = State::WRITE_DONE;
569           stream_->Write(req_, ClientRpcContext::tag(this));
570           return true;
571         case State::WRITE_DONE:
572           if (!ok) {
573             return false;
574           }
575           entry->set_value((UsageTimer::Now() - start_) * 1e9);
576           next_state_ = State::STREAM_IDLE;
577           break;  // loop around
578         default:
579           GPR_ASSERT(false);
580           return false;
581       }
582     }
583   }
584   void StartNewClone(CompletionQueue* cq) override {
585     auto* clone = new ClientRpcContextStreamingFromClientImpl(
586         stub_, req_, next_issue_, prepare_req_, callback_);
587     clone->StartInternal(cq);
588   }
589   void TryCancel() override { context_.TryCancel(); }
590
591  private:
592   grpc::ClientContext context_;
593   BenchmarkService::Stub* stub_;
594   CompletionQueue* cq_;
595   std::unique_ptr<Alarm> alarm_;
596   const RequestType& req_;
597   ResponseType response_;
598   enum State {
599     INVALID,
600     STREAM_IDLE,
601     WAIT,
602     READY_TO_WRITE,
603     WRITE_DONE,
604   };
605   State next_state_;
606   std::function<void(grpc::Status, ResponseType*)> callback_;
607   std::function<gpr_timespec()> next_issue_;
608   std::function<std::unique_ptr<grpc::ClientAsyncWriter<RequestType>>(
609       BenchmarkService::Stub*, grpc::ClientContext*, ResponseType*,
610       CompletionQueue*)>
611       prepare_req_;
612   grpc::Status status_;
613   double start_;
614   std::unique_ptr<grpc::ClientAsyncWriter<RequestType>> stream_;
615
616   void StartInternal(CompletionQueue* cq) {
617     cq_ = cq;
618     stream_ = prepare_req_(stub_, &context_, &response_, cq);
619     next_state_ = State::STREAM_IDLE;
620     stream_->StartCall(ClientRpcContext::tag(this));
621   }
622 };
623
624 class AsyncStreamingFromClientClient final
625     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
626  public:
627   explicit AsyncStreamingFromClientClient(const ClientConfig& config)
628       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
629             config, SetupCtx, BenchmarkStubCreator) {
630     StartThreads(num_async_threads_);
631   }
632
633   ~AsyncStreamingFromClientClient() override {}
634
635  private:
636   static void CheckDone(const grpc::Status& /*s*/,
637                         SimpleResponse* /*response*/) {}
638   static std::unique_ptr<grpc::ClientAsyncWriter<SimpleRequest>> PrepareReq(
639       BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
640       SimpleResponse* resp, CompletionQueue* cq) {
641     auto stream = stub->PrepareAsyncStreamingFromClient(ctx, resp, cq);
642     return stream;
643   };
644   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
645                                     std::function<gpr_timespec()> next_issue,
646                                     const SimpleRequest& req) {
647     return new ClientRpcContextStreamingFromClientImpl<SimpleRequest,
648                                                        SimpleResponse>(
649         stub, req, std::move(next_issue),
650         AsyncStreamingFromClientClient::PrepareReq,
651         AsyncStreamingFromClientClient::CheckDone);
652   }
653 };
654
655 template <class RequestType, class ResponseType>
656 class ClientRpcContextStreamingFromServerImpl : public ClientRpcContext {
657  public:
658   ClientRpcContextStreamingFromServerImpl(
659       BenchmarkService::Stub* stub, const RequestType& req,
660       std::function<gpr_timespec()> next_issue,
661       std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>(
662           BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
663           CompletionQueue*)>
664           prepare_req,
665       std::function<void(grpc::Status, ResponseType*)> on_done)
666       : context_(),
667         stub_(stub),
668         cq_(nullptr),
669         req_(req),
670         response_(),
671         next_state_(State::INVALID),
672         callback_(on_done),
673         next_issue_(std::move(next_issue)),
674         prepare_req_(prepare_req) {}
675   ~ClientRpcContextStreamingFromServerImpl() override {}
676   void Start(CompletionQueue* cq, const ClientConfig& config) override {
677     GPR_ASSERT(!config.use_coalesce_api());  // not supported
678     StartInternal(cq);
679   }
680   bool RunNextState(bool ok, HistogramEntry* entry) override {
681     while (true) {
682       switch (next_state_) {
683         case State::STREAM_IDLE:
684           if (!ok) {
685             return false;
686           }
687           start_ = UsageTimer::Now();
688           next_state_ = State::READ_DONE;
689           stream_->Read(&response_, ClientRpcContext::tag(this));
690           return true;
691         case State::READ_DONE:
692           if (!ok) {
693             return false;
694           }
695           entry->set_value((UsageTimer::Now() - start_) * 1e9);
696           callback_(status_, &response_);
697           next_state_ = State::STREAM_IDLE;
698           break;  // loop around
699         default:
700           GPR_ASSERT(false);
701           return false;
702       }
703     }
704   }
705   void StartNewClone(CompletionQueue* cq) override {
706     auto* clone = new ClientRpcContextStreamingFromServerImpl(
707         stub_, req_, next_issue_, prepare_req_, callback_);
708     clone->StartInternal(cq);
709   }
710   void TryCancel() override { context_.TryCancel(); }
711
712  private:
713   grpc::ClientContext context_;
714   BenchmarkService::Stub* stub_;
715   CompletionQueue* cq_;
716   std::unique_ptr<Alarm> alarm_;
717   const RequestType& req_;
718   ResponseType response_;
719   enum State { INVALID, STREAM_IDLE, READ_DONE };
720   State next_state_;
721   std::function<void(grpc::Status, ResponseType*)> callback_;
722   std::function<gpr_timespec()> next_issue_;
723   std::function<std::unique_ptr<grpc::ClientAsyncReader<ResponseType>>(
724       BenchmarkService::Stub*, grpc::ClientContext*, const RequestType&,
725       CompletionQueue*)>
726       prepare_req_;
727   grpc::Status status_;
728   double start_;
729   std::unique_ptr<grpc::ClientAsyncReader<ResponseType>> stream_;
730
731   void StartInternal(CompletionQueue* cq) {
732     // TODO(vjpai): Add support to rate-pace this
733     cq_ = cq;
734     stream_ = prepare_req_(stub_, &context_, req_, cq);
735     next_state_ = State::STREAM_IDLE;
736     stream_->StartCall(ClientRpcContext::tag(this));
737   }
738 };
739
740 class AsyncStreamingFromServerClient final
741     : public AsyncClient<BenchmarkService::Stub, SimpleRequest> {
742  public:
743   explicit AsyncStreamingFromServerClient(const ClientConfig& config)
744       : AsyncClient<BenchmarkService::Stub, SimpleRequest>(
745             config, SetupCtx, BenchmarkStubCreator) {
746     StartThreads(num_async_threads_);
747   }
748
749   ~AsyncStreamingFromServerClient() override {}
750
751  private:
752   static void CheckDone(const grpc::Status& /*s*/,
753                         SimpleResponse* /*response*/) {}
754   static std::unique_ptr<grpc::ClientAsyncReader<SimpleResponse>> PrepareReq(
755       BenchmarkService::Stub* stub, grpc::ClientContext* ctx,
756       const SimpleRequest& req, CompletionQueue* cq) {
757     auto stream = stub->PrepareAsyncStreamingFromServer(ctx, req, cq);
758     return stream;
759   };
760   static ClientRpcContext* SetupCtx(BenchmarkService::Stub* stub,
761                                     std::function<gpr_timespec()> next_issue,
762                                     const SimpleRequest& req) {
763     return new ClientRpcContextStreamingFromServerImpl<SimpleRequest,
764                                                        SimpleResponse>(
765         stub, req, std::move(next_issue),
766         AsyncStreamingFromServerClient::PrepareReq,
767         AsyncStreamingFromServerClient::CheckDone);
768   }
769 };
770
771 class ClientRpcContextGenericStreamingImpl : public ClientRpcContext {
772  public:
773   ClientRpcContextGenericStreamingImpl(
774       grpc::GenericStub* stub, const ByteBuffer& req,
775       std::function<gpr_timespec()> next_issue,
776       std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter>(
777           grpc::GenericStub*, grpc::ClientContext*,
778           const std::string& method_name, CompletionQueue*)>
779           prepare_req,
780       std::function<void(grpc::Status, ByteBuffer*)> on_done)
781       : context_(),
782         stub_(stub),
783         cq_(nullptr),
784         req_(req),
785         response_(),
786         next_state_(State::INVALID),
787         callback_(std::move(on_done)),
788         next_issue_(std::move(next_issue)),
789         prepare_req_(std::move(prepare_req)) {}
790   ~ClientRpcContextGenericStreamingImpl() override {}
791   void Start(CompletionQueue* cq, const ClientConfig& config) override {
792     GPR_ASSERT(!config.use_coalesce_api());  // not supported yet.
793     StartInternal(cq, config.messages_per_stream());
794   }
795   bool RunNextState(bool ok, HistogramEntry* entry) override {
796     while (true) {
797       switch (next_state_) {
798         case State::STREAM_IDLE:
799           if (!next_issue_) {  // ready to issue
800             next_state_ = State::READY_TO_WRITE;
801           } else {
802             next_state_ = State::WAIT;
803           }
804           break;  // loop around, don't return
805         case State::WAIT:
806           next_state_ = State::READY_TO_WRITE;
807           alarm_ = absl::make_unique<Alarm>();
808           alarm_->Set(cq_, next_issue_(), ClientRpcContext::tag(this));
809           return true;
810         case State::READY_TO_WRITE:
811           if (!ok) {
812             return false;
813           }
814           start_ = UsageTimer::Now();
815           next_state_ = State::WRITE_DONE;
816           stream_->Write(req_, ClientRpcContext::tag(this));
817           return true;
818         case State::WRITE_DONE:
819           if (!ok) {
820             return false;
821           }
822           next_state_ = State::READ_DONE;
823           stream_->Read(&response_, ClientRpcContext::tag(this));
824           return true;
825         case State::READ_DONE:
826           entry->set_value((UsageTimer::Now() - start_) * 1e9);
827           callback_(status_, &response_);
828           if ((messages_per_stream_ != 0) &&
829               (++messages_issued_ >= messages_per_stream_)) {
830             next_state_ = State::WRITES_DONE_DONE;
831             stream_->WritesDone(ClientRpcContext::tag(this));
832             return true;
833           }
834           next_state_ = State::STREAM_IDLE;
835           break;  // loop around
836         case State::WRITES_DONE_DONE:
837           next_state_ = State::FINISH_DONE;
838           stream_->Finish(&status_, ClientRpcContext::tag(this));
839           return true;
840         case State::FINISH_DONE:
841           next_state_ = State::INVALID;
842           return false;
843         default:
844           GPR_ASSERT(false);
845           return false;
846       }
847     }
848   }
849   void StartNewClone(CompletionQueue* cq) override {
850     auto* clone = new ClientRpcContextGenericStreamingImpl(
851         stub_, req_, next_issue_, prepare_req_, callback_);
852     clone->StartInternal(cq, messages_per_stream_);
853   }
854   void TryCancel() override { context_.TryCancel(); }
855
856  private:
857   grpc::ClientContext context_;
858   grpc::GenericStub* stub_;
859   CompletionQueue* cq_;
860   std::unique_ptr<Alarm> alarm_;
861   ByteBuffer req_;
862   ByteBuffer response_;
863   enum State {
864     INVALID,
865     STREAM_IDLE,
866     WAIT,
867     READY_TO_WRITE,
868     WRITE_DONE,
869     READ_DONE,
870     WRITES_DONE_DONE,
871     FINISH_DONE
872   };
873   State next_state_;
874   std::function<void(grpc::Status, ByteBuffer*)> callback_;
875   std::function<gpr_timespec()> next_issue_;
876   std::function<std::unique_ptr<grpc::GenericClientAsyncReaderWriter>(
877       grpc::GenericStub*, grpc::ClientContext*, const std::string&,
878       CompletionQueue*)>
879       prepare_req_;
880   grpc::Status status_;
881   double start_;
882   std::unique_ptr<grpc::GenericClientAsyncReaderWriter> stream_;
883
884   // Allow a limit on number of messages in a stream
885   int messages_per_stream_;
886   int messages_issued_;
887
888   void StartInternal(CompletionQueue* cq, int messages_per_stream) {
889     cq_ = cq;
890     const std::string kMethodName(
891         "/grpc.testing.BenchmarkService/StreamingCall");
892     messages_per_stream_ = messages_per_stream;
893     messages_issued_ = 0;
894     stream_ = prepare_req_(stub_, &context_, kMethodName, cq);
895     next_state_ = State::STREAM_IDLE;
896     stream_->StartCall(ClientRpcContext::tag(this));
897   }
898 };
899
900 static std::unique_ptr<grpc::GenericStub> GenericStubCreator(
901     const std::shared_ptr<Channel>& ch) {
902   return absl::make_unique<grpc::GenericStub>(ch);
903 }
904
905 class GenericAsyncStreamingClient final
906     : public AsyncClient<grpc::GenericStub, ByteBuffer> {
907  public:
908   explicit GenericAsyncStreamingClient(const ClientConfig& config)
909       : AsyncClient<grpc::GenericStub, ByteBuffer>(config, SetupCtx,
910                                                    GenericStubCreator) {
911     StartThreads(num_async_threads_);
912   }
913
914   ~GenericAsyncStreamingClient() override {}
915
916  private:
917   static void CheckDone(const grpc::Status& /*s*/, ByteBuffer* /*response*/) {}
918   static std::unique_ptr<grpc::GenericClientAsyncReaderWriter> PrepareReq(
919       grpc::GenericStub* stub, grpc::ClientContext* ctx,
920       const std::string& method_name, CompletionQueue* cq) {
921     auto stream = stub->PrepareCall(ctx, method_name, cq);
922     return stream;
923   };
924   static ClientRpcContext* SetupCtx(grpc::GenericStub* stub,
925                                     std::function<gpr_timespec()> next_issue,
926                                     const ByteBuffer& req) {
927     return new ClientRpcContextGenericStreamingImpl(
928         stub, req, std::move(next_issue),
929         GenericAsyncStreamingClient::PrepareReq,
930         GenericAsyncStreamingClient::CheckDone);
931   }
932 };
933
934 std::unique_ptr<Client> CreateAsyncClient(const ClientConfig& config) {
935   switch (config.rpc_type()) {
936     case UNARY:
937       return std::unique_ptr<Client>(new AsyncUnaryClient(config));
938     case STREAMING:
939       return std::unique_ptr<Client>(new AsyncStreamingPingPongClient(config));
940     case STREAMING_FROM_CLIENT:
941       return std::unique_ptr<Client>(
942           new AsyncStreamingFromClientClient(config));
943     case STREAMING_FROM_SERVER:
944       return std::unique_ptr<Client>(
945           new AsyncStreamingFromServerClient(config));
946     case STREAMING_BOTH_WAYS:
947       // TODO(vjpai): Implement this
948       assert(false);
949       return nullptr;
950     default:
951       assert(false);
952       return nullptr;
953   }
954 }
955 std::unique_ptr<Client> CreateGenericAsyncStreamingClient(
956     const ClientConfig& config) {
957   return std::unique_ptr<Client>(new GenericAsyncStreamingClient(config));
958 }
959
960 }  // namespace testing
961 }  // namespace grpc