Imported Upstream version 1.34.0
[platform/upstream/grpc.git] / test / cpp / end2end / filter_end2end_test.cc
1 /*
2  *
3  * Copyright 2016 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <memory>
20 #include <mutex>
21 #include <thread>
22
23 #include <grpc/grpc.h>
24 #include <grpc/support/time.h>
25 #include <grpcpp/channel.h>
26 #include <grpcpp/client_context.h>
27 #include <grpcpp/create_channel.h>
28 #include <grpcpp/generic/async_generic_service.h>
29 #include <grpcpp/generic/generic_stub.h>
30 #include <grpcpp/impl/codegen/proto_utils.h>
31 #include <grpcpp/server.h>
32 #include <grpcpp/server_builder.h>
33 #include <grpcpp/server_context.h>
34 #include <grpcpp/support/config.h>
35 #include <grpcpp/support/slice.h>
36
37 #include "absl/memory/memory.h"
38
39 #include "src/cpp/common/channel_filter.h"
40 #include "src/proto/grpc/testing/echo.grpc.pb.h"
41 #include "test/core/util/port.h"
42 #include "test/core/util/test_config.h"
43 #include "test/cpp/util/byte_buffer_proto_helper.h"
44
45 #include <gtest/gtest.h>
46
47 using grpc::testing::EchoRequest;
48 using grpc::testing::EchoResponse;
49
50 namespace grpc {
51 namespace testing {
52 namespace {
53
54 void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
55
56 void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
57   bool ok;
58   void* got_tag;
59   EXPECT_TRUE(cq->Next(&got_tag, &ok));
60   EXPECT_EQ(expect_ok, ok);
61   EXPECT_EQ(tag(i), got_tag);
62 }
63
64 namespace {
65
66 int global_num_connections = 0;
67 int global_num_calls = 0;
68 std::mutex global_mu;
69
70 void IncrementConnectionCounter() {
71   std::unique_lock<std::mutex> lock(global_mu);
72   ++global_num_connections;
73 }
74
75 void ResetConnectionCounter() {
76   std::unique_lock<std::mutex> lock(global_mu);
77   global_num_connections = 0;
78 }
79
80 int GetConnectionCounterValue() {
81   std::unique_lock<std::mutex> lock(global_mu);
82   return global_num_connections;
83 }
84
85 void IncrementCallCounter() {
86   std::unique_lock<std::mutex> lock(global_mu);
87   ++global_num_calls;
88 }
89
90 void ResetCallCounter() {
91   std::unique_lock<std::mutex> lock(global_mu);
92   global_num_calls = 0;
93 }
94
95 int GetCallCounterValue() {
96   std::unique_lock<std::mutex> lock(global_mu);
97   return global_num_calls;
98 }
99
100 }  // namespace
101
102 class ChannelDataImpl : public ChannelData {
103  public:
104   grpc_error* Init(grpc_channel_element* /*elem*/,
105                    grpc_channel_element_args* /*args*/) override {
106     IncrementConnectionCounter();
107     return GRPC_ERROR_NONE;
108   }
109 };
110
111 class CallDataImpl : public CallData {
112  public:
113   void StartTransportStreamOpBatch(grpc_call_element* elem,
114                                    TransportStreamOpBatch* op) override {
115     // Incrementing the counter could be done from Init(), but we want
116     // to test that the individual methods are actually called correctly.
117     if (op->recv_initial_metadata() != nullptr) IncrementCallCounter();
118     grpc_call_next_op(elem, op->op());
119   }
120 };
121
122 class FilterEnd2endTest : public ::testing::Test {
123  protected:
124   FilterEnd2endTest() : server_host_("localhost") {}
125
126   static void SetUpTestCase() {
127     // Workaround for
128     // https://github.com/google/google-toolbox-for-mac/issues/242
129     static bool setup_done = false;
130     if (!setup_done) {
131       setup_done = true;
132       grpc::RegisterChannelFilter<ChannelDataImpl, CallDataImpl>(
133           "test-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr);
134     }
135   }
136
137   void SetUp() override {
138     int port = grpc_pick_unused_port_or_die();
139     server_address_ << server_host_ << ":" << port;
140     // Setup server
141     ServerBuilder builder;
142     builder.AddListeningPort(server_address_.str(),
143                              InsecureServerCredentials());
144     builder.RegisterAsyncGenericService(&generic_service_);
145     srv_cq_ = builder.AddCompletionQueue();
146     server_ = builder.BuildAndStart();
147   }
148
149   void TearDown() override {
150     server_->Shutdown();
151     void* ignored_tag;
152     bool ignored_ok;
153     cli_cq_.Shutdown();
154     srv_cq_->Shutdown();
155     while (cli_cq_.Next(&ignored_tag, &ignored_ok)) {
156     }
157     while (srv_cq_->Next(&ignored_tag, &ignored_ok)) {
158     }
159   }
160
161   void ResetStub() {
162     std::shared_ptr<Channel> channel = grpc::CreateChannel(
163         server_address_.str(), InsecureChannelCredentials());
164     generic_stub_ = absl::make_unique<GenericStub>(channel);
165     ResetConnectionCounter();
166     ResetCallCounter();
167   }
168
169   void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); }
170   void client_ok(int i) { verify_ok(&cli_cq_, i, true); }
171   void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); }
172   void client_fail(int i) { verify_ok(&cli_cq_, i, false); }
173
174   void SendRpc(int num_rpcs) {
175     const std::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
176     for (int i = 0; i < num_rpcs; i++) {
177       EchoRequest send_request;
178       EchoRequest recv_request;
179       EchoResponse send_response;
180       EchoResponse recv_response;
181       Status recv_status;
182
183       ClientContext cli_ctx;
184       GenericServerContext srv_ctx;
185       GenericServerAsyncReaderWriter stream(&srv_ctx);
186
187       // The string needs to be long enough to test heap-based slice.
188       send_request.set_message("Hello world. Hello world. Hello world.");
189       std::thread request_call([this]() { server_ok(4); });
190       std::unique_ptr<GenericClientAsyncReaderWriter> call =
191           generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
192       call->StartCall(tag(1));
193       client_ok(1);
194       std::unique_ptr<ByteBuffer> send_buffer =
195           SerializeToByteBuffer(&send_request);
196       call->Write(*send_buffer, tag(2));
197       // Send ByteBuffer can be destroyed after calling Write.
198       send_buffer.reset();
199       client_ok(2);
200       call->WritesDone(tag(3));
201       client_ok(3);
202
203       generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(),
204                                    srv_cq_.get(), tag(4));
205
206       request_call.join();
207       EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
208       EXPECT_EQ(kMethodName, srv_ctx.method());
209       ByteBuffer recv_buffer;
210       stream.Read(&recv_buffer, tag(5));
211       server_ok(5);
212       EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
213       EXPECT_EQ(send_request.message(), recv_request.message());
214
215       send_response.set_message(recv_request.message());
216       send_buffer = SerializeToByteBuffer(&send_response);
217       stream.Write(*send_buffer, tag(6));
218       send_buffer.reset();
219       server_ok(6);
220
221       stream.Finish(Status::OK, tag(7));
222       server_ok(7);
223
224       recv_buffer.Clear();
225       call->Read(&recv_buffer, tag(8));
226       client_ok(8);
227       EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
228
229       call->Finish(&recv_status, tag(9));
230       client_ok(9);
231
232       EXPECT_EQ(send_response.message(), recv_response.message());
233       EXPECT_TRUE(recv_status.ok());
234     }
235   }
236
237   CompletionQueue cli_cq_;
238   std::unique_ptr<ServerCompletionQueue> srv_cq_;
239   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
240   std::unique_ptr<grpc::GenericStub> generic_stub_;
241   std::unique_ptr<Server> server_;
242   AsyncGenericService generic_service_;
243   const std::string server_host_;
244   std::ostringstream server_address_;
245 };
246
247 TEST_F(FilterEnd2endTest, SimpleRpc) {
248   ResetStub();
249   EXPECT_EQ(0, GetConnectionCounterValue());
250   EXPECT_EQ(0, GetCallCounterValue());
251   SendRpc(1);
252   EXPECT_EQ(1, GetConnectionCounterValue());
253   EXPECT_EQ(1, GetCallCounterValue());
254 }
255
256 TEST_F(FilterEnd2endTest, SequentialRpcs) {
257   ResetStub();
258   EXPECT_EQ(0, GetConnectionCounterValue());
259   EXPECT_EQ(0, GetCallCounterValue());
260   SendRpc(10);
261   EXPECT_EQ(1, GetConnectionCounterValue());
262   EXPECT_EQ(10, GetCallCounterValue());
263 }
264
265 // One ping, one pong.
266 TEST_F(FilterEnd2endTest, SimpleBidiStreaming) {
267   ResetStub();
268   EXPECT_EQ(0, GetConnectionCounterValue());
269   EXPECT_EQ(0, GetCallCounterValue());
270
271   const std::string kMethodName(
272       "/grpc.cpp.test.util.EchoTestService/BidiStream");
273   EchoRequest send_request;
274   EchoRequest recv_request;
275   EchoResponse send_response;
276   EchoResponse recv_response;
277   Status recv_status;
278   ClientContext cli_ctx;
279   GenericServerContext srv_ctx;
280   GenericServerAsyncReaderWriter srv_stream(&srv_ctx);
281
282   cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
283   send_request.set_message("Hello");
284   std::thread request_call([this]() { server_ok(2); });
285   std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream =
286       generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
287   cli_stream->StartCall(tag(1));
288   client_ok(1);
289
290   generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(),
291                                srv_cq_.get(), tag(2));
292
293   request_call.join();
294   EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
295   EXPECT_EQ(kMethodName, srv_ctx.method());
296
297   std::unique_ptr<ByteBuffer> send_buffer =
298       SerializeToByteBuffer(&send_request);
299   cli_stream->Write(*send_buffer, tag(3));
300   send_buffer.reset();
301   client_ok(3);
302
303   ByteBuffer recv_buffer;
304   srv_stream.Read(&recv_buffer, tag(4));
305   server_ok(4);
306   EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
307   EXPECT_EQ(send_request.message(), recv_request.message());
308
309   send_response.set_message(recv_request.message());
310   send_buffer = SerializeToByteBuffer(&send_response);
311   srv_stream.Write(*send_buffer, tag(5));
312   send_buffer.reset();
313   server_ok(5);
314
315   cli_stream->Read(&recv_buffer, tag(6));
316   client_ok(6);
317   EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
318   EXPECT_EQ(send_response.message(), recv_response.message());
319
320   cli_stream->WritesDone(tag(7));
321   client_ok(7);
322
323   srv_stream.Read(&recv_buffer, tag(8));
324   server_fail(8);
325
326   srv_stream.Finish(Status::OK, tag(9));
327   server_ok(9);
328
329   cli_stream->Finish(&recv_status, tag(10));
330   client_ok(10);
331
332   EXPECT_EQ(send_response.message(), recv_response.message());
333   EXPECT_TRUE(recv_status.ok());
334
335   EXPECT_EQ(1, GetCallCounterValue());
336   EXPECT_EQ(1, GetConnectionCounterValue());
337 }
338
339 }  // namespace
340 }  // namespace testing
341 }  // namespace grpc
342
343 int main(int argc, char** argv) {
344   grpc::testing::TestEnvironment env(argc, argv);
345   ::testing::InitGoogleTest(&argc, argv);
346   return RUN_ALL_TESTS();
347 }