Imported Upstream version 1.32.0
[platform/upstream/grpc.git] / test / cpp / end2end / filter_end2end_test.cc
1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <memory>
20 #include <mutex>
21 #include <thread>
22
23 #include <grpc/grpc.h>
24 #include <grpc/support/time.h>
25 #include <grpcpp/channel.h>
26 #include <grpcpp/client_context.h>
27 #include <grpcpp/create_channel.h>
28 #include <grpcpp/generic/async_generic_service.h>
29 #include <grpcpp/generic/generic_stub.h>
30 #include <grpcpp/impl/codegen/proto_utils.h>
31 #include <grpcpp/server.h>
32 #include <grpcpp/server_builder.h>
33 #include <grpcpp/server_context.h>
34 #include <grpcpp/support/config.h>
35 #include <grpcpp/support/slice.h>
36
37 #include "src/cpp/common/channel_filter.h"
38 #include "src/proto/grpc/testing/echo.grpc.pb.h"
39 #include "test/core/util/port.h"
40 #include "test/core/util/test_config.h"
41 #include "test/cpp/util/byte_buffer_proto_helper.h"
42
43 #include <gtest/gtest.h>
44
45 using grpc::testing::EchoRequest;
46 using grpc::testing::EchoResponse;
47 using std::chrono::system_clock;
48
49 namespace grpc {
50 namespace testing {
51 namespace {
52
53 void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
54
55 void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
56   bool ok;
57   void* got_tag;
58   EXPECT_TRUE(cq->Next(&got_tag, &ok));
59   EXPECT_EQ(expect_ok, ok);
60   EXPECT_EQ(tag(i), got_tag);
61 }
62
63 namespace {
64
65 int global_num_connections = 0;
66 int global_num_calls = 0;
67 std::mutex global_mu;
68
69 void IncrementConnectionCounter() {
70   std::unique_lock<std::mutex> lock(global_mu);
71   ++global_num_connections;
72 }
73
74 void ResetConnectionCounter() {
75   std::unique_lock<std::mutex> lock(global_mu);
76   global_num_connections = 0;
77 }
78
79 int GetConnectionCounterValue() {
80   std::unique_lock<std::mutex> lock(global_mu);
81   return global_num_connections;
82 }
83
84 void IncrementCallCounter() {
85   std::unique_lock<std::mutex> lock(global_mu);
86   ++global_num_calls;
87 }
88
89 void ResetCallCounter() {
90   std::unique_lock<std::mutex> lock(global_mu);
91   global_num_calls = 0;
92 }
93
94 int GetCallCounterValue() {
95   std::unique_lock<std::mutex> lock(global_mu);
96   return global_num_calls;
97 }
98
99 }  // namespace
100
101 class ChannelDataImpl : public ChannelData {
102  public:
103   grpc_error* Init(grpc_channel_element* /*elem*/,
104                    grpc_channel_element_args* /*args*/) {
105     IncrementConnectionCounter();
106     return GRPC_ERROR_NONE;
107   }
108 };
109
110 class CallDataImpl : public CallData {
111  public:
112   void StartTransportStreamOpBatch(grpc_call_element* elem,
113                                    TransportStreamOpBatch* op) override {
114     // Incrementing the counter could be done from Init(), but we want
115     // to test that the individual methods are actually called correctly.
116     if (op->recv_initial_metadata() != nullptr) IncrementCallCounter();
117     grpc_call_next_op(elem, op->op());
118   }
119 };
120
121 class FilterEnd2endTest : public ::testing::Test {
122  protected:
123   FilterEnd2endTest() : server_host_("localhost") {}
124
125   static void SetUpTestCase() {
126     // Workaround for
127     // https://github.com/google/google-toolbox-for-mac/issues/242
128     static bool setup_done = false;
129     if (!setup_done) {
130       setup_done = true;
131       grpc::RegisterChannelFilter<ChannelDataImpl, CallDataImpl>(
132           "test-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr);
133     }
134   }
135
136   void SetUp() override {
137     int port = grpc_pick_unused_port_or_die();
138     server_address_ << server_host_ << ":" << port;
139     // Setup server
140     ServerBuilder builder;
141     builder.AddListeningPort(server_address_.str(),
142                              InsecureServerCredentials());
143     builder.RegisterAsyncGenericService(&generic_service_);
144     srv_cq_ = builder.AddCompletionQueue();
145     server_ = builder.BuildAndStart();
146   }
147
148   void TearDown() override {
149     server_->Shutdown();
150     void* ignored_tag;
151     bool ignored_ok;
152     cli_cq_.Shutdown();
153     srv_cq_->Shutdown();
154     while (cli_cq_.Next(&ignored_tag, &ignored_ok))
155       ;
156     while (srv_cq_->Next(&ignored_tag, &ignored_ok))
157       ;
158   }
159
160   void ResetStub() {
161     std::shared_ptr<Channel> channel = grpc::CreateChannel(
162         server_address_.str(), InsecureChannelCredentials());
163     generic_stub_.reset(new GenericStub(channel));
164     ResetConnectionCounter();
165     ResetCallCounter();
166   }
167
168   void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); }
169   void client_ok(int i) { verify_ok(&cli_cq_, i, true); }
170   void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); }
171   void client_fail(int i) { verify_ok(&cli_cq_, i, false); }
172
173   void SendRpc(int num_rpcs) {
174     const std::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
175     for (int i = 0; i < num_rpcs; i++) {
176       EchoRequest send_request;
177       EchoRequest recv_request;
178       EchoResponse send_response;
179       EchoResponse recv_response;
180       Status recv_status;
181
182       ClientContext cli_ctx;
183       GenericServerContext srv_ctx;
184       GenericServerAsyncReaderWriter stream(&srv_ctx);
185
186       // The string needs to be long enough to test heap-based slice.
187       send_request.set_message("Hello world. Hello world. Hello world.");
188       std::thread request_call([this]() { server_ok(4); });
189       std::unique_ptr<GenericClientAsyncReaderWriter> call =
190           generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
191       call->StartCall(tag(1));
192       client_ok(1);
193       std::unique_ptr<ByteBuffer> send_buffer =
194           SerializeToByteBuffer(&send_request);
195       call->Write(*send_buffer, tag(2));
196       // Send ByteBuffer can be destroyed after calling Write.
197       send_buffer.reset();
198       client_ok(2);
199       call->WritesDone(tag(3));
200       client_ok(3);
201
202       generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(),
203                                    srv_cq_.get(), tag(4));
204
205       request_call.join();
206       EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
207       EXPECT_EQ(kMethodName, srv_ctx.method());
208       ByteBuffer recv_buffer;
209       stream.Read(&recv_buffer, tag(5));
210       server_ok(5);
211       EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
212       EXPECT_EQ(send_request.message(), recv_request.message());
213
214       send_response.set_message(recv_request.message());
215       send_buffer = SerializeToByteBuffer(&send_response);
216       stream.Write(*send_buffer, tag(6));
217       send_buffer.reset();
218       server_ok(6);
219
220       stream.Finish(Status::OK, tag(7));
221       server_ok(7);
222
223       recv_buffer.Clear();
224       call->Read(&recv_buffer, tag(8));
225       client_ok(8);
226       EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
227
228       call->Finish(&recv_status, tag(9));
229       client_ok(9);
230
231       EXPECT_EQ(send_response.message(), recv_response.message());
232       EXPECT_TRUE(recv_status.ok());
233     }
234   }
235
236   CompletionQueue cli_cq_;
237   std::unique_ptr<ServerCompletionQueue> srv_cq_;
238   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
239   std::unique_ptr<grpc::GenericStub> generic_stub_;
240   std::unique_ptr<Server> server_;
241   AsyncGenericService generic_service_;
242   const std::string server_host_;
243   std::ostringstream server_address_;
244 };
245
246 TEST_F(FilterEnd2endTest, SimpleRpc) {
247   ResetStub();
248   EXPECT_EQ(0, GetConnectionCounterValue());
249   EXPECT_EQ(0, GetCallCounterValue());
250   SendRpc(1);
251   EXPECT_EQ(1, GetConnectionCounterValue());
252   EXPECT_EQ(1, GetCallCounterValue());
253 }
254
255 TEST_F(FilterEnd2endTest, SequentialRpcs) {
256   ResetStub();
257   EXPECT_EQ(0, GetConnectionCounterValue());
258   EXPECT_EQ(0, GetCallCounterValue());
259   SendRpc(10);
260   EXPECT_EQ(1, GetConnectionCounterValue());
261   EXPECT_EQ(10, GetCallCounterValue());
262 }
263
264 // One ping, one pong.
265 TEST_F(FilterEnd2endTest, SimpleBidiStreaming) {
266   ResetStub();
267   EXPECT_EQ(0, GetConnectionCounterValue());
268   EXPECT_EQ(0, GetCallCounterValue());
269
270   const std::string kMethodName(
271       "/grpc.cpp.test.util.EchoTestService/BidiStream");
272   EchoRequest send_request;
273   EchoRequest recv_request;
274   EchoResponse send_response;
275   EchoResponse recv_response;
276   Status recv_status;
277   ClientContext cli_ctx;
278   GenericServerContext srv_ctx;
279   GenericServerAsyncReaderWriter srv_stream(&srv_ctx);
280
281   cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
282   send_request.set_message("Hello");
283   std::thread request_call([this]() { server_ok(2); });
284   std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream =
285       generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
286   cli_stream->StartCall(tag(1));
287   client_ok(1);
288
289   generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(),
290                                srv_cq_.get(), tag(2));
291
292   request_call.join();
293   EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
294   EXPECT_EQ(kMethodName, srv_ctx.method());
295
296   std::unique_ptr<ByteBuffer> send_buffer =
297       SerializeToByteBuffer(&send_request);
298   cli_stream->Write(*send_buffer, tag(3));
299   send_buffer.reset();
300   client_ok(3);
301
302   ByteBuffer recv_buffer;
303   srv_stream.Read(&recv_buffer, tag(4));
304   server_ok(4);
305   EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
306   EXPECT_EQ(send_request.message(), recv_request.message());
307
308   send_response.set_message(recv_request.message());
309   send_buffer = SerializeToByteBuffer(&send_response);
310   srv_stream.Write(*send_buffer, tag(5));
311   send_buffer.reset();
312   server_ok(5);
313
314   cli_stream->Read(&recv_buffer, tag(6));
315   client_ok(6);
316   EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
317   EXPECT_EQ(send_response.message(), recv_response.message());
318
319   cli_stream->WritesDone(tag(7));
320   client_ok(7);
321
322   srv_stream.Read(&recv_buffer, tag(8));
323   server_fail(8);
324
325   srv_stream.Finish(Status::OK, tag(9));
326   server_ok(9);
327
328   cli_stream->Finish(&recv_status, tag(10));
329   client_ok(10);
330
331   EXPECT_EQ(send_response.message(), recv_response.message());
332   EXPECT_TRUE(recv_status.ok());
333
334   EXPECT_EQ(1, GetCallCounterValue());
335   EXPECT_EQ(1, GetConnectionCounterValue());
336 }
337
338 }  // namespace
339 }  // namespace testing
340 }  // namespace grpc
341
342 int main(int argc, char** argv) {
343   grpc::testing::TestEnvironment env(argc, argv);
344   ::testing::InitGoogleTest(&argc, argv);
345   return RUN_ALL_TESTS();
346 }