3 * Copyright 2018 gRPC authors.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
22 #include <grpcpp/channel.h>
23 #include <grpcpp/client_context.h>
24 #include <grpcpp/create_channel.h>
25 #include <grpcpp/generic/generic_stub.h>
26 #include <grpcpp/impl/codegen/proto_utils.h>
27 #include <grpcpp/server.h>
28 #include <grpcpp/server_builder.h>
29 #include <grpcpp/server_context.h>
30 #include <grpcpp/support/client_interceptor.h>
32 #include "src/proto/grpc/testing/echo.grpc.pb.h"
33 #include "test/core/util/port.h"
34 #include "test/core/util/test_config.h"
35 #include "test/cpp/end2end/interceptors_util.h"
36 #include "test/cpp/end2end/test_service_impl.h"
37 #include "test/cpp/util/byte_buffer_proto_helper.h"
38 #include "test/cpp/util/string_ref_helper.h"
40 #include <gtest/gtest.h>
52 kAsyncCQClientStreaming,
53 kAsyncCQServerStreaming,
54 kAsyncCQBidiStreaming,
57 /* Hijacks Echo RPC and fills in the expected values */
58 class HijackingInterceptor : public experimental::Interceptor {
60 HijackingInterceptor(experimental::ClientRpcInfo* info) {
62 // Make sure it is the right method
63 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
64 EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
67 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
69 if (methods->QueryInterceptionHookPoint(
70 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
71 auto* map = methods->GetSendInitialMetadata();
72 // Check that we can see the test metadata
73 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
74 auto iterator = map->begin();
75 EXPECT_EQ("testkey", iterator->first);
76 EXPECT_EQ("testvalue", iterator->second);
79 if (methods->QueryInterceptionHookPoint(
80 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
82 auto* buffer = methods->GetSerializedSendMessage();
83 auto copied_buffer = *buffer;
85 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
87 EXPECT_EQ(req.message(), "Hello");
89 if (methods->QueryInterceptionHookPoint(
90 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
91 // Got nothing to do here for now
93 if (methods->QueryInterceptionHookPoint(
94 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
95 auto* map = methods->GetRecvInitialMetadata();
96 // Got nothing better to do here for now
97 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
99 if (methods->QueryInterceptionHookPoint(
100 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
102 static_cast<EchoResponse*>(methods->GetRecvMessage());
103 // Check that we got the hijacked message, and re-insert the expected
105 EXPECT_EQ(resp->message(), "Hello1");
106 resp->set_message("Hello");
108 if (methods->QueryInterceptionHookPoint(
109 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
110 auto* map = methods->GetRecvTrailingMetadata();
112 // Check that we received the metadata as an echo
113 for (const auto& pair : *map) {
114 found = pair.first.starts_with("testkey") &&
115 pair.second.starts_with("testvalue");
118 EXPECT_EQ(found, true);
119 auto* status = methods->GetRecvStatus();
120 EXPECT_EQ(status->ok(), true);
122 if (methods->QueryInterceptionHookPoint(
123 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
124 auto* map = methods->GetRecvInitialMetadata();
125 // Got nothing better to do here at the moment
126 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
128 if (methods->QueryInterceptionHookPoint(
129 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
130 // Insert a different message than expected
132 static_cast<EchoResponse*>(methods->GetRecvMessage());
133 resp->set_message("Hello1");
135 if (methods->QueryInterceptionHookPoint(
136 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
137 auto* map = methods->GetRecvTrailingMetadata();
138 // insert the metadata that we want
139 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
140 map->insert(std::make_pair("testkey", "testvalue"));
141 auto* status = methods->GetRecvStatus();
142 *status = Status(StatusCode::OK, "");
152 experimental::ClientRpcInfo* info_;
155 class HijackingInterceptorFactory
156 : public experimental::ClientInterceptorFactoryInterface {
158 virtual experimental::Interceptor* CreateClientInterceptor(
159 experimental::ClientRpcInfo* info) override {
160 return new HijackingInterceptor(info);
164 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
166 HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
168 // Make sure it is the right method
169 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
172 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
173 if (methods->QueryInterceptionHookPoint(
174 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
175 auto* map = methods->GetSendInitialMetadata();
176 // Check that we can see the test metadata
177 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
178 auto iterator = map->begin();
179 EXPECT_EQ("testkey", iterator->first);
180 EXPECT_EQ("testvalue", iterator->second);
181 // Make a copy of the map
182 metadata_map_ = *map;
184 if (methods->QueryInterceptionHookPoint(
185 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
187 auto* buffer = methods->GetSerializedSendMessage();
188 auto copied_buffer = *buffer;
190 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
192 EXPECT_EQ(req.message(), "Hello");
194 stub_ = grpc::testing::EchoTestService::NewStub(
195 methods->GetInterceptedChannel());
196 ctx_.AddMetadata(metadata_map_.begin()->first,
197 metadata_map_.begin()->second);
198 stub_->experimental_async()->Echo(&ctx_, &req_, &resp_,
199 [this, methods](Status s) {
200 EXPECT_EQ(s.ok(), true);
201 EXPECT_EQ(resp_.message(), "Hello");
204 // This is a Unary RPC and we have got nothing interesting to do in the
205 // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
206 // return here. (We do not want to call methods->Proceed(). When the new
207 // RPC returns, we will call methods->Hijack() instead.)
210 if (methods->QueryInterceptionHookPoint(
211 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
212 // Got nothing to do here for now
214 if (methods->QueryInterceptionHookPoint(
215 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
216 auto* map = methods->GetRecvInitialMetadata();
217 // Got nothing better to do here for now
218 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
220 if (methods->QueryInterceptionHookPoint(
221 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
223 static_cast<EchoResponse*>(methods->GetRecvMessage());
224 // Check that we got the hijacked message, and re-insert the expected
226 EXPECT_EQ(resp->message(), "Hello");
228 if (methods->QueryInterceptionHookPoint(
229 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
230 auto* map = methods->GetRecvTrailingMetadata();
232 // Check that we received the metadata as an echo
233 for (const auto& pair : *map) {
234 found = pair.first.starts_with("testkey") &&
235 pair.second.starts_with("testvalue");
238 EXPECT_EQ(found, true);
239 auto* status = methods->GetRecvStatus();
240 EXPECT_EQ(status->ok(), true);
242 if (methods->QueryInterceptionHookPoint(
243 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
244 auto* map = methods->GetRecvInitialMetadata();
245 // Got nothing better to do here at the moment
246 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
248 if (methods->QueryInterceptionHookPoint(
249 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
250 // Insert a different message than expected
252 static_cast<EchoResponse*>(methods->GetRecvMessage());
253 resp->set_message(resp_.message());
255 if (methods->QueryInterceptionHookPoint(
256 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
257 auto* map = methods->GetRecvTrailingMetadata();
258 // insert the metadata that we want
259 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
260 map->insert(std::make_pair("testkey", "testvalue"));
261 auto* status = methods->GetRecvStatus();
262 *status = Status(StatusCode::OK, "");
269 experimental::ClientRpcInfo* info_;
270 std::multimap<std::string, std::string> metadata_map_;
274 std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
277 class HijackingInterceptorMakesAnotherCallFactory
278 : public experimental::ClientInterceptorFactoryInterface {
280 virtual experimental::Interceptor* CreateClientInterceptor(
281 experimental::ClientRpcInfo* info) override {
282 return new HijackingInterceptorMakesAnotherCall(info);
286 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
288 BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
292 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
294 if (methods->QueryInterceptionHookPoint(
295 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
296 CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
299 if (methods->QueryInterceptionHookPoint(
300 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
302 auto* buffer = methods->GetSerializedSendMessage();
303 auto copied_buffer = *buffer;
305 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
307 EXPECT_EQ(req.message().find("Hello"), 0u);
310 if (methods->QueryInterceptionHookPoint(
311 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
312 // Got nothing to do here for now
314 if (methods->QueryInterceptionHookPoint(
315 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
316 CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
318 auto* status = methods->GetRecvStatus();
319 EXPECT_EQ(status->ok(), true);
321 if (methods->QueryInterceptionHookPoint(
322 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
324 static_cast<EchoResponse*>(methods->GetRecvMessage());
325 resp->set_message(msg);
327 if (methods->QueryInterceptionHookPoint(
328 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
329 EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
334 if (methods->QueryInterceptionHookPoint(
335 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
336 auto* map = methods->GetRecvTrailingMetadata();
337 // insert the metadata that we want
338 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
339 map->insert(std::make_pair("testkey", "testvalue"));
340 auto* status = methods->GetRecvStatus();
341 *status = Status(StatusCode::OK, "");
351 experimental::ClientRpcInfo* info_;
355 class ClientStreamingRpcHijackingInterceptor
356 : public experimental::Interceptor {
358 ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
361 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
363 if (methods->QueryInterceptionHookPoint(
364 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
367 if (methods->QueryInterceptionHookPoint(
368 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
370 methods->FailHijackedSendMessage();
373 if (methods->QueryInterceptionHookPoint(
374 experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
375 EXPECT_FALSE(got_failed_send_);
376 got_failed_send_ = !methods->GetSendMessageStatus();
378 if (methods->QueryInterceptionHookPoint(
379 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
380 auto* status = methods->GetRecvStatus();
381 *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
390 static bool GotFailedSend() { return got_failed_send_; }
393 experimental::ClientRpcInfo* info_;
395 static bool got_failed_send_;
398 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
400 class ClientStreamingRpcHijackingInterceptorFactory
401 : public experimental::ClientInterceptorFactoryInterface {
403 virtual experimental::Interceptor* CreateClientInterceptor(
404 experimental::ClientRpcInfo* info) override {
405 return new ClientStreamingRpcHijackingInterceptor(info);
409 class ServerStreamingRpcHijackingInterceptor
410 : public experimental::Interceptor {
412 ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
414 got_failed_message_ = false;
417 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
419 if (methods->QueryInterceptionHookPoint(
420 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
421 auto* map = methods->GetSendInitialMetadata();
422 // Check that we can see the test metadata
423 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
424 auto iterator = map->begin();
425 EXPECT_EQ("testkey", iterator->first);
426 EXPECT_EQ("testvalue", iterator->second);
429 if (methods->QueryInterceptionHookPoint(
430 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
432 auto* buffer = methods->GetSerializedSendMessage();
433 auto copied_buffer = *buffer;
435 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
437 EXPECT_EQ(req.message(), "Hello");
439 if (methods->QueryInterceptionHookPoint(
440 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
441 // Got nothing to do here for now
443 if (methods->QueryInterceptionHookPoint(
444 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
445 auto* map = methods->GetRecvTrailingMetadata();
447 // Check that we received the metadata as an echo
448 for (const auto& pair : *map) {
449 found = pair.first.starts_with("testkey") &&
450 pair.second.starts_with("testvalue");
453 EXPECT_EQ(found, true);
454 auto* status = methods->GetRecvStatus();
455 EXPECT_EQ(status->ok(), true);
457 if (methods->QueryInterceptionHookPoint(
458 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
460 methods->FailHijackedRecvMessage();
463 static_cast<EchoResponse*>(methods->GetRecvMessage());
464 resp->set_message("Hello");
466 if (methods->QueryInterceptionHookPoint(
467 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
468 // Only the last message will be a failure
469 EXPECT_FALSE(got_failed_message_);
470 got_failed_message_ = methods->GetRecvMessage() == nullptr;
472 if (methods->QueryInterceptionHookPoint(
473 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
474 auto* map = methods->GetRecvTrailingMetadata();
475 // insert the metadata that we want
476 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
477 map->insert(std::make_pair("testkey", "testvalue"));
478 auto* status = methods->GetRecvStatus();
479 *status = Status(StatusCode::OK, "");
488 static bool GotFailedMessage() { return got_failed_message_; }
491 experimental::ClientRpcInfo* info_;
492 static bool got_failed_message_;
496 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
498 class ServerStreamingRpcHijackingInterceptorFactory
499 : public experimental::ClientInterceptorFactoryInterface {
501 virtual experimental::Interceptor* CreateClientInterceptor(
502 experimental::ClientRpcInfo* info) override {
503 return new ServerStreamingRpcHijackingInterceptor(info);
507 class BidiStreamingRpcHijackingInterceptorFactory
508 : public experimental::ClientInterceptorFactoryInterface {
510 virtual experimental::Interceptor* CreateClientInterceptor(
511 experimental::ClientRpcInfo* info) override {
512 return new BidiStreamingRpcHijackingInterceptor(info);
516 // The logging interceptor is for testing purposes only. It is used to verify
517 // that all the appropriate hook points are invoked for an RPC. The counts are
518 // reset each time a new object of LoggingInterceptor is created, so only a
519 // single RPC should be made on the channel before calling the Verify methods.
520 class LoggingInterceptor : public experimental::Interceptor {
522 LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
523 pre_send_initial_metadata_ = false;
524 pre_send_message_count_ = 0;
525 pre_send_close_ = false;
526 post_recv_initial_metadata_ = false;
527 post_recv_message_count_ = 0;
528 post_recv_status_ = false;
531 virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
532 if (methods->QueryInterceptionHookPoint(
533 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
534 auto* map = methods->GetSendInitialMetadata();
535 // Check that we can see the test metadata
536 ASSERT_EQ(map->size(), static_cast<unsigned>(1));
537 auto iterator = map->begin();
538 EXPECT_EQ("testkey", iterator->first);
539 EXPECT_EQ("testvalue", iterator->second);
540 ASSERT_FALSE(pre_send_initial_metadata_);
541 pre_send_initial_metadata_ = true;
543 if (methods->QueryInterceptionHookPoint(
544 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
546 auto* send_msg = methods->GetSendMessage();
547 if (send_msg == nullptr) {
548 // We did not get the non-serialized form of the message. Get the
550 auto* buffer = methods->GetSerializedSendMessage();
551 auto copied_buffer = *buffer;
554 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
556 EXPECT_EQ(req.message(), "Hello");
559 static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
562 auto* buffer = methods->GetSerializedSendMessage();
563 auto copied_buffer = *buffer;
565 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
567 EXPECT_TRUE(req.message().find("Hello") == 0u);
568 pre_send_message_count_++;
570 if (methods->QueryInterceptionHookPoint(
571 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
572 // Got nothing to do here for now
573 pre_send_close_ = true;
575 if (methods->QueryInterceptionHookPoint(
576 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
577 auto* map = methods->GetRecvInitialMetadata();
578 // Got nothing better to do here for now
579 EXPECT_EQ(map->size(), static_cast<unsigned>(0));
580 post_recv_initial_metadata_ = true;
582 if (methods->QueryInterceptionHookPoint(
583 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
585 static_cast<EchoResponse*>(methods->GetRecvMessage());
586 if (resp != nullptr) {
587 EXPECT_TRUE(resp->message().find("Hello") == 0u);
588 post_recv_message_count_++;
591 if (methods->QueryInterceptionHookPoint(
592 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
593 auto* map = methods->GetRecvTrailingMetadata();
595 // Check that we received the metadata as an echo
596 for (const auto& pair : *map) {
597 found = pair.first.starts_with("testkey") &&
598 pair.second.starts_with("testvalue");
601 EXPECT_EQ(found, true);
602 auto* status = methods->GetRecvStatus();
603 EXPECT_EQ(status->ok(), true);
604 post_recv_status_ = true;
609 static void VerifyCall(RPCType type) {
611 case RPCType::kSyncUnary:
612 case RPCType::kAsyncCQUnary:
615 case RPCType::kSyncClientStreaming:
616 case RPCType::kAsyncCQClientStreaming:
617 VerifyClientStreamingCall();
619 case RPCType::kSyncServerStreaming:
620 case RPCType::kAsyncCQServerStreaming:
621 VerifyServerStreamingCall();
623 case RPCType::kSyncBidiStreaming:
624 case RPCType::kAsyncCQBidiStreaming:
625 VerifyBidiStreamingCall();
630 static void VerifyCallCommon() {
631 EXPECT_TRUE(pre_send_initial_metadata_);
632 EXPECT_TRUE(pre_send_close_);
633 EXPECT_TRUE(post_recv_initial_metadata_);
634 EXPECT_TRUE(post_recv_status_);
637 static void VerifyUnaryCall() {
639 EXPECT_EQ(pre_send_message_count_, 1);
640 EXPECT_EQ(post_recv_message_count_, 1);
643 static void VerifyClientStreamingCall() {
645 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
646 EXPECT_EQ(post_recv_message_count_, 1);
649 static void VerifyServerStreamingCall() {
651 EXPECT_EQ(pre_send_message_count_, 1);
652 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
655 static void VerifyBidiStreamingCall() {
657 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
658 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
662 static bool pre_send_initial_metadata_;
663 static int pre_send_message_count_;
664 static bool pre_send_close_;
665 static bool post_recv_initial_metadata_;
666 static int post_recv_message_count_;
667 static bool post_recv_status_;
670 bool LoggingInterceptor::pre_send_initial_metadata_;
671 int LoggingInterceptor::pre_send_message_count_;
672 bool LoggingInterceptor::pre_send_close_;
673 bool LoggingInterceptor::post_recv_initial_metadata_;
674 int LoggingInterceptor::post_recv_message_count_;
675 bool LoggingInterceptor::post_recv_status_;
677 class LoggingInterceptorFactory
678 : public experimental::ClientInterceptorFactoryInterface {
680 virtual experimental::Interceptor* CreateClientInterceptor(
681 experimental::ClientRpcInfo* info) override {
682 return new LoggingInterceptor(info);
688 explicit TestScenario(const RPCType& type) : type_(type) {}
690 RPCType type() const { return type_; }
696 std::vector<TestScenario> CreateTestScenarios() {
697 std::vector<TestScenario> scenarios;
698 scenarios.emplace_back(RPCType::kSyncUnary);
699 scenarios.emplace_back(RPCType::kSyncClientStreaming);
700 scenarios.emplace_back(RPCType::kSyncServerStreaming);
701 scenarios.emplace_back(RPCType::kSyncBidiStreaming);
702 scenarios.emplace_back(RPCType::kAsyncCQUnary);
703 scenarios.emplace_back(RPCType::kAsyncCQServerStreaming);
707 class ParameterizedClientInterceptorsEnd2endTest
708 : public ::testing::TestWithParam<TestScenario> {
710 ParameterizedClientInterceptorsEnd2endTest() {
711 int port = grpc_pick_unused_port_or_die();
713 ServerBuilder builder;
714 server_address_ = "localhost:" + std::to_string(port);
715 builder.AddListeningPort(server_address_, InsecureServerCredentials());
716 builder.RegisterService(&service_);
717 server_ = builder.BuildAndStart();
720 ~ParameterizedClientInterceptorsEnd2endTest() { server_->Shutdown(); }
722 void SendRPC(const std::shared_ptr<Channel>& channel) {
723 switch (GetParam().type()) {
724 case RPCType::kSyncUnary:
727 case RPCType::kSyncClientStreaming:
728 MakeClientStreamingCall(channel);
730 case RPCType::kSyncServerStreaming:
731 MakeServerStreamingCall(channel);
733 case RPCType::kSyncBidiStreaming:
734 MakeBidiStreamingCall(channel);
736 case RPCType::kAsyncCQUnary:
737 MakeAsyncCQCall(channel);
739 case RPCType::kAsyncCQClientStreaming:
740 // TODO(yashykt) : Fill this out
742 case RPCType::kAsyncCQServerStreaming:
743 MakeAsyncCQServerStreamingCall(channel);
745 case RPCType::kAsyncCQBidiStreaming:
746 // TODO(yashykt) : Fill this out
751 std::string server_address_;
752 EchoTestServiceStreamingImpl service_;
753 std::unique_ptr<Server> server_;
756 TEST_P(ParameterizedClientInterceptorsEnd2endTest,
757 ClientInterceptorLoggingTest) {
758 ChannelArguments args;
759 DummyInterceptor::Reset();
760 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
762 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
763 new LoggingInterceptorFactory()));
764 // Add 20 dummy interceptors
765 for (auto i = 0; i < 20; i++) {
766 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
767 new DummyInterceptorFactory()));
769 auto channel = experimental::CreateCustomChannelWithInterceptors(
770 server_address_, InsecureChannelCredentials(), args, std::move(creators));
772 LoggingInterceptor::VerifyCall(GetParam().type());
773 // Make sure all 20 dummy interceptors were run
774 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
777 INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
778 ParameterizedClientInterceptorsEnd2endTest,
779 ::testing::ValuesIn(CreateTestScenarios()));
781 class ClientInterceptorsEnd2endTest
782 : public ::testing::TestWithParam<TestScenario> {
784 ClientInterceptorsEnd2endTest() {
785 int port = grpc_pick_unused_port_or_die();
787 ServerBuilder builder;
788 server_address_ = "localhost:" + std::to_string(port);
789 builder.AddListeningPort(server_address_, InsecureServerCredentials());
790 builder.RegisterService(&service_);
791 server_ = builder.BuildAndStart();
794 ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
796 std::string server_address_;
797 TestServiceImpl service_;
798 std::unique_ptr<Server> server_;
801 TEST_F(ClientInterceptorsEnd2endTest,
802 LameChannelClientInterceptorHijackingTest) {
803 ChannelArguments args;
804 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
806 creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
807 new HijackingInterceptorFactory()));
808 auto channel = experimental::CreateCustomChannelWithInterceptors(
809 server_address_, nullptr, args, std::move(creators));
813 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
814 ChannelArguments args;
815 DummyInterceptor::Reset();
816 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
818 // Add 20 dummy interceptors before hijacking interceptor
819 creators.reserve(20);
820 for (auto i = 0; i < 20; i++) {
821 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
822 new DummyInterceptorFactory()));
824 creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
825 new HijackingInterceptorFactory()));
826 // Add 20 dummy interceptors after hijacking interceptor
827 for (auto i = 0; i < 20; i++) {
828 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
829 new DummyInterceptorFactory()));
831 auto channel = experimental::CreateCustomChannelWithInterceptors(
832 server_address_, InsecureChannelCredentials(), args, std::move(creators));
834 // Make sure only 20 dummy interceptors were run
835 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
838 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
839 ChannelArguments args;
840 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
842 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
843 new LoggingInterceptorFactory()));
844 creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
845 new HijackingInterceptorFactory()));
846 auto channel = experimental::CreateCustomChannelWithInterceptors(
847 server_address_, InsecureChannelCredentials(), args, std::move(creators));
849 LoggingInterceptor::VerifyUnaryCall();
852 TEST_F(ClientInterceptorsEnd2endTest,
853 ClientInterceptorHijackingMakesAnotherCallTest) {
854 ChannelArguments args;
855 DummyInterceptor::Reset();
856 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
858 // Add 5 dummy interceptors before hijacking interceptor
860 for (auto i = 0; i < 5; i++) {
861 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
862 new DummyInterceptorFactory()));
865 std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
866 new HijackingInterceptorMakesAnotherCallFactory()));
867 // Add 7 dummy interceptors after hijacking interceptor
868 for (auto i = 0; i < 7; i++) {
869 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
870 new DummyInterceptorFactory()));
872 auto channel = server_->experimental().InProcessChannelWithInterceptors(
873 args, std::move(creators));
876 // Make sure all interceptors were run once, since the hijacking interceptor
877 // makes an RPC on the intercepted channel
878 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12);
881 class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
883 ClientInterceptorsCallbackEnd2endTest() {
884 int port = grpc_pick_unused_port_or_die();
886 ServerBuilder builder;
887 server_address_ = "localhost:" + std::to_string(port);
888 builder.AddListeningPort(server_address_, InsecureServerCredentials());
889 builder.RegisterService(&service_);
890 server_ = builder.BuildAndStart();
893 ~ClientInterceptorsCallbackEnd2endTest() { server_->Shutdown(); }
895 std::string server_address_;
896 TestServiceImpl service_;
897 std::unique_ptr<Server> server_;
900 TEST_F(ClientInterceptorsCallbackEnd2endTest,
901 ClientInterceptorLoggingTestWithCallback) {
902 ChannelArguments args;
903 DummyInterceptor::Reset();
904 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
906 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
907 new LoggingInterceptorFactory()));
908 // Add 20 dummy interceptors
909 for (auto i = 0; i < 20; i++) {
910 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
911 new DummyInterceptorFactory()));
913 auto channel = server_->experimental().InProcessChannelWithInterceptors(
914 args, std::move(creators));
915 MakeCallbackCall(channel);
916 LoggingInterceptor::VerifyUnaryCall();
917 // Make sure all 20 dummy interceptors were run
918 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
921 TEST_F(ClientInterceptorsCallbackEnd2endTest,
922 ClientInterceptorFactoryAllowsNullptrReturn) {
923 ChannelArguments args;
924 DummyInterceptor::Reset();
925 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
927 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
928 new LoggingInterceptorFactory()));
929 // Add 20 dummy interceptors and 20 null interceptors
930 for (auto i = 0; i < 20; i++) {
931 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
932 new DummyInterceptorFactory()));
934 std::unique_ptr<NullInterceptorFactory>(new NullInterceptorFactory()));
936 auto channel = server_->experimental().InProcessChannelWithInterceptors(
937 args, std::move(creators));
938 MakeCallbackCall(channel);
939 LoggingInterceptor::VerifyUnaryCall();
940 // Make sure all 20 dummy interceptors were run
941 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
944 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
946 ClientInterceptorsStreamingEnd2endTest() {
947 int port = grpc_pick_unused_port_or_die();
949 ServerBuilder builder;
950 server_address_ = "localhost:" + std::to_string(port);
951 builder.AddListeningPort(server_address_, InsecureServerCredentials());
952 builder.RegisterService(&service_);
953 server_ = builder.BuildAndStart();
956 ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); }
958 std::string server_address_;
959 EchoTestServiceStreamingImpl service_;
960 std::unique_ptr<Server> server_;
963 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
964 ChannelArguments args;
965 DummyInterceptor::Reset();
966 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
968 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
969 new LoggingInterceptorFactory()));
970 // Add 20 dummy interceptors
971 for (auto i = 0; i < 20; i++) {
972 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
973 new DummyInterceptorFactory()));
975 auto channel = experimental::CreateCustomChannelWithInterceptors(
976 server_address_, InsecureChannelCredentials(), args, std::move(creators));
977 MakeClientStreamingCall(channel);
978 LoggingInterceptor::VerifyClientStreamingCall();
979 // Make sure all 20 dummy interceptors were run
980 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
983 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
984 ChannelArguments args;
985 DummyInterceptor::Reset();
986 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
988 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
989 new LoggingInterceptorFactory()));
990 // Add 20 dummy interceptors
991 for (auto i = 0; i < 20; i++) {
992 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
993 new DummyInterceptorFactory()));
995 auto channel = experimental::CreateCustomChannelWithInterceptors(
996 server_address_, InsecureChannelCredentials(), args, std::move(creators));
997 MakeServerStreamingCall(channel);
998 LoggingInterceptor::VerifyServerStreamingCall();
999 // Make sure all 20 dummy interceptors were run
1000 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1003 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
1004 ChannelArguments args;
1005 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1008 std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
1009 new ClientStreamingRpcHijackingInterceptorFactory()));
1010 auto channel = experimental::CreateCustomChannelWithInterceptors(
1011 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1013 auto stub = grpc::testing::EchoTestService::NewStub(channel);
1017 req.mutable_param()->set_echo_metadata(true);
1018 req.set_message("Hello");
1019 string expected_resp = "";
1020 auto writer = stub->RequestStream(&ctx, &resp);
1021 for (int i = 0; i < 10; i++) {
1022 EXPECT_TRUE(writer->Write(req));
1023 expected_resp += "Hello";
1025 // The interceptor will reject the 11th message
1027 Status s = writer->Finish();
1028 EXPECT_EQ(s.ok(), false);
1029 EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
1032 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
1033 ChannelArguments args;
1034 DummyInterceptor::Reset();
1035 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1038 std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
1039 new ServerStreamingRpcHijackingInterceptorFactory()));
1040 auto channel = experimental::CreateCustomChannelWithInterceptors(
1041 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1042 MakeServerStreamingCall(channel);
1043 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1046 TEST_F(ClientInterceptorsStreamingEnd2endTest,
1047 AsyncCQServerStreamingHijackingTest) {
1048 ChannelArguments args;
1049 DummyInterceptor::Reset();
1050 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1053 std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
1054 new ServerStreamingRpcHijackingInterceptorFactory()));
1055 auto channel = experimental::CreateCustomChannelWithInterceptors(
1056 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1057 MakeAsyncCQServerStreamingCall(channel);
1058 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1061 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
1062 ChannelArguments args;
1063 DummyInterceptor::Reset();
1064 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1067 std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
1068 new BidiStreamingRpcHijackingInterceptorFactory()));
1069 auto channel = experimental::CreateCustomChannelWithInterceptors(
1070 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1071 MakeBidiStreamingCall(channel);
1074 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
1075 ChannelArguments args;
1076 DummyInterceptor::Reset();
1077 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1079 creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
1080 new LoggingInterceptorFactory()));
1081 // Add 20 dummy interceptors
1082 for (auto i = 0; i < 20; i++) {
1083 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1084 new DummyInterceptorFactory()));
1086 auto channel = experimental::CreateCustomChannelWithInterceptors(
1087 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1088 MakeBidiStreamingCall(channel);
1089 LoggingInterceptor::VerifyBidiStreamingCall();
1090 // Make sure all 20 dummy interceptors were run
1091 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1094 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
1096 ClientGlobalInterceptorEnd2endTest() {
1097 int port = grpc_pick_unused_port_or_die();
1099 ServerBuilder builder;
1100 server_address_ = "localhost:" + std::to_string(port);
1101 builder.AddListeningPort(server_address_, InsecureServerCredentials());
1102 builder.RegisterService(&service_);
1103 server_ = builder.BuildAndStart();
1106 ~ClientGlobalInterceptorEnd2endTest() { server_->Shutdown(); }
1108 std::string server_address_;
1109 TestServiceImpl service_;
1110 std::unique_ptr<Server> server_;
1113 TEST_F(ClientGlobalInterceptorEnd2endTest, DummyGlobalInterceptor) {
1114 // We should ideally be registering a global interceptor only once per
1115 // process, but for the purposes of testing, it should be fine to modify the
1116 // registered global interceptor when there are no ongoing gRPC operations
1117 DummyInterceptorFactory global_factory;
1118 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1119 ChannelArguments args;
1120 DummyInterceptor::Reset();
1121 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1123 // Add 20 dummy interceptors
1124 creators.reserve(20);
1125 for (auto i = 0; i < 20; i++) {
1126 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1127 new DummyInterceptorFactory()));
1129 auto channel = experimental::CreateCustomChannelWithInterceptors(
1130 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1132 // Make sure all 20 dummy interceptors were run with the global interceptor
1133 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 21);
1134 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1137 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
1138 // We should ideally be registering a global interceptor only once per
1139 // process, but for the purposes of testing, it should be fine to modify the
1140 // registered global interceptor when there are no ongoing gRPC operations
1141 LoggingInterceptorFactory global_factory;
1142 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1143 ChannelArguments args;
1144 DummyInterceptor::Reset();
1145 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1147 // Add 20 dummy interceptors
1148 creators.reserve(20);
1149 for (auto i = 0; i < 20; i++) {
1150 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1151 new DummyInterceptorFactory()));
1153 auto channel = experimental::CreateCustomChannelWithInterceptors(
1154 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1156 LoggingInterceptor::VerifyUnaryCall();
1157 // Make sure all 20 dummy interceptors were run
1158 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1159 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1162 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
1163 // We should ideally be registering a global interceptor only once per
1164 // process, but for the purposes of testing, it should be fine to modify the
1165 // registered global interceptor when there are no ongoing gRPC operations
1166 HijackingInterceptorFactory global_factory;
1167 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1168 ChannelArguments args;
1169 DummyInterceptor::Reset();
1170 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1172 // Add 20 dummy interceptors
1173 creators.reserve(20);
1174 for (auto i = 0; i < 20; i++) {
1175 creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
1176 new DummyInterceptorFactory()));
1178 auto channel = experimental::CreateCustomChannelWithInterceptors(
1179 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1181 // Make sure all 20 dummy interceptors were run
1182 EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
1183 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1187 } // namespace testing
1190 int main(int argc, char** argv) {
1191 grpc::testing::TestEnvironment env(argc, argv);
1192 ::testing::InitGoogleTest(&argc, argv);
1193 return RUN_ALL_TESTS();