3 * Copyright 2016 gRPC authors.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
23 #include <gtest/gtest.h>
25 #include "absl/memory/memory.h"
27 #include <grpc/grpc.h>
28 #include <grpc/support/time.h>
29 #include <grpcpp/channel.h>
30 #include <grpcpp/client_context.h>
31 #include <grpcpp/create_channel.h>
32 #include <grpcpp/generic/async_generic_service.h>
33 #include <grpcpp/generic/generic_stub.h>
34 #include <grpcpp/impl/codegen/proto_utils.h>
35 #include <grpcpp/server.h>
36 #include <grpcpp/server_builder.h>
37 #include <grpcpp/server_context.h>
38 #include <grpcpp/support/config.h>
39 #include <grpcpp/support/slice.h>
41 #include "src/cpp/common/channel_filter.h"
42 #include "src/proto/grpc/testing/echo.grpc.pb.h"
43 #include "test/core/util/port.h"
44 #include "test/core/util/test_config.h"
45 #include "test/cpp/util/byte_buffer_proto_helper.h"
47 using grpc::testing::EchoRequest;
48 using grpc::testing::EchoResponse;
54 void* tag(int i) { return reinterpret_cast<void*>(i); }
56 void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
59 EXPECT_TRUE(cq->Next(&got_tag, &ok));
60 EXPECT_EQ(expect_ok, ok);
61 EXPECT_EQ(tag(i), got_tag);
66 int global_num_connections = 0;
67 int global_num_calls = 0;
70 void IncrementConnectionCounter() {
71 std::unique_lock<std::mutex> lock(global_mu);
72 ++global_num_connections;
75 void ResetConnectionCounter() {
76 std::unique_lock<std::mutex> lock(global_mu);
77 global_num_connections = 0;
80 int GetConnectionCounterValue() {
81 std::unique_lock<std::mutex> lock(global_mu);
82 return global_num_connections;
85 void IncrementCallCounter() {
86 std::unique_lock<std::mutex> lock(global_mu);
90 void ResetCallCounter() {
91 std::unique_lock<std::mutex> lock(global_mu);
95 int GetCallCounterValue() {
96 std::unique_lock<std::mutex> lock(global_mu);
97 return global_num_calls;
102 class ChannelDataImpl : public ChannelData {
104 grpc_error_handle Init(grpc_channel_element* /*elem*/,
105 grpc_channel_element_args* /*args*/) override {
106 IncrementConnectionCounter();
107 return GRPC_ERROR_NONE;
111 class CallDataImpl : public CallData {
113 void StartTransportStreamOpBatch(grpc_call_element* elem,
114 TransportStreamOpBatch* op) override {
115 // Incrementing the counter could be done from Init(), but we want
116 // to test that the individual methods are actually called correctly.
117 if (op->recv_initial_metadata() != nullptr) IncrementCallCounter();
118 grpc_call_next_op(elem, op->op());
122 class FilterEnd2endTest : public ::testing::Test {
124 FilterEnd2endTest() : server_host_("localhost") {}
126 static void SetUpTestCase() {
128 // https://github.com/google/google-toolbox-for-mac/issues/242
129 static bool setup_done = false;
132 grpc::RegisterChannelFilter<ChannelDataImpl, CallDataImpl>(
133 "test-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr);
137 void SetUp() override {
138 int port = grpc_pick_unused_port_or_die();
139 server_address_ << server_host_ << ":" << port;
141 ServerBuilder builder;
142 builder.AddListeningPort(server_address_.str(),
143 InsecureServerCredentials());
144 builder.RegisterAsyncGenericService(&generic_service_);
145 srv_cq_ = builder.AddCompletionQueue();
146 server_ = builder.BuildAndStart();
149 void TearDown() override {
155 while (cli_cq_.Next(&ignored_tag, &ignored_ok)) {
157 while (srv_cq_->Next(&ignored_tag, &ignored_ok)) {
162 std::shared_ptr<Channel> channel = grpc::CreateChannel(
163 server_address_.str(), InsecureChannelCredentials());
164 generic_stub_ = absl::make_unique<GenericStub>(channel);
165 ResetConnectionCounter();
169 void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); }
170 void client_ok(int i) { verify_ok(&cli_cq_, i, true); }
171 void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); }
172 void client_fail(int i) { verify_ok(&cli_cq_, i, false); }
174 void SendRpc(int num_rpcs) {
175 const std::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
176 for (int i = 0; i < num_rpcs; i++) {
177 EchoRequest send_request;
178 EchoRequest recv_request;
179 EchoResponse send_response;
180 EchoResponse recv_response;
183 ClientContext cli_ctx;
184 GenericServerContext srv_ctx;
185 GenericServerAsyncReaderWriter stream(&srv_ctx);
187 // The string needs to be long enough to test heap-based slice.
188 send_request.set_message("Hello world. Hello world. Hello world.");
189 std::thread request_call([this]() { server_ok(4); });
190 std::unique_ptr<GenericClientAsyncReaderWriter> call =
191 generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
192 call->StartCall(tag(1));
194 std::unique_ptr<ByteBuffer> send_buffer =
195 SerializeToByteBuffer(&send_request);
196 call->Write(*send_buffer, tag(2));
197 // Send ByteBuffer can be destroyed after calling Write.
200 call->WritesDone(tag(3));
203 generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(),
204 srv_cq_.get(), tag(4));
207 EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
208 EXPECT_EQ(kMethodName, srv_ctx.method());
209 ByteBuffer recv_buffer;
210 stream.Read(&recv_buffer, tag(5));
212 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
213 EXPECT_EQ(send_request.message(), recv_request.message());
215 send_response.set_message(recv_request.message());
216 send_buffer = SerializeToByteBuffer(&send_response);
217 stream.Write(*send_buffer, tag(6));
221 stream.Finish(Status::OK, tag(7));
225 call->Read(&recv_buffer, tag(8));
227 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
229 call->Finish(&recv_status, tag(9));
232 EXPECT_EQ(send_response.message(), recv_response.message());
233 EXPECT_TRUE(recv_status.ok());
237 CompletionQueue cli_cq_;
238 std::unique_ptr<ServerCompletionQueue> srv_cq_;
239 std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
240 std::unique_ptr<grpc::GenericStub> generic_stub_;
241 std::unique_ptr<Server> server_;
242 AsyncGenericService generic_service_;
243 const std::string server_host_;
244 std::ostringstream server_address_;
247 TEST_F(FilterEnd2endTest, SimpleRpc) {
249 EXPECT_EQ(0, GetConnectionCounterValue());
250 EXPECT_EQ(0, GetCallCounterValue());
252 EXPECT_EQ(1, GetConnectionCounterValue());
253 EXPECT_EQ(1, GetCallCounterValue());
256 TEST_F(FilterEnd2endTest, SequentialRpcs) {
258 EXPECT_EQ(0, GetConnectionCounterValue());
259 EXPECT_EQ(0, GetCallCounterValue());
261 EXPECT_EQ(1, GetConnectionCounterValue());
262 EXPECT_EQ(10, GetCallCounterValue());
265 // One ping, one pong.
266 TEST_F(FilterEnd2endTest, SimpleBidiStreaming) {
268 EXPECT_EQ(0, GetConnectionCounterValue());
269 EXPECT_EQ(0, GetCallCounterValue());
271 const std::string kMethodName(
272 "/grpc.cpp.test.util.EchoTestService/BidiStream");
273 EchoRequest send_request;
274 EchoRequest recv_request;
275 EchoResponse send_response;
276 EchoResponse recv_response;
278 ClientContext cli_ctx;
279 GenericServerContext srv_ctx;
280 GenericServerAsyncReaderWriter srv_stream(&srv_ctx);
282 cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
283 send_request.set_message("Hello");
284 std::thread request_call([this]() { server_ok(2); });
285 std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream =
286 generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
287 cli_stream->StartCall(tag(1));
290 generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(),
291 srv_cq_.get(), tag(2));
294 EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
295 EXPECT_EQ(kMethodName, srv_ctx.method());
297 std::unique_ptr<ByteBuffer> send_buffer =
298 SerializeToByteBuffer(&send_request);
299 cli_stream->Write(*send_buffer, tag(3));
303 ByteBuffer recv_buffer;
304 srv_stream.Read(&recv_buffer, tag(4));
306 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
307 EXPECT_EQ(send_request.message(), recv_request.message());
309 send_response.set_message(recv_request.message());
310 send_buffer = SerializeToByteBuffer(&send_response);
311 srv_stream.Write(*send_buffer, tag(5));
315 cli_stream->Read(&recv_buffer, tag(6));
317 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
318 EXPECT_EQ(send_response.message(), recv_response.message());
320 cli_stream->WritesDone(tag(7));
323 srv_stream.Read(&recv_buffer, tag(8));
326 srv_stream.Finish(Status::OK, tag(9));
329 cli_stream->Finish(&recv_status, tag(10));
332 EXPECT_EQ(send_response.message(), recv_response.message());
333 EXPECT_TRUE(recv_status.ok());
335 EXPECT_EQ(1, GetCallCounterValue());
336 EXPECT_EQ(1, GetConnectionCounterValue());
340 } // namespace testing
343 int main(int argc, char** argv) {
344 grpc::testing::TestEnvironment env(argc, argv);
345 ::testing::InitGoogleTest(&argc, argv);
346 return RUN_ALL_TESTS();