Imported Upstream version 1.21.0
[platform/upstream/grpc.git] / test / cpp / end2end / client_interceptors_end2end_test.cc
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <memory>
20 #include <vector>
21
22 #include <grpcpp/channel.h>
23 #include <grpcpp/client_context.h>
24 #include <grpcpp/create_channel.h>
25 #include <grpcpp/generic/generic_stub.h>
26 #include <grpcpp/impl/codegen/proto_utils.h>
27 #include <grpcpp/server.h>
28 #include <grpcpp/server_builder.h>
29 #include <grpcpp/server_context.h>
30 #include <grpcpp/support/client_interceptor.h>
31
32 #include "src/proto/grpc/testing/echo.grpc.pb.h"
33 #include "test/core/util/port.h"
34 #include "test/core/util/test_config.h"
35 #include "test/cpp/end2end/interceptors_util.h"
36 #include "test/cpp/end2end/test_service_impl.h"
37 #include "test/cpp/util/byte_buffer_proto_helper.h"
38 #include "test/cpp/util/string_ref_helper.h"
39
40 #include <gtest/gtest.h>
41
42 namespace grpc {
43 namespace testing {
44 namespace {
45
46 /* Hijacks Echo RPC and fills in the expected values */
47 class HijackingInterceptor : public experimental::Interceptor {
48  public:
49   HijackingInterceptor(experimental::ClientRpcInfo* info) {
50     info_ = info;
51     // Make sure it is the right method
52     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
53     EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
54   }
55
56   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
57     bool hijack = false;
58     if (methods->QueryInterceptionHookPoint(
59             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
60       auto* map = methods->GetSendInitialMetadata();
61       // Check that we can see the test metadata
62       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
63       auto iterator = map->begin();
64       EXPECT_EQ("testkey", iterator->first);
65       EXPECT_EQ("testvalue", iterator->second);
66       hijack = true;
67     }
68     if (methods->QueryInterceptionHookPoint(
69             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
70       EchoRequest req;
71       auto* buffer = methods->GetSerializedSendMessage();
72       auto copied_buffer = *buffer;
73       EXPECT_TRUE(
74           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
75               .ok());
76       EXPECT_EQ(req.message(), "Hello");
77     }
78     if (methods->QueryInterceptionHookPoint(
79             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
80       // Got nothing to do here for now
81     }
82     if (methods->QueryInterceptionHookPoint(
83             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
84       auto* map = methods->GetRecvInitialMetadata();
85       // Got nothing better to do here for now
86       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
87     }
88     if (methods->QueryInterceptionHookPoint(
89             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
90       EchoResponse* resp =
91           static_cast<EchoResponse*>(methods->GetRecvMessage());
92       // Check that we got the hijacked message, and re-insert the expected
93       // message
94       EXPECT_EQ(resp->message(), "Hello1");
95       resp->set_message("Hello");
96     }
97     if (methods->QueryInterceptionHookPoint(
98             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
99       auto* map = methods->GetRecvTrailingMetadata();
100       bool found = false;
101       // Check that we received the metadata as an echo
102       for (const auto& pair : *map) {
103         found = pair.first.starts_with("testkey") &&
104                 pair.second.starts_with("testvalue");
105         if (found) break;
106       }
107       EXPECT_EQ(found, true);
108       auto* status = methods->GetRecvStatus();
109       EXPECT_EQ(status->ok(), true);
110     }
111     if (methods->QueryInterceptionHookPoint(
112             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
113       auto* map = methods->GetRecvInitialMetadata();
114       // Got nothing better to do here at the moment
115       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
116     }
117     if (methods->QueryInterceptionHookPoint(
118             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
119       // Insert a different message than expected
120       EchoResponse* resp =
121           static_cast<EchoResponse*>(methods->GetRecvMessage());
122       resp->set_message("Hello1");
123     }
124     if (methods->QueryInterceptionHookPoint(
125             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
126       auto* map = methods->GetRecvTrailingMetadata();
127       // insert the metadata that we want
128       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
129       map->insert(std::make_pair("testkey", "testvalue"));
130       auto* status = methods->GetRecvStatus();
131       *status = Status(StatusCode::OK, "");
132     }
133     if (hijack) {
134       methods->Hijack();
135     } else {
136       methods->Proceed();
137     }
138   }
139
140  private:
141   experimental::ClientRpcInfo* info_;
142 };
143
144 class HijackingInterceptorFactory
145     : public experimental::ClientInterceptorFactoryInterface {
146  public:
147   virtual experimental::Interceptor* CreateClientInterceptor(
148       experimental::ClientRpcInfo* info) override {
149     return new HijackingInterceptor(info);
150   }
151 };
152
153 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
154  public:
155   HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) {
156     info_ = info;
157     // Make sure it is the right method
158     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
159   }
160
161   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
162     if (methods->QueryInterceptionHookPoint(
163             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
164       auto* map = methods->GetSendInitialMetadata();
165       // Check that we can see the test metadata
166       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
167       auto iterator = map->begin();
168       EXPECT_EQ("testkey", iterator->first);
169       EXPECT_EQ("testvalue", iterator->second);
170       // Make a copy of the map
171       metadata_map_ = *map;
172     }
173     if (methods->QueryInterceptionHookPoint(
174             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
175       EchoRequest req;
176       auto* buffer = methods->GetSerializedSendMessage();
177       auto copied_buffer = *buffer;
178       EXPECT_TRUE(
179           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
180               .ok());
181       EXPECT_EQ(req.message(), "Hello");
182       req_ = req;
183       stub_ = grpc::testing::EchoTestService::NewStub(
184           methods->GetInterceptedChannel());
185       ctx_.AddMetadata(metadata_map_.begin()->first,
186                        metadata_map_.begin()->second);
187       stub_->experimental_async()->Echo(&ctx_, &req_, &resp_,
188                                         [this, methods](Status s) {
189                                           EXPECT_EQ(s.ok(), true);
190                                           EXPECT_EQ(resp_.message(), "Hello");
191                                           methods->Hijack();
192                                         });
193       // There isn't going to be any other interesting operation in this batch,
194       // so it is fine to return
195       return;
196     }
197     if (methods->QueryInterceptionHookPoint(
198             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
199       // Got nothing to do here for now
200     }
201     if (methods->QueryInterceptionHookPoint(
202             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
203       auto* map = methods->GetRecvInitialMetadata();
204       // Got nothing better to do here for now
205       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
206     }
207     if (methods->QueryInterceptionHookPoint(
208             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
209       EchoResponse* resp =
210           static_cast<EchoResponse*>(methods->GetRecvMessage());
211       // Check that we got the hijacked message, and re-insert the expected
212       // message
213       EXPECT_EQ(resp->message(), "Hello");
214     }
215     if (methods->QueryInterceptionHookPoint(
216             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
217       auto* map = methods->GetRecvTrailingMetadata();
218       bool found = false;
219       // Check that we received the metadata as an echo
220       for (const auto& pair : *map) {
221         found = pair.first.starts_with("testkey") &&
222                 pair.second.starts_with("testvalue");
223         if (found) break;
224       }
225       EXPECT_EQ(found, true);
226       auto* status = methods->GetRecvStatus();
227       EXPECT_EQ(status->ok(), true);
228     }
229     if (methods->QueryInterceptionHookPoint(
230             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
231       auto* map = methods->GetRecvInitialMetadata();
232       // Got nothing better to do here at the moment
233       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
234     }
235     if (methods->QueryInterceptionHookPoint(
236             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
237       // Insert a different message than expected
238       EchoResponse* resp =
239           static_cast<EchoResponse*>(methods->GetRecvMessage());
240       resp->set_message(resp_.message());
241     }
242     if (methods->QueryInterceptionHookPoint(
243             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
244       auto* map = methods->GetRecvTrailingMetadata();
245       // insert the metadata that we want
246       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
247       map->insert(std::make_pair("testkey", "testvalue"));
248       auto* status = methods->GetRecvStatus();
249       *status = Status(StatusCode::OK, "");
250     }
251
252     methods->Proceed();
253   }
254
255  private:
256   experimental::ClientRpcInfo* info_;
257   std::multimap<grpc::string, grpc::string> metadata_map_;
258   ClientContext ctx_;
259   EchoRequest req_;
260   EchoResponse resp_;
261   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
262 };
263
264 class HijackingInterceptorMakesAnotherCallFactory
265     : public experimental::ClientInterceptorFactoryInterface {
266  public:
267   virtual experimental::Interceptor* CreateClientInterceptor(
268       experimental::ClientRpcInfo* info) override {
269     return new HijackingInterceptorMakesAnotherCall(info);
270   }
271 };
272
273 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
274  public:
275   BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
276     info_ = info;
277   }
278
279   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
280     bool hijack = false;
281     if (methods->QueryInterceptionHookPoint(
282             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
283       CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
284       hijack = true;
285     }
286     if (methods->QueryInterceptionHookPoint(
287             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
288       EchoRequest req;
289       auto* buffer = methods->GetSerializedSendMessage();
290       auto copied_buffer = *buffer;
291       EXPECT_TRUE(
292           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
293               .ok());
294       EXPECT_EQ(req.message().find("Hello"), 0u);
295       msg = req.message();
296     }
297     if (methods->QueryInterceptionHookPoint(
298             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
299       // Got nothing to do here for now
300     }
301     if (methods->QueryInterceptionHookPoint(
302             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
303       CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
304                     "testvalue");
305       auto* status = methods->GetRecvStatus();
306       EXPECT_EQ(status->ok(), true);
307     }
308     if (methods->QueryInterceptionHookPoint(
309             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
310       EchoResponse* resp =
311           static_cast<EchoResponse*>(methods->GetRecvMessage());
312       resp->set_message(msg);
313     }
314     if (methods->QueryInterceptionHookPoint(
315             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
316       EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
317                     ->message()
318                     .find("Hello"),
319                 0u);
320     }
321     if (methods->QueryInterceptionHookPoint(
322             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
323       auto* map = methods->GetRecvTrailingMetadata();
324       // insert the metadata that we want
325       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
326       map->insert(std::make_pair("testkey", "testvalue"));
327       auto* status = methods->GetRecvStatus();
328       *status = Status(StatusCode::OK, "");
329     }
330     if (hijack) {
331       methods->Hijack();
332     } else {
333       methods->Proceed();
334     }
335   }
336
337  private:
338   experimental::ClientRpcInfo* info_;
339   grpc::string msg;
340 };
341
342 class ClientStreamingRpcHijackingInterceptor
343     : public experimental::Interceptor {
344  public:
345   ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
346     info_ = info;
347   }
348   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
349     bool hijack = false;
350     if (methods->QueryInterceptionHookPoint(
351             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
352       hijack = true;
353     }
354     if (methods->QueryInterceptionHookPoint(
355             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
356       if (++count_ > 10) {
357         methods->FailHijackedSendMessage();
358       }
359     }
360     if (methods->QueryInterceptionHookPoint(
361             experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
362       EXPECT_FALSE(got_failed_send_);
363       got_failed_send_ = !methods->GetSendMessageStatus();
364     }
365     if (methods->QueryInterceptionHookPoint(
366             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
367       auto* status = methods->GetRecvStatus();
368       *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
369     }
370     if (hijack) {
371       methods->Hijack();
372     } else {
373       methods->Proceed();
374     }
375   }
376
377   static bool GotFailedSend() { return got_failed_send_; }
378
379  private:
380   experimental::ClientRpcInfo* info_;
381   int count_ = 0;
382   static bool got_failed_send_;
383 };
384
385 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
386
387 class ClientStreamingRpcHijackingInterceptorFactory
388     : public experimental::ClientInterceptorFactoryInterface {
389  public:
390   virtual experimental::Interceptor* CreateClientInterceptor(
391       experimental::ClientRpcInfo* info) override {
392     return new ClientStreamingRpcHijackingInterceptor(info);
393   }
394 };
395
396 class ServerStreamingRpcHijackingInterceptor
397     : public experimental::Interceptor {
398  public:
399   ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
400     info_ = info;
401   }
402
403   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
404     bool hijack = false;
405     if (methods->QueryInterceptionHookPoint(
406             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
407       auto* map = methods->GetSendInitialMetadata();
408       // Check that we can see the test metadata
409       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
410       auto iterator = map->begin();
411       EXPECT_EQ("testkey", iterator->first);
412       EXPECT_EQ("testvalue", iterator->second);
413       hijack = true;
414     }
415     if (methods->QueryInterceptionHookPoint(
416             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
417       EchoRequest req;
418       auto* buffer = methods->GetSerializedSendMessage();
419       auto copied_buffer = *buffer;
420       EXPECT_TRUE(
421           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
422               .ok());
423       EXPECT_EQ(req.message(), "Hello");
424     }
425     if (methods->QueryInterceptionHookPoint(
426             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
427       // Got nothing to do here for now
428     }
429     if (methods->QueryInterceptionHookPoint(
430             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
431       auto* map = methods->GetRecvTrailingMetadata();
432       bool found = false;
433       // Check that we received the metadata as an echo
434       for (const auto& pair : *map) {
435         found = pair.first.starts_with("testkey") &&
436                 pair.second.starts_with("testvalue");
437         if (found) break;
438       }
439       EXPECT_EQ(found, true);
440       auto* status = methods->GetRecvStatus();
441       EXPECT_EQ(status->ok(), true);
442     }
443     if (methods->QueryInterceptionHookPoint(
444             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
445       if (++count_ > 10) {
446         methods->FailHijackedRecvMessage();
447       }
448       EchoResponse* resp =
449           static_cast<EchoResponse*>(methods->GetRecvMessage());
450       resp->set_message("Hello");
451     }
452     if (methods->QueryInterceptionHookPoint(
453             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
454       // Only the last message will be a failure
455       EXPECT_FALSE(got_failed_message_);
456       got_failed_message_ = methods->GetRecvMessage() == nullptr;
457     }
458     if (methods->QueryInterceptionHookPoint(
459             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
460       auto* map = methods->GetRecvTrailingMetadata();
461       // insert the metadata that we want
462       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
463       map->insert(std::make_pair("testkey", "testvalue"));
464       auto* status = methods->GetRecvStatus();
465       *status = Status(StatusCode::OK, "");
466     }
467     if (hijack) {
468       methods->Hijack();
469     } else {
470       methods->Proceed();
471     }
472   }
473
474   static bool GotFailedMessage() { return got_failed_message_; }
475
476  private:
477   experimental::ClientRpcInfo* info_;
478   static bool got_failed_message_;
479   int count_ = 0;
480 };
481
482 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
483
484 class ServerStreamingRpcHijackingInterceptorFactory
485     : public experimental::ClientInterceptorFactoryInterface {
486  public:
487   virtual experimental::Interceptor* CreateClientInterceptor(
488       experimental::ClientRpcInfo* info) override {
489     return new ServerStreamingRpcHijackingInterceptor(info);
490   }
491 };
492
493 class BidiStreamingRpcHijackingInterceptorFactory
494     : public experimental::ClientInterceptorFactoryInterface {
495  public:
496   virtual experimental::Interceptor* CreateClientInterceptor(
497       experimental::ClientRpcInfo* info) override {
498     return new BidiStreamingRpcHijackingInterceptor(info);
499   }
500 };
501
502 class LoggingInterceptor : public experimental::Interceptor {
503  public:
504   LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
505
506   virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
507     if (methods->QueryInterceptionHookPoint(
508             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
509       auto* map = methods->GetSendInitialMetadata();
510       // Check that we can see the test metadata
511       ASSERT_EQ(map->size(), static_cast<unsigned>(1));
512       auto iterator = map->begin();
513       EXPECT_EQ("testkey", iterator->first);
514       EXPECT_EQ("testvalue", iterator->second);
515     }
516     if (methods->QueryInterceptionHookPoint(
517             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
518       EchoRequest req;
519       EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage())
520                     ->message()
521                     .find("Hello"),
522                 0u);
523       auto* buffer = methods->GetSerializedSendMessage();
524       auto copied_buffer = *buffer;
525       EXPECT_TRUE(
526           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
527               .ok());
528       EXPECT_TRUE(req.message().find("Hello") == 0u);
529     }
530     if (methods->QueryInterceptionHookPoint(
531             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
532       // Got nothing to do here for now
533     }
534     if (methods->QueryInterceptionHookPoint(
535             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
536       auto* map = methods->GetRecvInitialMetadata();
537       // Got nothing better to do here for now
538       EXPECT_EQ(map->size(), static_cast<unsigned>(0));
539     }
540     if (methods->QueryInterceptionHookPoint(
541             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
542       EchoResponse* resp =
543           static_cast<EchoResponse*>(methods->GetRecvMessage());
544       EXPECT_TRUE(resp->message().find("Hello") == 0u);
545     }
546     if (methods->QueryInterceptionHookPoint(
547             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
548       auto* map = methods->GetRecvTrailingMetadata();
549       bool found = false;
550       // Check that we received the metadata as an echo
551       for (const auto& pair : *map) {
552         found = pair.first.starts_with("testkey") &&
553                 pair.second.starts_with("testvalue");
554         if (found) break;
555       }
556       EXPECT_EQ(found, true);
557       auto* status = methods->GetRecvStatus();
558       EXPECT_EQ(status->ok(), true);
559     }
560     methods->Proceed();
561   }
562
563  private:
564   experimental::ClientRpcInfo* info_;
565 };
566
567 class LoggingInterceptorFactory
568     : public experimental::ClientInterceptorFactoryInterface {
569  public:
570   virtual experimental::Interceptor* CreateClientInterceptor(
571       experimental::ClientRpcInfo* info) override {
572     return new LoggingInterceptor(info);
573   }
574 };
575
576 class ClientInterceptorsEnd2endTest : public ::testing::Test {
577  protected:
578   ClientInterceptorsEnd2endTest() {
579     int port = grpc_pick_unused_port_or_die();
580
581     ServerBuilder builder;
582     server_address_ = "localhost:" + std::to_string(port);
583     builder.AddListeningPort(server_address_, InsecureServerCredentials());
584     builder.RegisterService(&service_);
585     server_ = builder.BuildAndStart();
586   }
587
588   ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
589
590   std::string server_address_;
591   TestServiceImpl service_;
592   std::unique_ptr<Server> server_;
593 };
594
595 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
596   ChannelArguments args;
597   DummyInterceptor::Reset();
598   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
599       creators;
600   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
601       new LoggingInterceptorFactory()));
602   // Add 20 dummy interceptors
603   for (auto i = 0; i < 20; i++) {
604     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
605         new DummyInterceptorFactory()));
606   }
607   auto channel = experimental::CreateCustomChannelWithInterceptors(
608       server_address_, InsecureChannelCredentials(), args, std::move(creators));
609   MakeCall(channel);
610   // Make sure all 20 dummy interceptors were run
611   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
612 }
613
614 TEST_F(ClientInterceptorsEnd2endTest,
615        LameChannelClientInterceptorHijackingTest) {
616   ChannelArguments args;
617   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
618       creators;
619   creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
620       new HijackingInterceptorFactory()));
621   auto channel = experimental::CreateCustomChannelWithInterceptors(
622       server_address_, nullptr, args, std::move(creators));
623   MakeCall(channel);
624 }
625
626 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
627   ChannelArguments args;
628   DummyInterceptor::Reset();
629   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
630       creators;
631   // Add 20 dummy interceptors before hijacking interceptor
632   creators.reserve(20);
633   for (auto i = 0; i < 20; i++) {
634     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
635         new DummyInterceptorFactory()));
636   }
637   creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
638       new HijackingInterceptorFactory()));
639   // Add 20 dummy interceptors after hijacking interceptor
640   for (auto i = 0; i < 20; i++) {
641     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
642         new DummyInterceptorFactory()));
643   }
644   auto channel = experimental::CreateCustomChannelWithInterceptors(
645       server_address_, InsecureChannelCredentials(), args, std::move(creators));
646
647   MakeCall(channel);
648   // Make sure only 20 dummy interceptors were run
649   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
650 }
651
652 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
653   ChannelArguments args;
654   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
655       creators;
656   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
657       new LoggingInterceptorFactory()));
658   creators.push_back(std::unique_ptr<HijackingInterceptorFactory>(
659       new HijackingInterceptorFactory()));
660   auto channel = experimental::CreateCustomChannelWithInterceptors(
661       server_address_, InsecureChannelCredentials(), args, std::move(creators));
662
663   MakeCall(channel);
664 }
665
666 TEST_F(ClientInterceptorsEnd2endTest,
667        ClientInterceptorHijackingMakesAnotherCallTest) {
668   ChannelArguments args;
669   DummyInterceptor::Reset();
670   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
671       creators;
672   // Add 5 dummy interceptors before hijacking interceptor
673   creators.reserve(5);
674   for (auto i = 0; i < 5; i++) {
675     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
676         new DummyInterceptorFactory()));
677   }
678   creators.push_back(
679       std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
680           new HijackingInterceptorMakesAnotherCallFactory()));
681   // Add 7 dummy interceptors after hijacking interceptor
682   for (auto i = 0; i < 7; i++) {
683     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
684         new DummyInterceptorFactory()));
685   }
686   auto channel = server_->experimental().InProcessChannelWithInterceptors(
687       args, std::move(creators));
688
689   MakeCall(channel);
690   // Make sure all interceptors were run once, since the hijacking interceptor
691   // makes an RPC on the intercepted channel
692   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12);
693 }
694
695 TEST_F(ClientInterceptorsEnd2endTest,
696        ClientInterceptorLoggingTestWithCallback) {
697   ChannelArguments args;
698   DummyInterceptor::Reset();
699   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
700       creators;
701   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
702       new LoggingInterceptorFactory()));
703   // Add 20 dummy interceptors
704   for (auto i = 0; i < 20; i++) {
705     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
706         new DummyInterceptorFactory()));
707   }
708   auto channel = server_->experimental().InProcessChannelWithInterceptors(
709       args, std::move(creators));
710   MakeCallbackCall(channel);
711   // Make sure all 20 dummy interceptors were run
712   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
713 }
714
715 TEST_F(ClientInterceptorsEnd2endTest,
716        ClientInterceptorFactoryAllowsNullptrReturn) {
717   ChannelArguments args;
718   DummyInterceptor::Reset();
719   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
720       creators;
721   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
722       new LoggingInterceptorFactory()));
723   // Add 20 dummy interceptors and 20 null interceptors
724   for (auto i = 0; i < 20; i++) {
725     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
726         new DummyInterceptorFactory()));
727     creators.push_back(
728         std::unique_ptr<NullInterceptorFactory>(new NullInterceptorFactory()));
729   }
730   auto channel = server_->experimental().InProcessChannelWithInterceptors(
731       args, std::move(creators));
732   MakeCallbackCall(channel);
733   // Make sure all 20 dummy interceptors were run
734   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
735 }
736
737 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
738  protected:
739   ClientInterceptorsStreamingEnd2endTest() {
740     int port = grpc_pick_unused_port_or_die();
741
742     ServerBuilder builder;
743     server_address_ = "localhost:" + std::to_string(port);
744     builder.AddListeningPort(server_address_, InsecureServerCredentials());
745     builder.RegisterService(&service_);
746     server_ = builder.BuildAndStart();
747   }
748
749   ~ClientInterceptorsStreamingEnd2endTest() { server_->Shutdown(); }
750
751   std::string server_address_;
752   EchoTestServiceStreamingImpl service_;
753   std::unique_ptr<Server> server_;
754 };
755
756 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
757   ChannelArguments args;
758   DummyInterceptor::Reset();
759   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
760       creators;
761   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
762       new LoggingInterceptorFactory()));
763   // Add 20 dummy interceptors
764   for (auto i = 0; i < 20; i++) {
765     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
766         new DummyInterceptorFactory()));
767   }
768   auto channel = experimental::CreateCustomChannelWithInterceptors(
769       server_address_, InsecureChannelCredentials(), args, std::move(creators));
770   MakeClientStreamingCall(channel);
771   // Make sure all 20 dummy interceptors were run
772   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
773 }
774
775 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
776   ChannelArguments args;
777   DummyInterceptor::Reset();
778   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
779       creators;
780   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
781       new LoggingInterceptorFactory()));
782   // Add 20 dummy interceptors
783   for (auto i = 0; i < 20; i++) {
784     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
785         new DummyInterceptorFactory()));
786   }
787   auto channel = experimental::CreateCustomChannelWithInterceptors(
788       server_address_, InsecureChannelCredentials(), args, std::move(creators));
789   MakeServerStreamingCall(channel);
790   // Make sure all 20 dummy interceptors were run
791   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
792 }
793
794 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
795   ChannelArguments args;
796   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
797       creators;
798   creators.push_back(
799       std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
800           new ClientStreamingRpcHijackingInterceptorFactory()));
801   auto channel = experimental::CreateCustomChannelWithInterceptors(
802       server_address_, InsecureChannelCredentials(), args, std::move(creators));
803
804   auto stub = grpc::testing::EchoTestService::NewStub(channel);
805   ClientContext ctx;
806   EchoRequest req;
807   EchoResponse resp;
808   req.mutable_param()->set_echo_metadata(true);
809   req.set_message("Hello");
810   string expected_resp = "";
811   auto writer = stub->RequestStream(&ctx, &resp);
812   for (int i = 0; i < 10; i++) {
813     EXPECT_TRUE(writer->Write(req));
814     expected_resp += "Hello";
815   }
816   // The interceptor will reject the 11th message
817   writer->Write(req);
818   Status s = writer->Finish();
819   EXPECT_EQ(s.ok(), false);
820   EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
821 }
822
823 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
824   ChannelArguments args;
825   DummyInterceptor::Reset();
826   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
827       creators;
828   creators.push_back(
829       std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
830           new ServerStreamingRpcHijackingInterceptorFactory()));
831   auto channel = experimental::CreateCustomChannelWithInterceptors(
832       server_address_, InsecureChannelCredentials(), args, std::move(creators));
833   MakeServerStreamingCall(channel);
834   EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
835 }
836
837 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
838   ChannelArguments args;
839   DummyInterceptor::Reset();
840   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
841       creators;
842   creators.push_back(
843       std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
844           new BidiStreamingRpcHijackingInterceptorFactory()));
845   auto channel = experimental::CreateCustomChannelWithInterceptors(
846       server_address_, InsecureChannelCredentials(), args, std::move(creators));
847   MakeBidiStreamingCall(channel);
848 }
849
850 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
851   ChannelArguments args;
852   DummyInterceptor::Reset();
853   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
854       creators;
855   creators.push_back(std::unique_ptr<LoggingInterceptorFactory>(
856       new LoggingInterceptorFactory()));
857   // Add 20 dummy interceptors
858   for (auto i = 0; i < 20; i++) {
859     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
860         new DummyInterceptorFactory()));
861   }
862   auto channel = experimental::CreateCustomChannelWithInterceptors(
863       server_address_, InsecureChannelCredentials(), args, std::move(creators));
864   MakeBidiStreamingCall(channel);
865   // Make sure all 20 dummy interceptors were run
866   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
867 }
868
869 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
870  protected:
871   ClientGlobalInterceptorEnd2endTest() {
872     int port = grpc_pick_unused_port_or_die();
873
874     ServerBuilder builder;
875     server_address_ = "localhost:" + std::to_string(port);
876     builder.AddListeningPort(server_address_, InsecureServerCredentials());
877     builder.RegisterService(&service_);
878     server_ = builder.BuildAndStart();
879   }
880
881   ~ClientGlobalInterceptorEnd2endTest() { server_->Shutdown(); }
882
883   std::string server_address_;
884   TestServiceImpl service_;
885   std::unique_ptr<Server> server_;
886 };
887
888 TEST_F(ClientGlobalInterceptorEnd2endTest, DummyGlobalInterceptor) {
889   // We should ideally be registering a global interceptor only once per
890   // process, but for the purposes of testing, it should be fine to modify the
891   // registered global interceptor when there are no ongoing gRPC operations
892   DummyInterceptorFactory global_factory;
893   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
894   ChannelArguments args;
895   DummyInterceptor::Reset();
896   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
897       creators;
898   // Add 20 dummy interceptors
899   creators.reserve(20);
900   for (auto i = 0; i < 20; i++) {
901     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
902         new DummyInterceptorFactory()));
903   }
904   auto channel = experimental::CreateCustomChannelWithInterceptors(
905       server_address_, InsecureChannelCredentials(), args, std::move(creators));
906   MakeCall(channel);
907   // Make sure all 20 dummy interceptors were run with the global interceptor
908   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 21);
909   experimental::TestOnlyResetGlobalClientInterceptorFactory();
910 }
911
912 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
913   // We should ideally be registering a global interceptor only once per
914   // process, but for the purposes of testing, it should be fine to modify the
915   // registered global interceptor when there are no ongoing gRPC operations
916   LoggingInterceptorFactory global_factory;
917   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
918   ChannelArguments args;
919   DummyInterceptor::Reset();
920   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
921       creators;
922   // Add 20 dummy interceptors
923   creators.reserve(20);
924   for (auto i = 0; i < 20; i++) {
925     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
926         new DummyInterceptorFactory()));
927   }
928   auto channel = experimental::CreateCustomChannelWithInterceptors(
929       server_address_, InsecureChannelCredentials(), args, std::move(creators));
930   MakeCall(channel);
931   // Make sure all 20 dummy interceptors were run
932   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
933   experimental::TestOnlyResetGlobalClientInterceptorFactory();
934 }
935
936 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
937   // We should ideally be registering a global interceptor only once per
938   // process, but for the purposes of testing, it should be fine to modify the
939   // registered global interceptor when there are no ongoing gRPC operations
940   HijackingInterceptorFactory global_factory;
941   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
942   ChannelArguments args;
943   DummyInterceptor::Reset();
944   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
945       creators;
946   // Add 20 dummy interceptors
947   creators.reserve(20);
948   for (auto i = 0; i < 20; i++) {
949     creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
950         new DummyInterceptorFactory()));
951   }
952   auto channel = experimental::CreateCustomChannelWithInterceptors(
953       server_address_, InsecureChannelCredentials(), args, std::move(creators));
954   MakeCall(channel);
955   // Make sure all 20 dummy interceptors were run
956   EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
957   experimental::TestOnlyResetGlobalClientInterceptorFactory();
958 }
959
960 }  // namespace
961 }  // namespace testing
962 }  // namespace grpc
963
964 int main(int argc, char** argv) {
965   grpc::testing::TestEnvironment env(argc, argv);
966   ::testing::InitGoogleTest(&argc, argv);
967   return RUN_ALL_TESTS();
968 }