3 * Copyright 2016 gRPC authors.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
22 #include <grpc/grpc.h>
23 #include <grpc/support/time.h>
24 #include <grpcpp/channel.h>
25 #include <grpcpp/client_context.h>
26 #include <grpcpp/create_channel.h>
27 #include <grpcpp/generic/async_generic_service.h>
28 #include <grpcpp/generic/generic_stub.h>
29 #include <grpcpp/impl/codegen/proto_utils.h>
30 #include <grpcpp/server.h>
31 #include <grpcpp/server_builder.h>
32 #include <grpcpp/server_context.h>
33 #include <grpcpp/support/config.h>
34 #include <grpcpp/support/slice.h>
36 #include "src/cpp/common/channel_filter.h"
37 #include "src/proto/grpc/testing/echo.grpc.pb.h"
38 #include "test/core/util/port.h"
39 #include "test/core/util/test_config.h"
40 #include "test/cpp/util/byte_buffer_proto_helper.h"
42 #include <gtest/gtest.h>
44 using grpc::testing::EchoRequest;
45 using grpc::testing::EchoResponse;
46 using std::chrono::system_clock;
52 void* tag(int i) { return (void*)static_cast<intptr_t>(i); }
54 void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
57 EXPECT_TRUE(cq->Next(&got_tag, &ok));
58 EXPECT_EQ(expect_ok, ok);
59 EXPECT_EQ(tag(i), got_tag);
64 int global_num_connections = 0;
65 int global_num_calls = 0;
68 void IncrementConnectionCounter() {
69 std::unique_lock<std::mutex> lock(global_mu);
70 ++global_num_connections;
73 void ResetConnectionCounter() {
74 std::unique_lock<std::mutex> lock(global_mu);
75 global_num_connections = 0;
78 int GetConnectionCounterValue() {
79 std::unique_lock<std::mutex> lock(global_mu);
80 return global_num_connections;
83 void IncrementCallCounter() {
84 std::unique_lock<std::mutex> lock(global_mu);
88 void ResetCallCounter() {
89 std::unique_lock<std::mutex> lock(global_mu);
93 int GetCallCounterValue() {
94 std::unique_lock<std::mutex> lock(global_mu);
95 return global_num_calls;
100 class ChannelDataImpl : public ChannelData {
102 grpc_error* Init(grpc_channel_element* elem,
103 grpc_channel_element_args* args) {
104 IncrementConnectionCounter();
105 return GRPC_ERROR_NONE;
109 class CallDataImpl : public CallData {
111 void StartTransportStreamOpBatch(grpc_call_element* elem,
112 TransportStreamOpBatch* op) override {
113 // Incrementing the counter could be done from Init(), but we want
114 // to test that the individual methods are actually called correctly.
115 if (op->recv_initial_metadata() != nullptr) IncrementCallCounter();
116 grpc_call_next_op(elem, op->op());
120 class FilterEnd2endTest : public ::testing::Test {
122 FilterEnd2endTest() : server_host_("localhost") {}
124 static void SetUpTestCase() {
126 // https://github.com/google/google-toolbox-for-mac/issues/242
127 static bool setup_done = false;
130 grpc::RegisterChannelFilter<ChannelDataImpl, CallDataImpl>(
131 "test-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr);
135 void SetUp() override {
136 int port = grpc_pick_unused_port_or_die();
137 server_address_ << server_host_ << ":" << port;
139 ServerBuilder builder;
140 builder.AddListeningPort(server_address_.str(),
141 InsecureServerCredentials());
142 builder.RegisterAsyncGenericService(&generic_service_);
143 srv_cq_ = builder.AddCompletionQueue();
144 server_ = builder.BuildAndStart();
147 void TearDown() override {
153 while (cli_cq_.Next(&ignored_tag, &ignored_ok))
155 while (srv_cq_->Next(&ignored_tag, &ignored_ok))
160 std::shared_ptr<Channel> channel = grpc::CreateChannel(
161 server_address_.str(), InsecureChannelCredentials());
162 generic_stub_.reset(new GenericStub(channel));
163 ResetConnectionCounter();
167 void server_ok(int i) { verify_ok(srv_cq_.get(), i, true); }
168 void client_ok(int i) { verify_ok(&cli_cq_, i, true); }
169 void server_fail(int i) { verify_ok(srv_cq_.get(), i, false); }
170 void client_fail(int i) { verify_ok(&cli_cq_, i, false); }
172 void SendRpc(int num_rpcs) {
173 const grpc::string kMethodName("/grpc.cpp.test.util.EchoTestService/Echo");
174 for (int i = 0; i < num_rpcs; i++) {
175 EchoRequest send_request;
176 EchoRequest recv_request;
177 EchoResponse send_response;
178 EchoResponse recv_response;
181 ClientContext cli_ctx;
182 GenericServerContext srv_ctx;
183 GenericServerAsyncReaderWriter stream(&srv_ctx);
185 // The string needs to be long enough to test heap-based slice.
186 send_request.set_message("Hello world. Hello world. Hello world.");
187 std::unique_ptr<GenericClientAsyncReaderWriter> call =
188 generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
189 call->StartCall(tag(1));
191 std::unique_ptr<ByteBuffer> send_buffer =
192 SerializeToByteBuffer(&send_request);
193 call->Write(*send_buffer, tag(2));
194 // Send ByteBuffer can be destroyed after calling Write.
197 call->WritesDone(tag(3));
200 generic_service_.RequestCall(&srv_ctx, &stream, srv_cq_.get(),
201 srv_cq_.get(), tag(4));
203 verify_ok(srv_cq_.get(), 4, true);
204 EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
205 EXPECT_EQ(kMethodName, srv_ctx.method());
206 ByteBuffer recv_buffer;
207 stream.Read(&recv_buffer, tag(5));
209 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
210 EXPECT_EQ(send_request.message(), recv_request.message());
212 send_response.set_message(recv_request.message());
213 send_buffer = SerializeToByteBuffer(&send_response);
214 stream.Write(*send_buffer, tag(6));
218 stream.Finish(Status::OK, tag(7));
222 call->Read(&recv_buffer, tag(8));
224 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
226 call->Finish(&recv_status, tag(9));
229 EXPECT_EQ(send_response.message(), recv_response.message());
230 EXPECT_TRUE(recv_status.ok());
234 CompletionQueue cli_cq_;
235 std::unique_ptr<ServerCompletionQueue> srv_cq_;
236 std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
237 std::unique_ptr<grpc::GenericStub> generic_stub_;
238 std::unique_ptr<Server> server_;
239 AsyncGenericService generic_service_;
240 const grpc::string server_host_;
241 std::ostringstream server_address_;
244 TEST_F(FilterEnd2endTest, SimpleRpc) {
246 EXPECT_EQ(0, GetConnectionCounterValue());
247 EXPECT_EQ(0, GetCallCounterValue());
249 EXPECT_EQ(1, GetConnectionCounterValue());
250 EXPECT_EQ(1, GetCallCounterValue());
253 TEST_F(FilterEnd2endTest, SequentialRpcs) {
255 EXPECT_EQ(0, GetConnectionCounterValue());
256 EXPECT_EQ(0, GetCallCounterValue());
258 EXPECT_EQ(1, GetConnectionCounterValue());
259 EXPECT_EQ(10, GetCallCounterValue());
262 // One ping, one pong.
263 TEST_F(FilterEnd2endTest, SimpleBidiStreaming) {
265 EXPECT_EQ(0, GetConnectionCounterValue());
266 EXPECT_EQ(0, GetCallCounterValue());
268 const grpc::string kMethodName(
269 "/grpc.cpp.test.util.EchoTestService/BidiStream");
270 EchoRequest send_request;
271 EchoRequest recv_request;
272 EchoResponse send_response;
273 EchoResponse recv_response;
275 ClientContext cli_ctx;
276 GenericServerContext srv_ctx;
277 GenericServerAsyncReaderWriter srv_stream(&srv_ctx);
279 cli_ctx.set_compression_algorithm(GRPC_COMPRESS_GZIP);
280 send_request.set_message("Hello");
281 std::unique_ptr<GenericClientAsyncReaderWriter> cli_stream =
282 generic_stub_->PrepareCall(&cli_ctx, kMethodName, &cli_cq_);
283 cli_stream->StartCall(tag(1));
286 generic_service_.RequestCall(&srv_ctx, &srv_stream, srv_cq_.get(),
287 srv_cq_.get(), tag(2));
289 verify_ok(srv_cq_.get(), 2, true);
290 EXPECT_EQ(server_host_, srv_ctx.host().substr(0, server_host_.length()));
291 EXPECT_EQ(kMethodName, srv_ctx.method());
293 std::unique_ptr<ByteBuffer> send_buffer =
294 SerializeToByteBuffer(&send_request);
295 cli_stream->Write(*send_buffer, tag(3));
299 ByteBuffer recv_buffer;
300 srv_stream.Read(&recv_buffer, tag(4));
302 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_request));
303 EXPECT_EQ(send_request.message(), recv_request.message());
305 send_response.set_message(recv_request.message());
306 send_buffer = SerializeToByteBuffer(&send_response);
307 srv_stream.Write(*send_buffer, tag(5));
311 cli_stream->Read(&recv_buffer, tag(6));
313 EXPECT_TRUE(ParseFromByteBuffer(&recv_buffer, &recv_response));
314 EXPECT_EQ(send_response.message(), recv_response.message());
316 cli_stream->WritesDone(tag(7));
319 srv_stream.Read(&recv_buffer, tag(8));
322 srv_stream.Finish(Status::OK, tag(9));
325 cli_stream->Finish(&recv_status, tag(10));
328 EXPECT_EQ(send_response.message(), recv_response.message());
329 EXPECT_TRUE(recv_status.ok());
331 EXPECT_EQ(1, GetCallCounterValue());
332 EXPECT_EQ(1, GetConnectionCounterValue());
336 } // namespace testing
339 int main(int argc, char** argv) {
340 grpc::testing::TestEnvironment env(argc, argv);
341 ::testing::InitGoogleTest(&argc, argv);
342 return RUN_ALL_TESTS();