Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / test / cpp / end2end / client_callback_end2end_test.cc
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <grpcpp/channel.h>
20 #include <grpcpp/client_context.h>
21 #include <grpcpp/create_channel.h>
22 #include <grpcpp/generic/generic_stub.h>
23 #include <grpcpp/impl/codegen/proto_utils.h>
24 #include <grpcpp/server.h>
25 #include <grpcpp/server_builder.h>
26 #include <grpcpp/server_context.h>
27 #include <grpcpp/support/client_callback.h>
28 #include <gtest/gtest.h>
29
30 #include <algorithm>
31 #include <condition_variable>
32 #include <functional>
33 #include <mutex>
34 #include <sstream>
35 #include <thread>
36
37 #include "absl/memory/memory.h"
38 #include "src/core/lib/gpr/env.h"
39 #include "src/core/lib/iomgr/iomgr.h"
40 #include "src/proto/grpc/testing/echo.grpc.pb.h"
41 #include "test/core/util/port.h"
42 #include "test/core/util/test_config.h"
43 #include "test/cpp/end2end/interceptors_util.h"
44 #include "test/cpp/end2end/test_service_impl.h"
45 #include "test/cpp/util/byte_buffer_proto_helper.h"
46 #include "test/cpp/util/string_ref_helper.h"
47 #include "test/cpp/util/test_credentials_provider.h"
48
49 // MAYBE_SKIP_TEST is a macro to determine if this particular test configuration
50 // should be skipped based on a decision made at SetUp time. In particular, any
51 // callback tests can only be run if the iomgr can run in the background or if
52 // the transport is in-process.
53 #define MAYBE_SKIP_TEST \
54   do {                  \
55     if (do_not_test_) { \
56       return;           \
57     }                   \
58   } while (0)
59
60 namespace grpc {
61 namespace testing {
62 namespace {
63
64 enum class Protocol { INPROC, TCP };
65
66 class TestScenario {
67  public:
68   TestScenario(bool serve_callback, Protocol protocol, bool intercept,
69                const std::string& creds_type)
70       : callback_server(serve_callback),
71         protocol(protocol),
72         use_interceptors(intercept),
73         credentials_type(creds_type) {}
74   void Log() const;
75   bool callback_server;
76   Protocol protocol;
77   bool use_interceptors;
78   const std::string credentials_type;
79 };
80
81 static std::ostream& operator<<(std::ostream& out,
82                                 const TestScenario& scenario) {
83   return out << "TestScenario{callback_server="
84              << (scenario.callback_server ? "true" : "false") << ",protocol="
85              << (scenario.protocol == Protocol::INPROC ? "INPROC" : "TCP")
86              << ",intercept=" << (scenario.use_interceptors ? "true" : "false")
87              << ",creds=" << scenario.credentials_type << "}";
88 }
89
90 void TestScenario::Log() const {
91   std::ostringstream out;
92   out << *this;
93   gpr_log(GPR_DEBUG, "%s", out.str().c_str());
94 }
95
96 class ClientCallbackEnd2endTest
97     : public ::testing::TestWithParam<TestScenario> {
98  protected:
99   ClientCallbackEnd2endTest() { GetParam().Log(); }
100
101   void SetUp() override {
102     ServerBuilder builder;
103
104     auto server_creds = GetCredentialsProvider()->GetServerCredentials(
105         GetParam().credentials_type);
106     // TODO(vjpai): Support testing of AuthMetadataProcessor
107
108     if (GetParam().protocol == Protocol::TCP) {
109       picked_port_ = grpc_pick_unused_port_or_die();
110       server_address_ << "localhost:" << picked_port_;
111       builder.AddListeningPort(server_address_.str(), server_creds);
112     }
113     if (!GetParam().callback_server) {
114       builder.RegisterService(&service_);
115     } else {
116       builder.RegisterService(&callback_service_);
117     }
118
119     if (GetParam().use_interceptors) {
120       std::vector<
121           std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>
122           creators;
123       // Add 20 dummy server interceptors
124       creators.reserve(20);
125       for (auto i = 0; i < 20; i++) {
126         creators.push_back(absl::make_unique<DummyInterceptorFactory>());
127       }
128       builder.experimental().SetInterceptorCreators(std::move(creators));
129     }
130
131     server_ = builder.BuildAndStart();
132     is_server_started_ = true;
133     if (GetParam().protocol == Protocol::TCP &&
134         !grpc_iomgr_run_in_background()) {
135       do_not_test_ = true;
136     }
137   }
138
139   void ResetStub() {
140     ChannelArguments args;
141     auto channel_creds = GetCredentialsProvider()->GetChannelCredentials(
142         GetParam().credentials_type, &args);
143     switch (GetParam().protocol) {
144       case Protocol::TCP:
145         if (!GetParam().use_interceptors) {
146           channel_ = ::grpc::CreateCustomChannel(server_address_.str(),
147                                                  channel_creds, args);
148         } else {
149           channel_ = CreateCustomChannelWithInterceptors(
150               server_address_.str(), channel_creds, args,
151               CreateDummyClientInterceptors());
152         }
153         break;
154       case Protocol::INPROC:
155         if (!GetParam().use_interceptors) {
156           channel_ = server_->InProcessChannel(args);
157         } else {
158           channel_ = server_->experimental().InProcessChannelWithInterceptors(
159               args, CreateDummyClientInterceptors());
160         }
161         break;
162       default:
163         assert(false);
164     }
165     stub_ = grpc::testing::EchoTestService::NewStub(channel_);
166     generic_stub_ = absl::make_unique<GenericStub>(channel_);
167     DummyInterceptor::Reset();
168   }
169
170   void TearDown() override {
171     if (is_server_started_) {
172       // Although we would normally do an explicit shutdown, the server
173       // should also work correctly with just a destructor call. The regular
174       // end2end test uses explicit shutdown, so let this one just do reset.
175       server_.reset();
176     }
177     if (picked_port_ > 0) {
178       grpc_recycle_unused_port(picked_port_);
179     }
180   }
181
182   void SendRpcs(int num_rpcs, bool with_binary_metadata) {
183     std::string test_string("");
184     for (int i = 0; i < num_rpcs; i++) {
185       EchoRequest request;
186       EchoResponse response;
187       ClientContext cli_ctx;
188
189       test_string += "Hello world. ";
190       request.set_message(test_string);
191       std::string val;
192       if (with_binary_metadata) {
193         request.mutable_param()->set_echo_metadata(true);
194         char bytes[8] = {'\0', '\1', '\2', '\3',
195                          '\4', '\5', '\6', static_cast<char>(i)};
196         val = std::string(bytes, 8);
197         cli_ctx.AddMetadata("custom-bin", val);
198       }
199
200       cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
201
202       std::mutex mu;
203       std::condition_variable cv;
204       bool done = false;
205       stub_->experimental_async()->Echo(
206           &cli_ctx, &request, &response,
207           [&cli_ctx, &request, &response, &done, &mu, &cv, val,
208            with_binary_metadata](Status s) {
209             GPR_ASSERT(s.ok());
210
211             EXPECT_EQ(request.message(), response.message());
212             if (with_binary_metadata) {
213               EXPECT_EQ(
214                   1u, cli_ctx.GetServerTrailingMetadata().count("custom-bin"));
215               EXPECT_EQ(val, ToString(cli_ctx.GetServerTrailingMetadata()
216                                           .find("custom-bin")
217                                           ->second));
218             }
219             std::lock_guard<std::mutex> l(mu);
220             done = true;
221             cv.notify_one();
222           });
223       std::unique_lock<std::mutex> l(mu);
224       while (!done) {
225         cv.wait(l);
226       }
227     }
228   }
229
230   void SendRpcsGeneric(int num_rpcs, bool maybe_except) {
231     const std::string kMethodName("/grpc.testing.EchoTestService/Echo");
232     std::string test_string("");
233     for (int i = 0; i < num_rpcs; i++) {
234       EchoRequest request;
235       std::unique_ptr<ByteBuffer> send_buf;
236       ByteBuffer recv_buf;
237       ClientContext cli_ctx;
238
239       test_string += "Hello world. ";
240       request.set_message(test_string);
241       send_buf = SerializeToByteBuffer(&request);
242
243       std::mutex mu;
244       std::condition_variable cv;
245       bool done = false;
246       generic_stub_->experimental().UnaryCall(
247           &cli_ctx, kMethodName, send_buf.get(), &recv_buf,
248           [&request, &recv_buf, &done, &mu, &cv, maybe_except](Status s) {
249             GPR_ASSERT(s.ok());
250
251             EchoResponse response;
252             EXPECT_TRUE(ParseFromByteBuffer(&recv_buf, &response));
253             EXPECT_EQ(request.message(), response.message());
254             std::lock_guard<std::mutex> l(mu);
255             done = true;
256             cv.notify_one();
257 #if GRPC_ALLOW_EXCEPTIONS
258             if (maybe_except) {
259               throw - 1;
260             }
261 #else
262             GPR_ASSERT(!maybe_except);
263 #endif
264           });
265       std::unique_lock<std::mutex> l(mu);
266       while (!done) {
267         cv.wait(l);
268       }
269     }
270   }
271
272   void SendGenericEchoAsBidi(int num_rpcs, int reuses, bool do_writes_done) {
273     const std::string kMethodName("/grpc.testing.EchoTestService/Echo");
274     std::string test_string("");
275     for (int i = 0; i < num_rpcs; i++) {
276       test_string += "Hello world. ";
277       class Client : public grpc::experimental::ClientBidiReactor<ByteBuffer,
278                                                                   ByteBuffer> {
279        public:
280         Client(ClientCallbackEnd2endTest* test, const std::string& method_name,
281                const std::string& test_str, int reuses, bool do_writes_done)
282             : reuses_remaining_(reuses), do_writes_done_(do_writes_done) {
283           activate_ = [this, test, method_name, test_str] {
284             if (reuses_remaining_ > 0) {
285               cli_ctx_ = absl::make_unique<ClientContext>();
286               reuses_remaining_--;
287               test->generic_stub_->experimental().PrepareBidiStreamingCall(
288                   cli_ctx_.get(), method_name, this);
289               request_.set_message(test_str);
290               send_buf_ = SerializeToByteBuffer(&request_);
291               StartWrite(send_buf_.get());
292               StartRead(&recv_buf_);
293               StartCall();
294             } else {
295               std::unique_lock<std::mutex> l(mu_);
296               done_ = true;
297               cv_.notify_one();
298             }
299           };
300           activate_();
301         }
302         void OnWriteDone(bool /*ok*/) override {
303           if (do_writes_done_) {
304             StartWritesDone();
305           }
306         }
307         void OnReadDone(bool /*ok*/) override {
308           EchoResponse response;
309           EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response));
310           EXPECT_EQ(request_.message(), response.message());
311         };
312         void OnDone(const Status& s) override {
313           EXPECT_TRUE(s.ok());
314           activate_();
315         }
316         void Await() {
317           std::unique_lock<std::mutex> l(mu_);
318           while (!done_) {
319             cv_.wait(l);
320           }
321         }
322
323         EchoRequest request_;
324         std::unique_ptr<ByteBuffer> send_buf_;
325         ByteBuffer recv_buf_;
326         std::unique_ptr<ClientContext> cli_ctx_;
327         int reuses_remaining_;
328         std::function<void()> activate_;
329         std::mutex mu_;
330         std::condition_variable cv_;
331         bool done_ = false;
332         const bool do_writes_done_;
333       };
334
335       Client rpc(this, kMethodName, test_string, reuses, do_writes_done);
336
337       rpc.Await();
338     }
339   }
340   bool do_not_test_{false};
341   bool is_server_started_{false};
342   int picked_port_{0};
343   std::shared_ptr<Channel> channel_;
344   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
345   std::unique_ptr<grpc::GenericStub> generic_stub_;
346   TestServiceImpl service_;
347   CallbackTestServiceImpl callback_service_;
348   std::unique_ptr<Server> server_;
349   std::ostringstream server_address_;
350 };
351
352 TEST_P(ClientCallbackEnd2endTest, SimpleRpc) {
353   MAYBE_SKIP_TEST;
354   ResetStub();
355   SendRpcs(1, false);
356 }
357
358 TEST_P(ClientCallbackEnd2endTest, SimpleRpcExpectedError) {
359   MAYBE_SKIP_TEST;
360   ResetStub();
361
362   EchoRequest request;
363   EchoResponse response;
364   ClientContext cli_ctx;
365   ErrorStatus error_status;
366
367   request.set_message("Hello failure");
368   error_status.set_code(1);  // CANCELLED
369   error_status.set_error_message("cancel error message");
370   *request.mutable_param()->mutable_expected_error() = error_status;
371
372   std::mutex mu;
373   std::condition_variable cv;
374   bool done = false;
375
376   stub_->experimental_async()->Echo(
377       &cli_ctx, &request, &response,
378       [&response, &done, &mu, &cv, &error_status](Status s) {
379         EXPECT_EQ("", response.message());
380         EXPECT_EQ(error_status.code(), s.error_code());
381         EXPECT_EQ(error_status.error_message(), s.error_message());
382         std::lock_guard<std::mutex> l(mu);
383         done = true;
384         cv.notify_one();
385       });
386
387   std::unique_lock<std::mutex> l(mu);
388   while (!done) {
389     cv.wait(l);
390   }
391 }
392
393 TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLockNested) {
394   MAYBE_SKIP_TEST;
395   ResetStub();
396
397   // The request/response state associated with an RPC and the synchronization
398   // variables needed to notify its completion.
399   struct RpcState {
400     std::mutex mu;
401     std::condition_variable cv;
402     bool done = false;
403     EchoRequest request;
404     EchoResponse response;
405     ClientContext cli_ctx;
406
407     RpcState() = default;
408     ~RpcState() {
409       // Grab the lock to prevent destruction while another is still holding
410       // lock
411       std::lock_guard<std::mutex> lock(mu);
412     }
413   };
414   std::vector<RpcState> rpc_state(3);
415   for (size_t i = 0; i < rpc_state.size(); i++) {
416     std::string message = "Hello locked world";
417     message += std::to_string(i);
418     rpc_state[i].request.set_message(message);
419   }
420
421   // Grab a lock and then start an RPC whose callback grabs the same lock and
422   // then calls this function to start the next RPC under lock (up to a limit of
423   // the size of the rpc_state vector).
424   std::function<void(int)> nested_call = [this, &nested_call,
425                                           &rpc_state](int index) {
426     std::lock_guard<std::mutex> l(rpc_state[index].mu);
427     stub_->experimental_async()->Echo(
428         &rpc_state[index].cli_ctx, &rpc_state[index].request,
429         &rpc_state[index].response,
430         [index, &nested_call, &rpc_state](Status s) {
431           std::lock_guard<std::mutex> l1(rpc_state[index].mu);
432           EXPECT_TRUE(s.ok());
433           rpc_state[index].done = true;
434           rpc_state[index].cv.notify_all();
435           // Call the next level of nesting if possible
436           if (index + 1 < int(rpc_state.size())) {
437             nested_call(index + 1);
438           }
439         });
440   };
441
442   nested_call(0);
443
444   // Wait for completion notifications from all RPCs. Order doesn't matter.
445   for (RpcState& state : rpc_state) {
446     std::unique_lock<std::mutex> l(state.mu);
447     while (!state.done) {
448       state.cv.wait(l);
449     }
450     EXPECT_EQ(state.request.message(), state.response.message());
451   }
452 }
453
454 TEST_P(ClientCallbackEnd2endTest, SimpleRpcUnderLock) {
455   MAYBE_SKIP_TEST;
456   ResetStub();
457   std::mutex mu;
458   std::condition_variable cv;
459   bool done = false;
460   EchoRequest request;
461   request.set_message("Hello locked world.");
462   EchoResponse response;
463   ClientContext cli_ctx;
464   {
465     std::lock_guard<std::mutex> l(mu);
466     stub_->experimental_async()->Echo(
467         &cli_ctx, &request, &response,
468         [&mu, &cv, &done, &request, &response](Status s) {
469           std::lock_guard<std::mutex> l(mu);
470           EXPECT_TRUE(s.ok());
471           EXPECT_EQ(request.message(), response.message());
472           done = true;
473           cv.notify_one();
474         });
475   }
476   std::unique_lock<std::mutex> l(mu);
477   while (!done) {
478     cv.wait(l);
479   }
480 }
481
482 TEST_P(ClientCallbackEnd2endTest, SequentialRpcs) {
483   MAYBE_SKIP_TEST;
484   ResetStub();
485   SendRpcs(10, false);
486 }
487
488 TEST_P(ClientCallbackEnd2endTest, SendClientInitialMetadata) {
489   MAYBE_SKIP_TEST;
490   ResetStub();
491   SimpleRequest request;
492   SimpleResponse response;
493   ClientContext cli_ctx;
494
495   cli_ctx.AddMetadata(kCheckClientInitialMetadataKey,
496                       kCheckClientInitialMetadataVal);
497
498   std::mutex mu;
499   std::condition_variable cv;
500   bool done = false;
501   stub_->experimental_async()->CheckClientInitialMetadata(
502       &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) {
503         GPR_ASSERT(s.ok());
504
505         std::lock_guard<std::mutex> l(mu);
506         done = true;
507         cv.notify_one();
508       });
509   std::unique_lock<std::mutex> l(mu);
510   while (!done) {
511     cv.wait(l);
512   }
513 }
514
515 TEST_P(ClientCallbackEnd2endTest, SimpleRpcWithBinaryMetadata) {
516   MAYBE_SKIP_TEST;
517   ResetStub();
518   SendRpcs(1, true);
519 }
520
521 TEST_P(ClientCallbackEnd2endTest, SequentialRpcsWithVariedBinaryMetadataValue) {
522   MAYBE_SKIP_TEST;
523   ResetStub();
524   SendRpcs(10, true);
525 }
526
527 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcs) {
528   MAYBE_SKIP_TEST;
529   ResetStub();
530   SendRpcsGeneric(10, false);
531 }
532
533 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidi) {
534   MAYBE_SKIP_TEST;
535   ResetStub();
536   SendGenericEchoAsBidi(10, 1, /*do_writes_done=*/true);
537 }
538
539 TEST_P(ClientCallbackEnd2endTest, SequentialGenericRpcsAsBidiWithReactorReuse) {
540   MAYBE_SKIP_TEST;
541   ResetStub();
542   SendGenericEchoAsBidi(10, 10, /*do_writes_done=*/true);
543 }
544
545 TEST_P(ClientCallbackEnd2endTest, GenericRpcNoWritesDone) {
546   MAYBE_SKIP_TEST;
547   ResetStub();
548   SendGenericEchoAsBidi(1, 1, /*do_writes_done=*/false);
549 }
550
551 #if GRPC_ALLOW_EXCEPTIONS
552 TEST_P(ClientCallbackEnd2endTest, ExceptingRpc) {
553   MAYBE_SKIP_TEST;
554   ResetStub();
555   SendRpcsGeneric(10, true);
556 }
557 #endif
558
559 TEST_P(ClientCallbackEnd2endTest, MultipleRpcsWithVariedBinaryMetadataValue) {
560   MAYBE_SKIP_TEST;
561   ResetStub();
562   std::vector<std::thread> threads;
563   threads.reserve(10);
564   for (int i = 0; i < 10; ++i) {
565     threads.emplace_back([this] { SendRpcs(10, true); });
566   }
567   for (int i = 0; i < 10; ++i) {
568     threads[i].join();
569   }
570 }
571
572 TEST_P(ClientCallbackEnd2endTest, MultipleRpcs) {
573   MAYBE_SKIP_TEST;
574   ResetStub();
575   std::vector<std::thread> threads;
576   threads.reserve(10);
577   for (int i = 0; i < 10; ++i) {
578     threads.emplace_back([this] { SendRpcs(10, false); });
579   }
580   for (int i = 0; i < 10; ++i) {
581     threads[i].join();
582   }
583 }
584
585 TEST_P(ClientCallbackEnd2endTest, CancelRpcBeforeStart) {
586   MAYBE_SKIP_TEST;
587   ResetStub();
588   EchoRequest request;
589   EchoResponse response;
590   ClientContext context;
591   request.set_message("hello");
592   context.TryCancel();
593
594   std::mutex mu;
595   std::condition_variable cv;
596   bool done = false;
597   stub_->experimental_async()->Echo(
598       &context, &request, &response, [&response, &done, &mu, &cv](Status s) {
599         EXPECT_EQ("", response.message());
600         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
601         std::lock_guard<std::mutex> l(mu);
602         done = true;
603         cv.notify_one();
604       });
605   std::unique_lock<std::mutex> l(mu);
606   while (!done) {
607     cv.wait(l);
608   }
609   if (GetParam().use_interceptors) {
610     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
611   }
612 }
613
614 TEST_P(ClientCallbackEnd2endTest, RequestEchoServerCancel) {
615   MAYBE_SKIP_TEST;
616   ResetStub();
617   EchoRequest request;
618   EchoResponse response;
619   ClientContext context;
620   request.set_message("hello");
621   context.AddMetadata(kServerTryCancelRequest,
622                       std::to_string(CANCEL_BEFORE_PROCESSING));
623
624   std::mutex mu;
625   std::condition_variable cv;
626   bool done = false;
627   stub_->experimental_async()->Echo(
628       &context, &request, &response, [&done, &mu, &cv](Status s) {
629         EXPECT_FALSE(s.ok());
630         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
631         std::lock_guard<std::mutex> l(mu);
632         done = true;
633         cv.notify_one();
634       });
635   std::unique_lock<std::mutex> l(mu);
636   while (!done) {
637     cv.wait(l);
638   }
639 }
640
641 struct ClientCancelInfo {
642   bool cancel{false};
643   int ops_before_cancel;
644
645   ClientCancelInfo() : cancel{false} {}
646   explicit ClientCancelInfo(int ops) : cancel{true}, ops_before_cancel{ops} {}
647 };
648
649 class WriteClient : public grpc::experimental::ClientWriteReactor<EchoRequest> {
650  public:
651   WriteClient(grpc::testing::EchoTestService::Stub* stub,
652               ServerTryCancelRequestPhase server_try_cancel,
653               int num_msgs_to_send, ClientCancelInfo client_cancel = {})
654       : server_try_cancel_(server_try_cancel),
655         num_msgs_to_send_(num_msgs_to_send),
656         client_cancel_{client_cancel} {
657     std::string msg{"Hello server."};
658     for (int i = 0; i < num_msgs_to_send; i++) {
659       desired_ += msg;
660     }
661     if (server_try_cancel != DO_NOT_CANCEL) {
662       // Send server_try_cancel value in the client metadata
663       context_.AddMetadata(kServerTryCancelRequest,
664                            std::to_string(server_try_cancel));
665     }
666     context_.set_initial_metadata_corked(true);
667     stub->experimental_async()->RequestStream(&context_, &response_, this);
668     StartCall();
669     request_.set_message(msg);
670     MaybeWrite();
671   }
672   void OnWriteDone(bool ok) override {
673     if (ok) {
674       num_msgs_sent_++;
675       MaybeWrite();
676     }
677   }
678   void OnDone(const Status& s) override {
679     gpr_log(GPR_INFO, "Sent %d messages", num_msgs_sent_);
680     int num_to_send =
681         (client_cancel_.cancel)
682             ? std::min(num_msgs_to_send_, client_cancel_.ops_before_cancel)
683             : num_msgs_to_send_;
684     switch (server_try_cancel_) {
685       case CANCEL_BEFORE_PROCESSING:
686       case CANCEL_DURING_PROCESSING:
687         // If the RPC is canceled by server before / during messages from the
688         // client, it means that the client most likely did not get a chance to
689         // send all the messages it wanted to send. i.e num_msgs_sent <=
690         // num_msgs_to_send
691         EXPECT_LE(num_msgs_sent_, num_to_send);
692         break;
693       case DO_NOT_CANCEL:
694       case CANCEL_AFTER_PROCESSING:
695         // If the RPC was not canceled or canceled after all messages were read
696         // by the server, the client did get a chance to send all its messages
697         EXPECT_EQ(num_msgs_sent_, num_to_send);
698         break;
699       default:
700         assert(false);
701         break;
702     }
703     if ((server_try_cancel_ == DO_NOT_CANCEL) && !client_cancel_.cancel) {
704       EXPECT_TRUE(s.ok());
705       EXPECT_EQ(response_.message(), desired_);
706     } else {
707       EXPECT_FALSE(s.ok());
708       EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
709     }
710     std::unique_lock<std::mutex> l(mu_);
711     done_ = true;
712     cv_.notify_one();
713   }
714   void Await() {
715     std::unique_lock<std::mutex> l(mu_);
716     while (!done_) {
717       cv_.wait(l);
718     }
719   }
720
721  private:
722   void MaybeWrite() {
723     if (client_cancel_.cancel &&
724         num_msgs_sent_ == client_cancel_.ops_before_cancel) {
725       context_.TryCancel();
726     } else if (num_msgs_to_send_ > num_msgs_sent_ + 1) {
727       StartWrite(&request_);
728     } else if (num_msgs_to_send_ == num_msgs_sent_ + 1) {
729       StartWriteLast(&request_, WriteOptions());
730     }
731   }
732   EchoRequest request_;
733   EchoResponse response_;
734   ClientContext context_;
735   const ServerTryCancelRequestPhase server_try_cancel_;
736   int num_msgs_sent_{0};
737   const int num_msgs_to_send_;
738   std::string desired_;
739   const ClientCancelInfo client_cancel_;
740   std::mutex mu_;
741   std::condition_variable cv_;
742   bool done_ = false;
743 };
744
745 TEST_P(ClientCallbackEnd2endTest, RequestStream) {
746   MAYBE_SKIP_TEST;
747   ResetStub();
748   WriteClient test{stub_.get(), DO_NOT_CANCEL, 3};
749   test.Await();
750   // Make sure that the server interceptors were not notified to cancel
751   if (GetParam().use_interceptors) {
752     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
753   }
754 }
755
756 TEST_P(ClientCallbackEnd2endTest, ClientCancelsRequestStream) {
757   MAYBE_SKIP_TEST;
758   ResetStub();
759   WriteClient test{stub_.get(), DO_NOT_CANCEL, 3, ClientCancelInfo{2}};
760   test.Await();
761   // Make sure that the server interceptors got the cancel
762   if (GetParam().use_interceptors) {
763     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
764   }
765 }
766
767 // Server to cancel before doing reading the request
768 TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelBeforeReads) {
769   MAYBE_SKIP_TEST;
770   ResetStub();
771   WriteClient test{stub_.get(), CANCEL_BEFORE_PROCESSING, 1};
772   test.Await();
773   // Make sure that the server interceptors were notified
774   if (GetParam().use_interceptors) {
775     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
776   }
777 }
778
779 // Server to cancel while reading a request from the stream in parallel
780 TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelDuringRead) {
781   MAYBE_SKIP_TEST;
782   ResetStub();
783   WriteClient test{stub_.get(), CANCEL_DURING_PROCESSING, 10};
784   test.Await();
785   // Make sure that the server interceptors were notified
786   if (GetParam().use_interceptors) {
787     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
788   }
789 }
790
791 // Server to cancel after reading all the requests but before returning to the
792 // client
793 TEST_P(ClientCallbackEnd2endTest, RequestStreamServerCancelAfterReads) {
794   MAYBE_SKIP_TEST;
795   ResetStub();
796   WriteClient test{stub_.get(), CANCEL_AFTER_PROCESSING, 4};
797   test.Await();
798   // Make sure that the server interceptors were notified
799   if (GetParam().use_interceptors) {
800     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
801   }
802 }
803
804 TEST_P(ClientCallbackEnd2endTest, UnaryReactor) {
805   MAYBE_SKIP_TEST;
806   ResetStub();
807   class UnaryClient : public grpc::experimental::ClientUnaryReactor {
808    public:
809     explicit UnaryClient(grpc::testing::EchoTestService::Stub* stub) {
810       cli_ctx_.AddMetadata("key1", "val1");
811       cli_ctx_.AddMetadata("key2", "val2");
812       request_.mutable_param()->set_echo_metadata_initially(true);
813       request_.set_message("Hello metadata");
814       stub->experimental_async()->Echo(&cli_ctx_, &request_, &response_, this);
815       StartCall();
816     }
817     void OnReadInitialMetadataDone(bool ok) override {
818       EXPECT_TRUE(ok);
819       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1"));
820       EXPECT_EQ(
821           "val1",
822           ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second));
823       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2"));
824       EXPECT_EQ(
825           "val2",
826           ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second));
827       initial_metadata_done_ = true;
828     }
829     void OnDone(const Status& s) override {
830       EXPECT_TRUE(initial_metadata_done_);
831       EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size());
832       EXPECT_TRUE(s.ok());
833       EXPECT_EQ(request_.message(), response_.message());
834       std::unique_lock<std::mutex> l(mu_);
835       done_ = true;
836       cv_.notify_one();
837     }
838     void Await() {
839       std::unique_lock<std::mutex> l(mu_);
840       while (!done_) {
841         cv_.wait(l);
842       }
843     }
844
845    private:
846     EchoRequest request_;
847     EchoResponse response_;
848     ClientContext cli_ctx_;
849     std::mutex mu_;
850     std::condition_variable cv_;
851     bool done_{false};
852     bool initial_metadata_done_{false};
853   };
854
855   UnaryClient test{stub_.get()};
856   test.Await();
857   // Make sure that the server interceptors were not notified of a cancel
858   if (GetParam().use_interceptors) {
859     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
860   }
861 }
862
863 TEST_P(ClientCallbackEnd2endTest, GenericUnaryReactor) {
864   MAYBE_SKIP_TEST;
865   ResetStub();
866   const std::string kMethodName("/grpc.testing.EchoTestService/Echo");
867   class UnaryClient : public grpc::experimental::ClientUnaryReactor {
868    public:
869     UnaryClient(grpc::GenericStub* stub, const std::string& method_name) {
870       cli_ctx_.AddMetadata("key1", "val1");
871       cli_ctx_.AddMetadata("key2", "val2");
872       request_.mutable_param()->set_echo_metadata_initially(true);
873       request_.set_message("Hello metadata");
874       send_buf_ = SerializeToByteBuffer(&request_);
875
876       stub->experimental().PrepareUnaryCall(&cli_ctx_, method_name,
877                                             send_buf_.get(), &recv_buf_, this);
878       StartCall();
879     }
880     void OnReadInitialMetadataDone(bool ok) override {
881       EXPECT_TRUE(ok);
882       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key1"));
883       EXPECT_EQ(
884           "val1",
885           ToString(cli_ctx_.GetServerInitialMetadata().find("key1")->second));
886       EXPECT_EQ(1u, cli_ctx_.GetServerInitialMetadata().count("key2"));
887       EXPECT_EQ(
888           "val2",
889           ToString(cli_ctx_.GetServerInitialMetadata().find("key2")->second));
890       initial_metadata_done_ = true;
891     }
892     void OnDone(const Status& s) override {
893       EXPECT_TRUE(initial_metadata_done_);
894       EXPECT_EQ(0u, cli_ctx_.GetServerTrailingMetadata().size());
895       EXPECT_TRUE(s.ok());
896       EchoResponse response;
897       EXPECT_TRUE(ParseFromByteBuffer(&recv_buf_, &response));
898       EXPECT_EQ(request_.message(), response.message());
899       std::unique_lock<std::mutex> l(mu_);
900       done_ = true;
901       cv_.notify_one();
902     }
903     void Await() {
904       std::unique_lock<std::mutex> l(mu_);
905       while (!done_) {
906         cv_.wait(l);
907       }
908     }
909
910    private:
911     EchoRequest request_;
912     std::unique_ptr<ByteBuffer> send_buf_;
913     ByteBuffer recv_buf_;
914     ClientContext cli_ctx_;
915     std::mutex mu_;
916     std::condition_variable cv_;
917     bool done_{false};
918     bool initial_metadata_done_{false};
919   };
920
921   UnaryClient test{generic_stub_.get(), kMethodName};
922   test.Await();
923   // Make sure that the server interceptors were not notified of a cancel
924   if (GetParam().use_interceptors) {
925     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
926   }
927 }
928
929 class ReadClient : public grpc::experimental::ClientReadReactor<EchoResponse> {
930  public:
931   ReadClient(grpc::testing::EchoTestService::Stub* stub,
932              ServerTryCancelRequestPhase server_try_cancel,
933              ClientCancelInfo client_cancel = {})
934       : server_try_cancel_(server_try_cancel), client_cancel_{client_cancel} {
935     if (server_try_cancel_ != DO_NOT_CANCEL) {
936       // Send server_try_cancel value in the client metadata
937       context_.AddMetadata(kServerTryCancelRequest,
938                            std::to_string(server_try_cancel));
939     }
940     request_.set_message("Hello client ");
941     stub->experimental_async()->ResponseStream(&context_, &request_, this);
942     if (client_cancel_.cancel &&
943         reads_complete_ == client_cancel_.ops_before_cancel) {
944       context_.TryCancel();
945     }
946     // Even if we cancel, read until failure because there might be responses
947     // pending
948     StartRead(&response_);
949     StartCall();
950   }
951   void OnReadDone(bool ok) override {
952     if (!ok) {
953       if (server_try_cancel_ == DO_NOT_CANCEL && !client_cancel_.cancel) {
954         EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend);
955       }
956     } else {
957       EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
958       EXPECT_EQ(response_.message(),
959                 request_.message() + std::to_string(reads_complete_));
960       reads_complete_++;
961       if (client_cancel_.cancel &&
962           reads_complete_ == client_cancel_.ops_before_cancel) {
963         context_.TryCancel();
964       }
965       // Even if we cancel, read until failure because there might be responses
966       // pending
967       StartRead(&response_);
968     }
969   }
970   void OnDone(const Status& s) override {
971     gpr_log(GPR_INFO, "Read %d messages", reads_complete_);
972     switch (server_try_cancel_) {
973       case DO_NOT_CANCEL:
974         if (!client_cancel_.cancel || client_cancel_.ops_before_cancel >
975                                           kServerDefaultResponseStreamsToSend) {
976           EXPECT_TRUE(s.ok());
977           EXPECT_EQ(reads_complete_, kServerDefaultResponseStreamsToSend);
978         } else {
979           EXPECT_GE(reads_complete_, client_cancel_.ops_before_cancel);
980           EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
981           // Status might be ok or cancelled depending on whether server
982           // sent status before client cancel went through
983           if (!s.ok()) {
984             EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
985           }
986         }
987         break;
988       case CANCEL_BEFORE_PROCESSING:
989         EXPECT_FALSE(s.ok());
990         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
991         EXPECT_EQ(reads_complete_, 0);
992         break;
993       case CANCEL_DURING_PROCESSING:
994       case CANCEL_AFTER_PROCESSING:
995         // If server canceled while writing messages, client must have read
996         // less than or equal to the expected number of messages. Even if the
997         // server canceled after writing all messages, the RPC may be canceled
998         // before the Client got a chance to read all the messages.
999         EXPECT_FALSE(s.ok());
1000         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1001         EXPECT_LE(reads_complete_, kServerDefaultResponseStreamsToSend);
1002         break;
1003       default:
1004         assert(false);
1005     }
1006     std::unique_lock<std::mutex> l(mu_);
1007     done_ = true;
1008     cv_.notify_one();
1009   }
1010   void Await() {
1011     std::unique_lock<std::mutex> l(mu_);
1012     while (!done_) {
1013       cv_.wait(l);
1014     }
1015   }
1016
1017  private:
1018   EchoRequest request_;
1019   EchoResponse response_;
1020   ClientContext context_;
1021   const ServerTryCancelRequestPhase server_try_cancel_;
1022   int reads_complete_{0};
1023   const ClientCancelInfo client_cancel_;
1024   std::mutex mu_;
1025   std::condition_variable cv_;
1026   bool done_ = false;
1027 };
1028
1029 TEST_P(ClientCallbackEnd2endTest, ResponseStream) {
1030   MAYBE_SKIP_TEST;
1031   ResetStub();
1032   ReadClient test{stub_.get(), DO_NOT_CANCEL};
1033   test.Await();
1034   // Make sure that the server interceptors were not notified of a cancel
1035   if (GetParam().use_interceptors) {
1036     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
1037   }
1038 }
1039
1040 TEST_P(ClientCallbackEnd2endTest, ClientCancelsResponseStream) {
1041   MAYBE_SKIP_TEST;
1042   ResetStub();
1043   ReadClient test{stub_.get(), DO_NOT_CANCEL, ClientCancelInfo{2}};
1044   test.Await();
1045   // Because cancel in this case races with server finish, we can't be sure that
1046   // server interceptors even see cancellation
1047 }
1048
1049 // Server to cancel before sending any response messages
1050 TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelBefore) {
1051   MAYBE_SKIP_TEST;
1052   ResetStub();
1053   ReadClient test{stub_.get(), CANCEL_BEFORE_PROCESSING};
1054   test.Await();
1055   // Make sure that the server interceptors were notified
1056   if (GetParam().use_interceptors) {
1057     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
1058   }
1059 }
1060
1061 // Server to cancel while writing a response to the stream in parallel
1062 TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelDuring) {
1063   MAYBE_SKIP_TEST;
1064   ResetStub();
1065   ReadClient test{stub_.get(), CANCEL_DURING_PROCESSING};
1066   test.Await();
1067   // Make sure that the server interceptors were notified
1068   if (GetParam().use_interceptors) {
1069     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
1070   }
1071 }
1072
1073 // Server to cancel after writing all the respones to the stream but before
1074 // returning to the client
1075 TEST_P(ClientCallbackEnd2endTest, ResponseStreamServerCancelAfter) {
1076   MAYBE_SKIP_TEST;
1077   ResetStub();
1078   ReadClient test{stub_.get(), CANCEL_AFTER_PROCESSING};
1079   test.Await();
1080   // Make sure that the server interceptors were notified
1081   if (GetParam().use_interceptors) {
1082     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
1083   }
1084 }
1085
1086 class BidiClient
1087     : public grpc::experimental::ClientBidiReactor<EchoRequest, EchoResponse> {
1088  public:
1089   BidiClient(grpc::testing::EchoTestService::Stub* stub,
1090              ServerTryCancelRequestPhase server_try_cancel,
1091              int num_msgs_to_send, bool cork_metadata, bool first_write_async,
1092              ClientCancelInfo client_cancel = {})
1093       : server_try_cancel_(server_try_cancel),
1094         msgs_to_send_{num_msgs_to_send},
1095         client_cancel_{client_cancel} {
1096     if (server_try_cancel_ != DO_NOT_CANCEL) {
1097       // Send server_try_cancel value in the client metadata
1098       context_.AddMetadata(kServerTryCancelRequest,
1099                            std::to_string(server_try_cancel));
1100     }
1101     request_.set_message("Hello fren ");
1102     context_.set_initial_metadata_corked(cork_metadata);
1103     stub->experimental_async()->BidiStream(&context_, this);
1104     MaybeAsyncWrite(first_write_async);
1105     StartRead(&response_);
1106     StartCall();
1107   }
1108   void OnReadDone(bool ok) override {
1109     if (!ok) {
1110       if (server_try_cancel_ == DO_NOT_CANCEL) {
1111         if (!client_cancel_.cancel) {
1112           EXPECT_EQ(reads_complete_, msgs_to_send_);
1113         } else {
1114           EXPECT_LE(reads_complete_, writes_complete_);
1115         }
1116       }
1117     } else {
1118       EXPECT_LE(reads_complete_, msgs_to_send_);
1119       EXPECT_EQ(response_.message(), request_.message());
1120       reads_complete_++;
1121       StartRead(&response_);
1122     }
1123   }
1124   void OnWriteDone(bool ok) override {
1125     if (async_write_thread_.joinable()) {
1126       async_write_thread_.join();
1127       RemoveHold();
1128     }
1129     if (server_try_cancel_ == DO_NOT_CANCEL) {
1130       EXPECT_TRUE(ok);
1131     } else if (!ok) {
1132       return;
1133     }
1134     writes_complete_++;
1135     MaybeWrite();
1136   }
1137   void OnDone(const Status& s) override {
1138     gpr_log(GPR_INFO, "Sent %d messages", writes_complete_);
1139     gpr_log(GPR_INFO, "Read %d messages", reads_complete_);
1140     switch (server_try_cancel_) {
1141       case DO_NOT_CANCEL:
1142         if (!client_cancel_.cancel ||
1143             client_cancel_.ops_before_cancel > msgs_to_send_) {
1144           EXPECT_TRUE(s.ok());
1145           EXPECT_EQ(writes_complete_, msgs_to_send_);
1146           EXPECT_EQ(reads_complete_, writes_complete_);
1147         } else {
1148           EXPECT_FALSE(s.ok());
1149           EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1150           EXPECT_EQ(writes_complete_, client_cancel_.ops_before_cancel);
1151           EXPECT_LE(reads_complete_, writes_complete_);
1152         }
1153         break;
1154       case CANCEL_BEFORE_PROCESSING:
1155         EXPECT_FALSE(s.ok());
1156         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1157         // The RPC is canceled before the server did any work or returned any
1158         // reads, but it's possible that some writes took place first from the
1159         // client
1160         EXPECT_LE(writes_complete_, msgs_to_send_);
1161         EXPECT_EQ(reads_complete_, 0);
1162         break;
1163       case CANCEL_DURING_PROCESSING:
1164         EXPECT_FALSE(s.ok());
1165         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1166         EXPECT_LE(writes_complete_, msgs_to_send_);
1167         EXPECT_LE(reads_complete_, writes_complete_);
1168         break;
1169       case CANCEL_AFTER_PROCESSING:
1170         EXPECT_FALSE(s.ok());
1171         EXPECT_EQ(grpc::StatusCode::CANCELLED, s.error_code());
1172         EXPECT_EQ(writes_complete_, msgs_to_send_);
1173         // The Server canceled after reading the last message and after writing
1174         // the message to the client. However, the RPC cancellation might have
1175         // taken effect before the client actually read the response.
1176         EXPECT_LE(reads_complete_, writes_complete_);
1177         break;
1178       default:
1179         assert(false);
1180     }
1181     std::unique_lock<std::mutex> l(mu_);
1182     done_ = true;
1183     cv_.notify_one();
1184   }
1185   void Await() {
1186     std::unique_lock<std::mutex> l(mu_);
1187     while (!done_) {
1188       cv_.wait(l);
1189     }
1190   }
1191
1192  private:
1193   void MaybeAsyncWrite(bool first_write_async) {
1194     if (first_write_async) {
1195       // Make sure that we have a write to issue.
1196       // TODO(vjpai): Make this work with 0 writes case as well.
1197       assert(msgs_to_send_ >= 1);
1198
1199       AddHold();
1200       async_write_thread_ = std::thread([this] {
1201         std::unique_lock<std::mutex> lock(async_write_thread_mu_);
1202         async_write_thread_cv_.wait(
1203             lock, [this] { return async_write_thread_start_; });
1204         MaybeWrite();
1205       });
1206       std::lock_guard<std::mutex> lock(async_write_thread_mu_);
1207       async_write_thread_start_ = true;
1208       async_write_thread_cv_.notify_one();
1209       return;
1210     }
1211     MaybeWrite();
1212   }
1213   void MaybeWrite() {
1214     if (client_cancel_.cancel &&
1215         writes_complete_ == client_cancel_.ops_before_cancel) {
1216       context_.TryCancel();
1217     } else if (writes_complete_ == msgs_to_send_) {
1218       StartWritesDone();
1219     } else {
1220       StartWrite(&request_);
1221     }
1222   }
1223   EchoRequest request_;
1224   EchoResponse response_;
1225   ClientContext context_;
1226   const ServerTryCancelRequestPhase server_try_cancel_;
1227   int reads_complete_{0};
1228   int writes_complete_{0};
1229   const int msgs_to_send_;
1230   const ClientCancelInfo client_cancel_;
1231   std::mutex mu_;
1232   std::condition_variable cv_;
1233   bool done_ = false;
1234   std::thread async_write_thread_;
1235   bool async_write_thread_start_ = false;
1236   std::mutex async_write_thread_mu_;
1237   std::condition_variable async_write_thread_cv_;
1238 };
1239
1240 TEST_P(ClientCallbackEnd2endTest, BidiStream) {
1241   MAYBE_SKIP_TEST;
1242   ResetStub();
1243   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1244                   kServerDefaultResponseStreamsToSend,
1245                   /*cork_metadata=*/false, /*first_write_async=*/false);
1246   test.Await();
1247   // Make sure that the server interceptors were not notified of a cancel
1248   if (GetParam().use_interceptors) {
1249     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
1250   }
1251 }
1252
1253 TEST_P(ClientCallbackEnd2endTest, BidiStreamFirstWriteAsync) {
1254   MAYBE_SKIP_TEST;
1255   ResetStub();
1256   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1257                   kServerDefaultResponseStreamsToSend,
1258                   /*cork_metadata=*/false, /*first_write_async=*/true);
1259   test.Await();
1260   // Make sure that the server interceptors were not notified of a cancel
1261   if (GetParam().use_interceptors) {
1262     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
1263   }
1264 }
1265
1266 TEST_P(ClientCallbackEnd2endTest, BidiStreamCorked) {
1267   MAYBE_SKIP_TEST;
1268   ResetStub();
1269   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1270                   kServerDefaultResponseStreamsToSend,
1271                   /*cork_metadata=*/true, /*first_write_async=*/false);
1272   test.Await();
1273   // Make sure that the server interceptors were not notified of a cancel
1274   if (GetParam().use_interceptors) {
1275     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
1276   }
1277 }
1278
1279 TEST_P(ClientCallbackEnd2endTest, BidiStreamCorkedFirstWriteAsync) {
1280   MAYBE_SKIP_TEST;
1281   ResetStub();
1282   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1283                   kServerDefaultResponseStreamsToSend,
1284                   /*cork_metadata=*/true, /*first_write_async=*/true);
1285   test.Await();
1286   // Make sure that the server interceptors were not notified of a cancel
1287   if (GetParam().use_interceptors) {
1288     EXPECT_EQ(0, DummyInterceptor::GetNumTimesCancel());
1289   }
1290 }
1291
1292 TEST_P(ClientCallbackEnd2endTest, ClientCancelsBidiStream) {
1293   MAYBE_SKIP_TEST;
1294   ResetStub();
1295   BidiClient test(stub_.get(), DO_NOT_CANCEL,
1296                   kServerDefaultResponseStreamsToSend,
1297                   /*cork_metadata=*/false, /*first_write_async=*/false,
1298                   ClientCancelInfo(2));
1299   test.Await();
1300   // Make sure that the server interceptors were notified of a cancel
1301   if (GetParam().use_interceptors) {
1302     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
1303   }
1304 }
1305
1306 // Server to cancel before reading/writing any requests/responses on the stream
1307 TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelBefore) {
1308   MAYBE_SKIP_TEST;
1309   ResetStub();
1310   BidiClient test(stub_.get(), CANCEL_BEFORE_PROCESSING, /*num_msgs_to_send=*/2,
1311                   /*cork_metadata=*/false, /*first_write_async=*/false);
1312   test.Await();
1313   // Make sure that the server interceptors were notified
1314   if (GetParam().use_interceptors) {
1315     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
1316   }
1317 }
1318
1319 // Server to cancel while reading/writing requests/responses on the stream in
1320 // parallel
1321 TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelDuring) {
1322   MAYBE_SKIP_TEST;
1323   ResetStub();
1324   BidiClient test(stub_.get(), CANCEL_DURING_PROCESSING,
1325                   /*num_msgs_to_send=*/10, /*cork_metadata=*/false,
1326                   /*first_write_async=*/false);
1327   test.Await();
1328   // Make sure that the server interceptors were notified
1329   if (GetParam().use_interceptors) {
1330     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
1331   }
1332 }
1333
1334 // Server to cancel after reading/writing all requests/responses on the stream
1335 // but before returning to the client
1336 TEST_P(ClientCallbackEnd2endTest, BidiStreamServerCancelAfter) {
1337   MAYBE_SKIP_TEST;
1338   ResetStub();
1339   BidiClient test(stub_.get(), CANCEL_AFTER_PROCESSING, /*num_msgs_to_send=*/5,
1340                   /*cork_metadata=*/false, /*first_write_async=*/false);
1341   test.Await();
1342   // Make sure that the server interceptors were notified
1343   if (GetParam().use_interceptors) {
1344     EXPECT_EQ(20, DummyInterceptor::GetNumTimesCancel());
1345   }
1346 }
1347
1348 TEST_P(ClientCallbackEnd2endTest, SimultaneousReadAndWritesDone) {
1349   MAYBE_SKIP_TEST;
1350   ResetStub();
1351   class Client : public grpc::experimental::ClientBidiReactor<EchoRequest,
1352                                                               EchoResponse> {
1353    public:
1354     explicit Client(grpc::testing::EchoTestService::Stub* stub) {
1355       request_.set_message("Hello bidi ");
1356       stub->experimental_async()->BidiStream(&context_, this);
1357       StartWrite(&request_);
1358       StartCall();
1359     }
1360     void OnReadDone(bool ok) override {
1361       EXPECT_TRUE(ok);
1362       EXPECT_EQ(response_.message(), request_.message());
1363     }
1364     void OnWriteDone(bool ok) override {
1365       EXPECT_TRUE(ok);
1366       // Now send out the simultaneous Read and WritesDone
1367       StartWritesDone();
1368       StartRead(&response_);
1369     }
1370     void OnDone(const Status& s) override {
1371       EXPECT_TRUE(s.ok());
1372       EXPECT_EQ(response_.message(), request_.message());
1373       std::unique_lock<std::mutex> l(mu_);
1374       done_ = true;
1375       cv_.notify_one();
1376     }
1377     void Await() {
1378       std::unique_lock<std::mutex> l(mu_);
1379       while (!done_) {
1380         cv_.wait(l);
1381       }
1382     }
1383
1384    private:
1385     EchoRequest request_;
1386     EchoResponse response_;
1387     ClientContext context_;
1388     std::mutex mu_;
1389     std::condition_variable cv_;
1390     bool done_ = false;
1391   } test{stub_.get()};
1392
1393   test.Await();
1394 }
1395
1396 TEST_P(ClientCallbackEnd2endTest, UnimplementedRpc) {
1397   MAYBE_SKIP_TEST;
1398   ChannelArguments args;
1399   const auto& channel_creds = GetCredentialsProvider()->GetChannelCredentials(
1400       GetParam().credentials_type, &args);
1401   std::shared_ptr<Channel> channel =
1402       (GetParam().protocol == Protocol::TCP)
1403           ? ::grpc::CreateCustomChannel(server_address_.str(), channel_creds,
1404                                         args)
1405           : server_->InProcessChannel(args);
1406   std::unique_ptr<grpc::testing::UnimplementedEchoService::Stub> stub;
1407   stub = grpc::testing::UnimplementedEchoService::NewStub(channel);
1408   EchoRequest request;
1409   EchoResponse response;
1410   ClientContext cli_ctx;
1411   request.set_message("Hello world.");
1412   std::mutex mu;
1413   std::condition_variable cv;
1414   bool done = false;
1415   stub->experimental_async()->Unimplemented(
1416       &cli_ctx, &request, &response, [&done, &mu, &cv](Status s) {
1417         EXPECT_EQ(StatusCode::UNIMPLEMENTED, s.error_code());
1418         EXPECT_EQ("", s.error_message());
1419
1420         std::lock_guard<std::mutex> l(mu);
1421         done = true;
1422         cv.notify_one();
1423       });
1424   std::unique_lock<std::mutex> l(mu);
1425   while (!done) {
1426     cv.wait(l);
1427   }
1428 }
1429
1430 TEST_P(ClientCallbackEnd2endTest,
1431        ResponseStreamExtraReactionFlowReadsUntilDone) {
1432   MAYBE_SKIP_TEST;
1433   ResetStub();
1434   class ReadAllIncomingDataClient
1435       : public grpc::experimental::ClientReadReactor<EchoResponse> {
1436    public:
1437     explicit ReadAllIncomingDataClient(
1438         grpc::testing::EchoTestService::Stub* stub) {
1439       request_.set_message("Hello client ");
1440       stub->experimental_async()->ResponseStream(&context_, &request_, this);
1441     }
1442     bool WaitForReadDone() {
1443       std::unique_lock<std::mutex> l(mu_);
1444       while (!read_done_) {
1445         read_cv_.wait(l);
1446       }
1447       read_done_ = false;
1448       return read_ok_;
1449     }
1450     void Await() {
1451       std::unique_lock<std::mutex> l(mu_);
1452       while (!done_) {
1453         done_cv_.wait(l);
1454       }
1455     }
1456     // RemoveHold under the same lock used for OnDone to make sure that we don't
1457     // call OnDone directly or indirectly from the RemoveHold function.
1458     void RemoveHoldUnderLock() {
1459       std::unique_lock<std::mutex> l(mu_);
1460       RemoveHold();
1461     }
1462     const Status& status() {
1463       std::unique_lock<std::mutex> l(mu_);
1464       return status_;
1465     }
1466
1467    private:
1468     void OnReadDone(bool ok) override {
1469       std::unique_lock<std::mutex> l(mu_);
1470       read_ok_ = ok;
1471       read_done_ = true;
1472       read_cv_.notify_one();
1473     }
1474     void OnDone(const Status& s) override {
1475       std::unique_lock<std::mutex> l(mu_);
1476       done_ = true;
1477       status_ = s;
1478       done_cv_.notify_one();
1479     }
1480
1481     EchoRequest request_;
1482     EchoResponse response_;
1483     ClientContext context_;
1484     bool read_ok_ = false;
1485     bool read_done_ = false;
1486     std::mutex mu_;
1487     std::condition_variable read_cv_;
1488     std::condition_variable done_cv_;
1489     bool done_ = false;
1490     Status status_;
1491   } client{stub_.get()};
1492
1493   int reads_complete = 0;
1494   client.AddHold();
1495   client.StartCall();
1496
1497   EchoResponse response;
1498   bool read_ok = true;
1499   while (read_ok) {
1500     client.StartRead(&response);
1501     read_ok = client.WaitForReadDone();
1502     if (read_ok) {
1503       ++reads_complete;
1504     }
1505   }
1506   client.RemoveHoldUnderLock();
1507   client.Await();
1508
1509   EXPECT_EQ(kServerDefaultResponseStreamsToSend, reads_complete);
1510   EXPECT_EQ(client.status().error_code(), grpc::StatusCode::OK);
1511 }
1512
1513 std::vector<TestScenario> CreateTestScenarios(bool test_insecure) {
1514 #if TARGET_OS_IPHONE
1515   // Workaround Apple CFStream bug
1516   gpr_setenv("grpc_cfstream", "0");
1517 #endif
1518
1519   std::vector<TestScenario> scenarios;
1520   std::vector<std::string> credentials_types{
1521       GetCredentialsProvider()->GetSecureCredentialsTypeList()};
1522   auto insec_ok = [] {
1523     // Only allow insecure credentials type when it is registered with the
1524     // provider. User may create providers that do not have insecure.
1525     return GetCredentialsProvider()->GetChannelCredentials(
1526                kInsecureCredentialsType, nullptr) != nullptr;
1527   };
1528   if (test_insecure && insec_ok()) {
1529     credentials_types.push_back(kInsecureCredentialsType);
1530   }
1531   GPR_ASSERT(!credentials_types.empty());
1532
1533   bool barr[]{false, true};
1534   Protocol parr[]{Protocol::INPROC, Protocol::TCP};
1535   for (Protocol p : parr) {
1536     for (const auto& cred : credentials_types) {
1537       // TODO(vjpai): Test inproc with secure credentials when feasible
1538       if (p == Protocol::INPROC &&
1539           (cred != kInsecureCredentialsType || !insec_ok())) {
1540         continue;
1541       }
1542       for (bool callback_server : barr) {
1543         for (bool use_interceptors : barr) {
1544           scenarios.emplace_back(callback_server, p, use_interceptors, cred);
1545         }
1546       }
1547     }
1548   }
1549   return scenarios;
1550 }
1551
1552 INSTANTIATE_TEST_SUITE_P(ClientCallbackEnd2endTest, ClientCallbackEnd2endTest,
1553                          ::testing::ValuesIn(CreateTestScenarios(true)));
1554
1555 }  // namespace
1556 }  // namespace testing
1557 }  // namespace grpc
1558
1559 int main(int argc, char** argv) {
1560   ::testing::InitGoogleTest(&argc, argv);
1561   grpc::testing::TestEnvironment env(argc, argv);
1562   grpc_init();
1563   int ret = RUN_ALL_TESTS();
1564   grpc_shutdown();
1565   return ret;
1566 }