Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / test / core / tsi / alts / fake_handshaker / fake_handshaker_server.cc
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 #include "test/core/tsi/alts/fake_handshaker/fake_handshaker_server.h"
19
20 #include <memory>
21 #include <sstream>
22 #include <string>
23
24 #include <grpc/grpc.h>
25 #include <grpc/support/log.h>
26 #include <grpcpp/impl/codegen/async_stream.h>
27 #include <grpcpp/impl/codegen/sync.h>
28 #include <grpcpp/security/server_credentials.h>
29 #include <grpcpp/server.h>
30 #include <grpcpp/server_builder.h>
31 #include <grpcpp/server_context.h>
32
33 #include "test/core/tsi/alts/fake_handshaker/handshaker.grpc.pb.h"
34 #include "test/core/tsi/alts/fake_handshaker/handshaker.pb.h"
35 #include "test/core/tsi/alts/fake_handshaker/transport_security_common.pb.h"
36
37 // Fake handshake messages.
38 constexpr char kClientInitFrame[] = "ClientInit";
39 constexpr char kServerFrame[] = "ServerInitAndFinished";
40 constexpr char kClientFinishFrame[] = "ClientFinished";
41 // Error messages.
42 constexpr char kInvalidFrameError[] = "Invalid input frame.";
43 constexpr char kWrongStateError[] = "Wrong handshake state.";
44
45 namespace grpc {
46 namespace gcp {
47
48 // FakeHandshakeService implements a fake handshaker service using a fake key
49 // exchange protocol. The fake key exchange protocol is a 3-message protocol:
50 // - Client first sends ClientInit message to Server.
51 // - Server then sends ServerInitAndFinished message back to Client.
52 // - Client finally sends ClientFinished message to Server.
53 // This fake handshaker service is intended for ALTS integration testing without
54 // relying on real ALTS handshaker service inside GCE.
55 // It is thread-safe.
56 class FakeHandshakerService : public HandshakerService::Service {
57  public:
58   explicit FakeHandshakerService(int expected_max_concurrent_rpcs)
59       : expected_max_concurrent_rpcs_(expected_max_concurrent_rpcs) {}
60
61   Status DoHandshake(
62       ServerContext* /*server_context*/,
63       ServerReaderWriter<HandshakerResp, HandshakerReq>* stream) override {
64     ConcurrentRpcsCheck concurrent_rpcs_check(this);
65     Status status;
66     HandshakerContext context;
67     HandshakerReq request;
68     HandshakerResp response;
69     gpr_log(GPR_DEBUG, "Start a new handshake.");
70     while (stream->Read(&request)) {
71       status = ProcessRequest(&context, request, &response);
72       if (!status.ok()) return WriteErrorResponse(stream, status);
73       stream->Write(response);
74       if (context.state == COMPLETED) return Status::OK;
75       request.Clear();
76     }
77     return Status::OK;
78   }
79
80  private:
81   // HandshakeState is used by fake handshaker server to keep track of client's
82   // handshake status. In the beginning of a handshake, the state is INITIAL.
83   // If start_client or start_server request is called, the state becomes at
84   // least STARTED. When the handshaker server produces the first fame, the
85   // state becomes SENT. After the handshaker server processes the final frame
86   // from the peer, the state becomes COMPLETED.
87   enum HandshakeState { INITIAL, STARTED, SENT, COMPLETED };
88
89   struct HandshakerContext {
90     bool is_client = true;
91     HandshakeState state = INITIAL;
92   };
93
94   Status ProcessRequest(HandshakerContext* context,
95                         const HandshakerReq& request,
96                         HandshakerResp* response) {
97     GPR_ASSERT(context != nullptr && response != nullptr);
98     response->Clear();
99     if (request.has_client_start()) {
100       gpr_log(GPR_DEBUG, "Process client start request.");
101       return ProcessClientStart(context, request.client_start(), response);
102     } else if (request.has_server_start()) {
103       gpr_log(GPR_DEBUG, "Process server start request.");
104       return ProcessServerStart(context, request.server_start(), response);
105     } else if (request.has_next()) {
106       gpr_log(GPR_DEBUG, "Process next request.");
107       return ProcessNext(context, request.next(), response);
108     }
109     return Status(StatusCode::INVALID_ARGUMENT, "Request is empty.");
110   }
111
112   Status ProcessClientStart(HandshakerContext* context,
113                             const StartClientHandshakeReq& request,
114                             HandshakerResp* response) {
115     GPR_ASSERT(context != nullptr && response != nullptr);
116     // Checks request.
117     if (context->state != INITIAL) {
118       return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError);
119     }
120     if (request.application_protocols_size() == 0) {
121       return Status(StatusCode::INVALID_ARGUMENT,
122                     "At least one application protocol needed.");
123     }
124     if (request.record_protocols_size() == 0) {
125       return Status(StatusCode::INVALID_ARGUMENT,
126                     "At least one record protocol needed.");
127     }
128     // Sets response.
129     response->set_out_frames(kClientInitFrame);
130     response->set_bytes_consumed(0);
131     response->mutable_status()->set_code(StatusCode::OK);
132     // Updates handshaker context.
133     context->is_client = true;
134     context->state = SENT;
135     return Status::OK;
136   }
137
138   Status ProcessServerStart(HandshakerContext* context,
139                             const StartServerHandshakeReq& request,
140                             HandshakerResp* response) {
141     GPR_ASSERT(context != nullptr && response != nullptr);
142     // Checks request.
143     if (context->state != INITIAL) {
144       return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError);
145     }
146     if (request.application_protocols_size() == 0) {
147       return Status(StatusCode::INVALID_ARGUMENT,
148                     "At least one application protocol needed.");
149     }
150     if (request.handshake_parameters().empty()) {
151       return Status(StatusCode::INVALID_ARGUMENT,
152                     "At least one set of handshake parameters needed.");
153     }
154     // Sets response.
155     if (request.in_bytes().empty()) {
156       // start_server request does not have in_bytes.
157       response->set_bytes_consumed(0);
158       context->state = STARTED;
159     } else {
160       // start_server request has in_bytes.
161       if (request.in_bytes() == kClientInitFrame) {
162         response->set_out_frames(kServerFrame);
163         response->set_bytes_consumed(strlen(kClientInitFrame));
164         context->state = SENT;
165       } else {
166         return Status(StatusCode::UNKNOWN, kInvalidFrameError);
167       }
168     }
169     response->mutable_status()->set_code(StatusCode::OK);
170     context->is_client = false;
171     return Status::OK;
172   }
173
174   Status ProcessNext(HandshakerContext* context,
175                      const NextHandshakeMessageReq& request,
176                      HandshakerResp* response) {
177     GPR_ASSERT(context != nullptr && response != nullptr);
178     if (context->is_client) {
179       // Processes next request on client side.
180       if (context->state != SENT) {
181         return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError);
182       }
183       if (request.in_bytes() != kServerFrame) {
184         return Status(StatusCode::UNKNOWN, kInvalidFrameError);
185       }
186       response->set_out_frames(kClientFinishFrame);
187       response->set_bytes_consumed(strlen(kServerFrame));
188       context->state = COMPLETED;
189     } else {
190       // Processes next request on server side.
191       HandshakeState current_state = context->state;
192       if (current_state == STARTED) {
193         if (request.in_bytes() != kClientInitFrame) {
194           return Status(StatusCode::UNKNOWN, kInvalidFrameError);
195         }
196         response->set_out_frames(kServerFrame);
197         response->set_bytes_consumed(strlen(kClientInitFrame));
198         context->state = SENT;
199       } else if (current_state == SENT) {
200         // Client finish frame may be sent along with the first payload from the
201         // client, handshaker only consumes the client finish frame.
202         if (request.in_bytes().substr(0, strlen(kClientFinishFrame)) !=
203             kClientFinishFrame) {
204           return Status(StatusCode::UNKNOWN, kInvalidFrameError);
205         }
206         response->set_bytes_consumed(strlen(kClientFinishFrame));
207         context->state = COMPLETED;
208       } else {
209         return Status(StatusCode::FAILED_PRECONDITION, kWrongStateError);
210       }
211     }
212     // At this point, processing next request succeeded.
213     response->mutable_status()->set_code(StatusCode::OK);
214     if (context->state == COMPLETED) {
215       *response->mutable_result() = GetHandshakerResult();
216     }
217     return Status::OK;
218   }
219
220   Status WriteErrorResponse(
221       ServerReaderWriter<HandshakerResp, HandshakerReq>* stream,
222       const Status& status) {
223     GPR_ASSERT(!status.ok());
224     HandshakerResp response;
225     response.mutable_status()->set_code(status.error_code());
226     response.mutable_status()->set_details(status.error_message());
227     stream->Write(response);
228     return status;
229   }
230
231   HandshakerResult GetHandshakerResult() {
232     HandshakerResult result;
233     result.set_application_protocol("grpc");
234     result.set_record_protocol("ALTSRP_GCM_AES128_REKEY");
235     result.mutable_peer_identity()->set_service_account("peer_identity");
236     result.mutable_local_identity()->set_service_account("local_identity");
237     string key(1024, '\0');
238     result.set_key_data(key);
239     result.mutable_peer_rpc_versions()->mutable_max_rpc_version()->set_major(2);
240     result.mutable_peer_rpc_versions()->mutable_max_rpc_version()->set_minor(1);
241     result.mutable_peer_rpc_versions()->mutable_min_rpc_version()->set_major(2);
242     result.mutable_peer_rpc_versions()->mutable_min_rpc_version()->set_minor(1);
243     return result;
244   }
245
246   class ConcurrentRpcsCheck {
247    public:
248     explicit ConcurrentRpcsCheck(FakeHandshakerService* parent)
249         : parent_(parent) {
250       if (parent->expected_max_concurrent_rpcs_ > 0) {
251         grpc::internal::MutexLock lock(
252             &parent->expected_max_concurrent_rpcs_mu_);
253         if (++parent->concurrent_rpcs_ >
254             parent->expected_max_concurrent_rpcs_) {
255           gpr_log(GPR_ERROR,
256                   "FakeHandshakerService:%p concurrent_rpcs_:%d "
257                   "expected_max_concurrent_rpcs:%d",
258                   parent, parent->concurrent_rpcs_,
259                   parent->expected_max_concurrent_rpcs_);
260           abort();
261         }
262       }
263     }
264
265     ~ConcurrentRpcsCheck() {
266       if (parent_->expected_max_concurrent_rpcs_ > 0) {
267         grpc::internal::MutexLock lock(
268             &parent_->expected_max_concurrent_rpcs_mu_);
269         parent_->concurrent_rpcs_--;
270       }
271     }
272
273    private:
274     FakeHandshakerService* parent_;
275   };
276
277   grpc::internal::Mutex expected_max_concurrent_rpcs_mu_;
278   int concurrent_rpcs_ = 0;
279   const int expected_max_concurrent_rpcs_;
280 };
281
282 std::unique_ptr<grpc::Service> CreateFakeHandshakerService(
283     int expected_max_concurrent_rpcs) {
284   return std::unique_ptr<grpc::Service>{
285       new grpc::gcp::FakeHandshakerService(expected_max_concurrent_rpcs)};
286 }
287
288 }  // namespace gcp
289 }  // namespace grpc