Imported Upstream version 1.36.0
[platform/upstream/grpc.git] / src / core / tsi / alts / handshaker / alts_handshaker_client.cc
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <grpc/support/port_platform.h>
20
21 #include <list>
22
23 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
24
25 #include "upb/upb.hpp"
26
27 #include <grpc/byte_buffer.h>
28 #include <grpc/support/alloc.h>
29 #include <grpc/support/log.h>
30
31 #include "src/core/lib/gprpp/sync.h"
32 #include "src/core/lib/slice/slice_internal.h"
33 #include "src/core/lib/surface/call.h"
34 #include "src/core/lib/surface/channel.h"
35 #include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
36 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h"
37 #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h"
38
39 #define TSI_ALTS_INITIAL_BUFFER_SIZE 256
40
41 const int kHandshakerClientOpNum = 4;
42
43 struct alts_handshaker_client {
44   const alts_handshaker_client_vtable* vtable;
45 };
46
47 struct recv_message_result {
48   tsi_result status;
49   const unsigned char* bytes_to_send;
50   size_t bytes_to_send_size;
51   tsi_handshaker_result* result;
52 };
53
54 typedef struct alts_grpc_handshaker_client {
55   alts_handshaker_client base;
56   /* One ref is held by the entity that created this handshaker_client, and
57    * another ref is held by the pending RECEIVE_STATUS_ON_CLIENT op. */
58   gpr_refcount refs;
59   alts_tsi_handshaker* handshaker;
60   grpc_call* call;
61   /* A pointer to a function handling the interaction with handshaker service.
62    * That is, it points to grpc_call_start_batch_and_execute when the handshaker
63    * client is used in a non-testing use case and points to a custom function
64    * that validates the data to be sent to handshaker service in a testing use
65    * case. */
66   alts_grpc_caller grpc_caller;
67   /* A gRPC closure to be scheduled when the response from handshaker service
68    * is received. It will be initialized with the injected grpc RPC callback. */
69   grpc_closure on_handshaker_service_resp_recv;
70   /* Buffers containing information to be sent (or received) to (or from) the
71    * handshaker service. */
72   grpc_byte_buffer* send_buffer = nullptr;
73   grpc_byte_buffer* recv_buffer = nullptr;
74   grpc_status_code status = GRPC_STATUS_OK;
75   /* Initial metadata to be received from handshaker service. */
76   grpc_metadata_array recv_initial_metadata;
77   /* A callback function provided by an application to be invoked when response
78    * is received from handshaker service. */
79   tsi_handshaker_on_next_done_cb cb;
80   void* user_data;
81   /* ALTS credential options passed in from the caller. */
82   grpc_alts_credentials_options* options;
83   /* target name information to be passed to handshaker service for server
84    * authorization check. */
85   grpc_slice target_name;
86   /* boolean flag indicating if the handshaker client is used at client
87    * (is_client = true) or server (is_client = false) side. */
88   bool is_client;
89   /* a temporary store for data received from handshaker service used to extract
90    * unused data. */
91   grpc_slice recv_bytes;
92   /* a buffer containing data to be sent to the grpc client or server's peer. */
93   unsigned char* buffer;
94   size_t buffer_size;
95   /** callback for receiving handshake call status */
96   grpc_closure on_status_received;
97   /** gRPC status code of handshake call */
98   grpc_status_code handshake_status_code = GRPC_STATUS_OK;
99   /** gRPC status details of handshake call */
100   grpc_slice handshake_status_details;
101   /* mu synchronizes all fields below including their internal fields. */
102   grpc_core::Mutex mu;
103   /* indicates if the handshaker call's RECV_STATUS_ON_CLIENT op is done. */
104   bool receive_status_finished = false;
105   /* if non-null, contains arguments to complete a TSI next callback. */
106   recv_message_result* pending_recv_message_result = nullptr;
107   /* Maximum frame size used by frame protector. */
108   size_t max_frame_size;
109 } alts_grpc_handshaker_client;
110
111 static void handshaker_client_send_buffer_destroy(
112     alts_grpc_handshaker_client* client) {
113   GPR_ASSERT(client != nullptr);
114   grpc_byte_buffer_destroy(client->send_buffer);
115   client->send_buffer = nullptr;
116 }
117
118 static bool is_handshake_finished_properly(grpc_gcp_HandshakerResp* resp) {
119   GPR_ASSERT(resp != nullptr);
120   if (grpc_gcp_HandshakerResp_result(resp)) {
121     return true;
122   }
123   return false;
124 }
125
126 static void alts_grpc_handshaker_client_unref(
127     alts_grpc_handshaker_client* client) {
128   if (gpr_unref(&client->refs)) {
129     if (client->base.vtable != nullptr &&
130         client->base.vtable->destruct != nullptr) {
131       client->base.vtable->destruct(&client->base);
132     }
133     grpc_byte_buffer_destroy(client->send_buffer);
134     grpc_byte_buffer_destroy(client->recv_buffer);
135     client->send_buffer = nullptr;
136     client->recv_buffer = nullptr;
137     grpc_metadata_array_destroy(&client->recv_initial_metadata);
138     grpc_slice_unref_internal(client->recv_bytes);
139     grpc_slice_unref_internal(client->target_name);
140     grpc_alts_credentials_options_destroy(client->options);
141     gpr_free(client->buffer);
142     grpc_slice_unref_internal(client->handshake_status_details);
143     delete client;
144   }
145 }
146
147 static void maybe_complete_tsi_next(
148     alts_grpc_handshaker_client* client, bool receive_status_finished,
149     recv_message_result* pending_recv_message_result) {
150   recv_message_result* r;
151   {
152     grpc_core::MutexLock lock(&client->mu);
153     client->receive_status_finished |= receive_status_finished;
154     if (pending_recv_message_result != nullptr) {
155       GPR_ASSERT(client->pending_recv_message_result == nullptr);
156       client->pending_recv_message_result = pending_recv_message_result;
157     }
158     if (client->pending_recv_message_result == nullptr) {
159       return;
160     }
161     const bool have_final_result =
162         client->pending_recv_message_result->result != nullptr ||
163         client->pending_recv_message_result->status != TSI_OK;
164     if (have_final_result && !client->receive_status_finished) {
165       // If we've received the final message from the handshake
166       // server, or we're about to invoke the TSI next callback
167       // with a status other than TSI_OK (which terminates the
168       // handshake), then first wait for the RECV_STATUS op to complete.
169       return;
170     }
171     r = client->pending_recv_message_result;
172     client->pending_recv_message_result = nullptr;
173   }
174   client->cb(r->status, client->user_data, r->bytes_to_send,
175              r->bytes_to_send_size, r->result);
176   gpr_free(r);
177 }
178
179 static void handle_response_done(alts_grpc_handshaker_client* client,
180                                  tsi_result status,
181                                  const unsigned char* bytes_to_send,
182                                  size_t bytes_to_send_size,
183                                  tsi_handshaker_result* result) {
184   recv_message_result* p =
185       static_cast<recv_message_result*>(gpr_zalloc(sizeof(*p)));
186   p->status = status;
187   p->bytes_to_send = bytes_to_send;
188   p->bytes_to_send_size = bytes_to_send_size;
189   p->result = result;
190   maybe_complete_tsi_next(client, false /* receive_status_finished */,
191                           p /* pending_recv_message_result */);
192 }
193
194 void alts_handshaker_client_handle_response(alts_handshaker_client* c,
195                                             bool is_ok) {
196   GPR_ASSERT(c != nullptr);
197   alts_grpc_handshaker_client* client =
198       reinterpret_cast<alts_grpc_handshaker_client*>(c);
199   grpc_byte_buffer* recv_buffer = client->recv_buffer;
200   grpc_status_code status = client->status;
201   alts_tsi_handshaker* handshaker = client->handshaker;
202   /* Invalid input check. */
203   if (client->cb == nullptr) {
204     gpr_log(GPR_ERROR,
205             "client->cb is nullptr in alts_tsi_handshaker_handle_response()");
206     return;
207   }
208   if (handshaker == nullptr) {
209     gpr_log(GPR_ERROR,
210             "handshaker is nullptr in alts_tsi_handshaker_handle_response()");
211     handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr);
212     return;
213   }
214   /* TSI handshake has been shutdown. */
215   if (alts_tsi_handshaker_has_shutdown(handshaker)) {
216     gpr_log(GPR_ERROR, "TSI handshake shutdown");
217     handle_response_done(client, TSI_HANDSHAKE_SHUTDOWN, nullptr, 0, nullptr);
218     return;
219   }
220   /* Failed grpc call check. */
221   if (!is_ok || status != GRPC_STATUS_OK) {
222     gpr_log(GPR_ERROR, "grpc call made to handshaker service failed");
223     handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr);
224     return;
225   }
226   if (recv_buffer == nullptr) {
227     gpr_log(GPR_ERROR,
228             "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()");
229     handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr);
230     return;
231   }
232   upb::Arena arena;
233   grpc_gcp_HandshakerResp* resp =
234       alts_tsi_utils_deserialize_response(recv_buffer, arena.ptr());
235   grpc_byte_buffer_destroy(client->recv_buffer);
236   client->recv_buffer = nullptr;
237   /* Invalid handshaker response check. */
238   if (resp == nullptr) {
239     gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed");
240     handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr);
241     return;
242   }
243   const grpc_gcp_HandshakerStatus* resp_status =
244       grpc_gcp_HandshakerResp_status(resp);
245   if (resp_status == nullptr) {
246     gpr_log(GPR_ERROR, "No status in HandshakerResp");
247     handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr);
248     return;
249   }
250   upb_strview out_frames = grpc_gcp_HandshakerResp_out_frames(resp);
251   unsigned char* bytes_to_send = nullptr;
252   size_t bytes_to_send_size = 0;
253   if (out_frames.size > 0) {
254     bytes_to_send_size = out_frames.size;
255     while (bytes_to_send_size > client->buffer_size) {
256       client->buffer_size *= 2;
257       client->buffer = static_cast<unsigned char*>(
258           gpr_realloc(client->buffer, client->buffer_size));
259     }
260     memcpy(client->buffer, out_frames.data, bytes_to_send_size);
261     bytes_to_send = client->buffer;
262   }
263   tsi_handshaker_result* result = nullptr;
264   if (is_handshake_finished_properly(resp)) {
265     tsi_result status =
266         alts_tsi_handshaker_result_create(resp, client->is_client, &result);
267     if (status != TSI_OK) {
268       gpr_log(GPR_ERROR, "alts_tsi_handshaker_result_create() failed");
269       handle_response_done(client, status, nullptr, 0, nullptr);
270       return;
271     }
272     alts_tsi_handshaker_result_set_unused_bytes(
273         result, &client->recv_bytes,
274         grpc_gcp_HandshakerResp_bytes_consumed(resp));
275   }
276   grpc_status_code code = static_cast<grpc_status_code>(
277       grpc_gcp_HandshakerStatus_code(resp_status));
278   if (code != GRPC_STATUS_OK) {
279     upb_strview details = grpc_gcp_HandshakerStatus_details(resp_status);
280     if (details.size > 0) {
281       char* error_details = static_cast<char*>(gpr_zalloc(details.size + 1));
282       memcpy(error_details, details.data, details.size);
283       gpr_log(GPR_ERROR, "Error from handshaker service:%s", error_details);
284       gpr_free(error_details);
285     }
286   }
287   // TODO(apolcyn): consider short ciruiting handle_response_done and
288   // invoking the TSI callback directly if we aren't done yet, if
289   // handle_response_done's allocation per message received causes
290   // a performance issue.
291   handle_response_done(client, alts_tsi_utils_convert_to_tsi_result(code),
292                        bytes_to_send, bytes_to_send_size, result);
293 }
294
295 static tsi_result continue_make_grpc_call(alts_grpc_handshaker_client* client,
296                                           bool is_start) {
297   GPR_ASSERT(client != nullptr);
298   grpc_op ops[kHandshakerClientOpNum];
299   memset(ops, 0, sizeof(ops));
300   grpc_op* op = ops;
301   if (is_start) {
302     op->op = GRPC_OP_RECV_STATUS_ON_CLIENT;
303     op->data.recv_status_on_client.trailing_metadata = nullptr;
304     op->data.recv_status_on_client.status = &client->handshake_status_code;
305     op->data.recv_status_on_client.status_details =
306         &client->handshake_status_details;
307     op->flags = 0;
308     op->reserved = nullptr;
309     op++;
310     GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
311     gpr_ref(&client->refs);
312     grpc_call_error call_error =
313         client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
314                             &client->on_status_received);
315     // TODO(apolcyn): return the error here instead, as done for other ops?
316     GPR_ASSERT(call_error == GRPC_CALL_OK);
317     memset(ops, 0, sizeof(ops));
318     op = ops;
319     op->op = GRPC_OP_SEND_INITIAL_METADATA;
320     op->data.send_initial_metadata.count = 0;
321     op++;
322     GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
323     op->op = GRPC_OP_RECV_INITIAL_METADATA;
324     op->data.recv_initial_metadata.recv_initial_metadata =
325         &client->recv_initial_metadata;
326     op++;
327     GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
328   }
329   op->op = GRPC_OP_SEND_MESSAGE;
330   op->data.send_message.send_message = client->send_buffer;
331   op++;
332   GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
333   op->op = GRPC_OP_RECV_MESSAGE;
334   op->data.recv_message.recv_message = &client->recv_buffer;
335   op++;
336   GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
337   GPR_ASSERT(client->grpc_caller != nullptr);
338   if (client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
339                           &client->on_handshaker_service_resp_recv) !=
340       GRPC_CALL_OK) {
341     gpr_log(GPR_ERROR, "Start batch operation failed");
342     return TSI_INTERNAL_ERROR;
343   }
344   return TSI_OK;
345 }
346
347 // TODO(apolcyn): remove this global queue when we can safely rely
348 // on a MAX_CONCURRENT_STREAMS setting in the ALTS handshake server to
349 // limit the number of concurrent handshakes.
350 namespace {
351
352 class HandshakeQueue {
353  public:
354   explicit HandshakeQueue(size_t max_outstanding_handshakes)
355       : max_outstanding_handshakes_(max_outstanding_handshakes) {}
356
357   void RequestHandshake(alts_grpc_handshaker_client* client) {
358     {
359       grpc_core::MutexLock lock(&mu_);
360       if (outstanding_handshakes_ == max_outstanding_handshakes_) {
361         // Max number already running, add to queue.
362         queued_handshakes_.push_back(client);
363         return;
364       }
365       // Start the handshake immediately.
366       ++outstanding_handshakes_;
367     }
368     continue_make_grpc_call(client, true /* is_start */);
369   }
370
371   void HandshakeDone() {
372     alts_grpc_handshaker_client* client = nullptr;
373     {
374       grpc_core::MutexLock lock(&mu_);
375       if (queued_handshakes_.empty()) {
376         // Nothing more in queue.  Decrement count and return immediately.
377         --outstanding_handshakes_;
378         return;
379       }
380       // Remove next entry from queue and start the handshake.
381       client = queued_handshakes_.front();
382       queued_handshakes_.pop_front();
383     }
384     continue_make_grpc_call(client, true /* is_start */);
385   }
386
387  private:
388   grpc_core::Mutex mu_;
389   std::list<alts_grpc_handshaker_client*> queued_handshakes_;
390   size_t outstanding_handshakes_ = 0;
391   const size_t max_outstanding_handshakes_;
392 };
393
394 gpr_once g_queued_handshakes_init = GPR_ONCE_INIT;
395 /* Using separate queues for client and server handshakes is a
396  * hack that's mainly intended to satisfy the alts_concurrent_connectivity_test,
397  * which runs many concurrent handshakes where both endpoints
398  * are in the same process; this situation is problematic with a
399  * single queue because we have a high chance of using up all outstanding
400  * slots in the queue, such that there aren't any
401  * mutual client/server handshakes outstanding at the same time and
402  * able to make progress. */
403 HandshakeQueue* g_client_handshake_queue;
404 HandshakeQueue* g_server_handshake_queue;
405
406 void DoHandshakeQueuesInit(void) {
407   const size_t per_queue_max_outstanding_handshakes = 40;
408   g_client_handshake_queue =
409       new HandshakeQueue(per_queue_max_outstanding_handshakes);
410   g_server_handshake_queue =
411       new HandshakeQueue(per_queue_max_outstanding_handshakes);
412 }
413
414 void RequestHandshake(alts_grpc_handshaker_client* client, bool is_client) {
415   gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
416   HandshakeQueue* queue =
417       is_client ? g_client_handshake_queue : g_server_handshake_queue;
418   queue->RequestHandshake(client);
419 }
420
421 void HandshakeDone(bool is_client) {
422   HandshakeQueue* queue =
423       is_client ? g_client_handshake_queue : g_server_handshake_queue;
424   queue->HandshakeDone();
425 }
426
427 };  // namespace
428
429 /**
430  * Populate grpc operation data with the fields of ALTS handshaker client and
431  * make a grpc call.
432  */
433 static tsi_result make_grpc_call(alts_handshaker_client* c, bool is_start) {
434   GPR_ASSERT(c != nullptr);
435   alts_grpc_handshaker_client* client =
436       reinterpret_cast<alts_grpc_handshaker_client*>(c);
437   if (is_start) {
438     RequestHandshake(client, client->is_client);
439     return TSI_OK;
440   } else {
441     return continue_make_grpc_call(client, is_start);
442   }
443 }
444
445 static void on_status_received(void* arg, grpc_error* error) {
446   alts_grpc_handshaker_client* client =
447       static_cast<alts_grpc_handshaker_client*>(arg);
448   if (client->handshake_status_code != GRPC_STATUS_OK) {
449     // TODO(apolcyn): consider overriding the handshake result's
450     // status from the final ALTS message with the status here.
451     char* status_details =
452         grpc_slice_to_c_string(client->handshake_status_details);
453     gpr_log(GPR_INFO,
454             "alts_grpc_handshaker_client:%p on_status_received "
455             "status:%d details:|%s| error:|%s|",
456             client, client->handshake_status_code, status_details,
457             grpc_error_string(error));
458     gpr_free(status_details);
459   }
460   maybe_complete_tsi_next(client, true /* receive_status_finished */,
461                           nullptr /* pending_recv_message_result */);
462   HandshakeDone(client->is_client);
463   alts_grpc_handshaker_client_unref(client);
464 }
465
466 /* Serializes a grpc_gcp_HandshakerReq message into a buffer and returns newly
467  * grpc_byte_buffer holding it. */
468 static grpc_byte_buffer* get_serialized_handshaker_req(
469     grpc_gcp_HandshakerReq* req, upb_arena* arena) {
470   size_t buf_length;
471   char* buf = grpc_gcp_HandshakerReq_serialize(req, arena, &buf_length);
472   if (buf == nullptr) {
473     return nullptr;
474   }
475   grpc_slice slice = grpc_slice_from_copied_buffer(buf, buf_length);
476   grpc_byte_buffer* byte_buffer = grpc_raw_byte_buffer_create(&slice, 1);
477   grpc_slice_unref_internal(slice);
478   return byte_buffer;
479 }
480
481 /* Create and populate a client_start handshaker request, then serialize it. */
482 static grpc_byte_buffer* get_serialized_start_client(
483     alts_handshaker_client* c) {
484   GPR_ASSERT(c != nullptr);
485   alts_grpc_handshaker_client* client =
486       reinterpret_cast<alts_grpc_handshaker_client*>(c);
487   upb::Arena arena;
488   grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
489   grpc_gcp_StartClientHandshakeReq* start_client =
490       grpc_gcp_HandshakerReq_mutable_client_start(req, arena.ptr());
491   grpc_gcp_StartClientHandshakeReq_set_handshake_security_protocol(
492       start_client, grpc_gcp_ALTS);
493   grpc_gcp_StartClientHandshakeReq_add_application_protocols(
494       start_client, upb_strview_makez(ALTS_APPLICATION_PROTOCOL), arena.ptr());
495   grpc_gcp_StartClientHandshakeReq_add_record_protocols(
496       start_client, upb_strview_makez(ALTS_RECORD_PROTOCOL), arena.ptr());
497   grpc_gcp_RpcProtocolVersions* client_version =
498       grpc_gcp_StartClientHandshakeReq_mutable_rpc_versions(start_client,
499                                                             arena.ptr());
500   grpc_gcp_RpcProtocolVersions_assign_from_struct(
501       client_version, arena.ptr(), &client->options->rpc_versions);
502   grpc_gcp_StartClientHandshakeReq_set_target_name(
503       start_client,
504       upb_strview_make(reinterpret_cast<const char*>(
505                            GRPC_SLICE_START_PTR(client->target_name)),
506                        GRPC_SLICE_LENGTH(client->target_name)));
507   target_service_account* ptr =
508       (reinterpret_cast<grpc_alts_credentials_client_options*>(client->options))
509           ->target_account_list_head;
510   while (ptr != nullptr) {
511     grpc_gcp_Identity* target_identity =
512         grpc_gcp_StartClientHandshakeReq_add_target_identities(start_client,
513                                                                arena.ptr());
514     grpc_gcp_Identity_set_service_account(target_identity,
515                                           upb_strview_makez(ptr->data));
516     ptr = ptr->next;
517   }
518   grpc_gcp_StartClientHandshakeReq_set_max_frame_size(
519       start_client, static_cast<uint32_t>(client->max_frame_size));
520   return get_serialized_handshaker_req(req, arena.ptr());
521 }
522
523 static tsi_result handshaker_client_start_client(alts_handshaker_client* c) {
524   if (c == nullptr) {
525     gpr_log(GPR_ERROR, "client is nullptr in handshaker_client_start_client()");
526     return TSI_INVALID_ARGUMENT;
527   }
528   grpc_byte_buffer* buffer = get_serialized_start_client(c);
529   alts_grpc_handshaker_client* client =
530       reinterpret_cast<alts_grpc_handshaker_client*>(c);
531   if (buffer == nullptr) {
532     gpr_log(GPR_ERROR, "get_serialized_start_client() failed");
533     return TSI_INTERNAL_ERROR;
534   }
535   handshaker_client_send_buffer_destroy(client);
536   client->send_buffer = buffer;
537   tsi_result result = make_grpc_call(&client->base, true /* is_start */);
538   if (result != TSI_OK) {
539     gpr_log(GPR_ERROR, "make_grpc_call() failed");
540   }
541   return result;
542 }
543
544 /* Create and populate a start_server handshaker request, then serialize it. */
545 static grpc_byte_buffer* get_serialized_start_server(
546     alts_handshaker_client* c, grpc_slice* bytes_received) {
547   GPR_ASSERT(c != nullptr);
548   GPR_ASSERT(bytes_received != nullptr);
549   alts_grpc_handshaker_client* client =
550       reinterpret_cast<alts_grpc_handshaker_client*>(c);
551
552   upb::Arena arena;
553   grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
554
555   grpc_gcp_StartServerHandshakeReq* start_server =
556       grpc_gcp_HandshakerReq_mutable_server_start(req, arena.ptr());
557   grpc_gcp_StartServerHandshakeReq_add_application_protocols(
558       start_server, upb_strview_makez(ALTS_APPLICATION_PROTOCOL), arena.ptr());
559   grpc_gcp_ServerHandshakeParameters* value =
560       grpc_gcp_ServerHandshakeParameters_new(arena.ptr());
561   grpc_gcp_ServerHandshakeParameters_add_record_protocols(
562       value, upb_strview_makez(ALTS_RECORD_PROTOCOL), arena.ptr());
563   grpc_gcp_StartServerHandshakeReq_handshake_parameters_set(
564       start_server, grpc_gcp_ALTS, value, arena.ptr());
565   grpc_gcp_StartServerHandshakeReq_set_in_bytes(
566       start_server, upb_strview_make(reinterpret_cast<const char*>(
567                                          GRPC_SLICE_START_PTR(*bytes_received)),
568                                      GRPC_SLICE_LENGTH(*bytes_received)));
569   grpc_gcp_RpcProtocolVersions* server_version =
570       grpc_gcp_StartServerHandshakeReq_mutable_rpc_versions(start_server,
571                                                             arena.ptr());
572   grpc_gcp_RpcProtocolVersions_assign_from_struct(
573       server_version, arena.ptr(), &client->options->rpc_versions);
574   grpc_gcp_StartServerHandshakeReq_set_max_frame_size(
575       start_server, static_cast<uint32_t>(client->max_frame_size));
576   return get_serialized_handshaker_req(req, arena.ptr());
577 }
578
579 static tsi_result handshaker_client_start_server(alts_handshaker_client* c,
580                                                  grpc_slice* bytes_received) {
581   if (c == nullptr || bytes_received == nullptr) {
582     gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_server()");
583     return TSI_INVALID_ARGUMENT;
584   }
585   alts_grpc_handshaker_client* client =
586       reinterpret_cast<alts_grpc_handshaker_client*>(c);
587   grpc_byte_buffer* buffer = get_serialized_start_server(c, bytes_received);
588   if (buffer == nullptr) {
589     gpr_log(GPR_ERROR, "get_serialized_start_server() failed");
590     return TSI_INTERNAL_ERROR;
591   }
592   handshaker_client_send_buffer_destroy(client);
593   client->send_buffer = buffer;
594   tsi_result result = make_grpc_call(&client->base, true /* is_start */);
595   if (result != TSI_OK) {
596     gpr_log(GPR_ERROR, "make_grpc_call() failed");
597   }
598   return result;
599 }
600
601 /* Create and populate a next handshaker request, then serialize it. */
602 static grpc_byte_buffer* get_serialized_next(grpc_slice* bytes_received) {
603   GPR_ASSERT(bytes_received != nullptr);
604   upb::Arena arena;
605   grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
606   grpc_gcp_NextHandshakeMessageReq* next =
607       grpc_gcp_HandshakerReq_mutable_next(req, arena.ptr());
608   grpc_gcp_NextHandshakeMessageReq_set_in_bytes(
609       next, upb_strview_make(reinterpret_cast<const char*> GRPC_SLICE_START_PTR(
610                                  *bytes_received),
611                              GRPC_SLICE_LENGTH(*bytes_received)));
612   return get_serialized_handshaker_req(req, arena.ptr());
613 }
614
615 static tsi_result handshaker_client_next(alts_handshaker_client* c,
616                                          grpc_slice* bytes_received) {
617   if (c == nullptr || bytes_received == nullptr) {
618     gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_next()");
619     return TSI_INVALID_ARGUMENT;
620   }
621   alts_grpc_handshaker_client* client =
622       reinterpret_cast<alts_grpc_handshaker_client*>(c);
623   grpc_slice_unref_internal(client->recv_bytes);
624   client->recv_bytes = grpc_slice_ref_internal(*bytes_received);
625   grpc_byte_buffer* buffer = get_serialized_next(bytes_received);
626   if (buffer == nullptr) {
627     gpr_log(GPR_ERROR, "get_serialized_next() failed");
628     return TSI_INTERNAL_ERROR;
629   }
630   handshaker_client_send_buffer_destroy(client);
631   client->send_buffer = buffer;
632   tsi_result result = make_grpc_call(&client->base, false /* is_start */);
633   if (result != TSI_OK) {
634     gpr_log(GPR_ERROR, "make_grpc_call() failed");
635   }
636   return result;
637 }
638
639 static void handshaker_client_shutdown(alts_handshaker_client* c) {
640   GPR_ASSERT(c != nullptr);
641   alts_grpc_handshaker_client* client =
642       reinterpret_cast<alts_grpc_handshaker_client*>(c);
643   if (client->call != nullptr) {
644     grpc_call_cancel_internal(client->call);
645   }
646 }
647
648 static void handshaker_call_unref(void* arg, grpc_error* /* error */) {
649   grpc_call* call = static_cast<grpc_call*>(arg);
650   grpc_call_unref(call);
651 }
652
653 static void handshaker_client_destruct(alts_handshaker_client* c) {
654   if (c == nullptr) {
655     return;
656   }
657   alts_grpc_handshaker_client* client =
658       reinterpret_cast<alts_grpc_handshaker_client*>(c);
659   if (client->call != nullptr) {
660     // Throw this grpc_call_unref over to the ExecCtx so that
661     // we invoke it at the bottom of the call stack and
662     // prevent lock inversion problems due to nested ExecCtx flushing.
663     // TODO(apolcyn): we could remove this indirection and call
664     // grpc_call_unref inline if there was an internal variant of
665     // grpc_call_unref that didn't need to flush an ExecCtx.
666     if (grpc_core::ExecCtx::Get() == nullptr) {
667       // Unref handshaker call if there is no exec_ctx, e.g., in the case of
668       // Envoy ALTS transport socket.
669       grpc_call_unref(client->call);
670     } else {
671       // Using existing exec_ctx to unref handshaker call.
672       grpc_core::ExecCtx::Run(
673           DEBUG_LOCATION,
674           GRPC_CLOSURE_CREATE(handshaker_call_unref, client->call,
675                               grpc_schedule_on_exec_ctx),
676           GRPC_ERROR_NONE);
677     }
678   }
679 }
680
681 static const alts_handshaker_client_vtable vtable = {
682     handshaker_client_start_client, handshaker_client_start_server,
683     handshaker_client_next, handshaker_client_shutdown,
684     handshaker_client_destruct};
685
686 alts_handshaker_client* alts_grpc_handshaker_client_create(
687     alts_tsi_handshaker* handshaker, grpc_channel* channel,
688     const char* handshaker_service_url, grpc_pollset_set* interested_parties,
689     grpc_alts_credentials_options* options, const grpc_slice& target_name,
690     grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb,
691     void* user_data, alts_handshaker_client_vtable* vtable_for_testing,
692     bool is_client, size_t max_frame_size) {
693   if (channel == nullptr || handshaker_service_url == nullptr) {
694     gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()");
695     return nullptr;
696   }
697   alts_grpc_handshaker_client* client = new alts_grpc_handshaker_client();
698   memset(&client->base, 0, sizeof(client->base));
699   client->base.vtable =
700       vtable_for_testing == nullptr ? &vtable : vtable_for_testing;
701   gpr_ref_init(&client->refs, 1);
702   client->handshaker = handshaker;
703   client->grpc_caller = grpc_call_start_batch_and_execute;
704   grpc_metadata_array_init(&client->recv_initial_metadata);
705   client->cb = cb;
706   client->user_data = user_data;
707   client->options = grpc_alts_credentials_options_copy(options);
708   client->target_name = grpc_slice_copy(target_name);
709   client->is_client = is_client;
710   client->recv_bytes = grpc_empty_slice();
711   client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE;
712   client->buffer = static_cast<unsigned char*>(gpr_zalloc(client->buffer_size));
713   client->handshake_status_details = grpc_empty_slice();
714   client->max_frame_size = max_frame_size;
715   grpc_slice slice = grpc_slice_from_copied_string(handshaker_service_url);
716   client->call =
717       strcmp(handshaker_service_url, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING) ==
718               0
719           ? nullptr
720           : grpc_channel_create_pollset_set_call(
721                 channel, nullptr, GRPC_PROPAGATE_DEFAULTS, interested_parties,
722                 grpc_slice_from_static_string(ALTS_SERVICE_METHOD), &slice,
723                 GRPC_MILLIS_INF_FUTURE, nullptr);
724   GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, grpc_cb, client,
725                     grpc_schedule_on_exec_ctx);
726   GRPC_CLOSURE_INIT(&client->on_status_received, on_status_received, client,
727                     grpc_schedule_on_exec_ctx);
728   grpc_slice_unref_internal(slice);
729   return &client->base;
730 }
731
732 namespace grpc_core {
733 namespace internal {
734
735 void alts_handshaker_client_set_grpc_caller_for_testing(
736     alts_handshaker_client* c, alts_grpc_caller caller) {
737   GPR_ASSERT(c != nullptr && caller != nullptr);
738   alts_grpc_handshaker_client* client =
739       reinterpret_cast<alts_grpc_handshaker_client*>(c);
740   client->grpc_caller = caller;
741 }
742
743 grpc_byte_buffer* alts_handshaker_client_get_send_buffer_for_testing(
744     alts_handshaker_client* c) {
745   GPR_ASSERT(c != nullptr);
746   alts_grpc_handshaker_client* client =
747       reinterpret_cast<alts_grpc_handshaker_client*>(c);
748   return client->send_buffer;
749 }
750
751 grpc_byte_buffer** alts_handshaker_client_get_recv_buffer_addr_for_testing(
752     alts_handshaker_client* c) {
753   GPR_ASSERT(c != nullptr);
754   alts_grpc_handshaker_client* client =
755       reinterpret_cast<alts_grpc_handshaker_client*>(c);
756   return &client->recv_buffer;
757 }
758
759 grpc_metadata_array* alts_handshaker_client_get_initial_metadata_for_testing(
760     alts_handshaker_client* c) {
761   GPR_ASSERT(c != nullptr);
762   alts_grpc_handshaker_client* client =
763       reinterpret_cast<alts_grpc_handshaker_client*>(c);
764   return &client->recv_initial_metadata;
765 }
766
767 void alts_handshaker_client_set_recv_bytes_for_testing(
768     alts_handshaker_client* c, grpc_slice* recv_bytes) {
769   GPR_ASSERT(c != nullptr);
770   alts_grpc_handshaker_client* client =
771       reinterpret_cast<alts_grpc_handshaker_client*>(c);
772   client->recv_bytes = grpc_slice_ref_internal(*recv_bytes);
773 }
774
775 void alts_handshaker_client_set_fields_for_testing(
776     alts_handshaker_client* c, alts_tsi_handshaker* handshaker,
777     tsi_handshaker_on_next_done_cb cb, void* user_data,
778     grpc_byte_buffer* recv_buffer, grpc_status_code status) {
779   GPR_ASSERT(c != nullptr);
780   alts_grpc_handshaker_client* client =
781       reinterpret_cast<alts_grpc_handshaker_client*>(c);
782   client->handshaker = handshaker;
783   client->cb = cb;
784   client->user_data = user_data;
785   client->recv_buffer = recv_buffer;
786   client->status = status;
787 }
788
789 void alts_handshaker_client_check_fields_for_testing(
790     alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb,
791     void* user_data, bool has_sent_start_message, grpc_slice* recv_bytes) {
792   GPR_ASSERT(c != nullptr);
793   alts_grpc_handshaker_client* client =
794       reinterpret_cast<alts_grpc_handshaker_client*>(c);
795   GPR_ASSERT(client->cb == cb);
796   GPR_ASSERT(client->user_data == user_data);
797   if (recv_bytes != nullptr) {
798     GPR_ASSERT(grpc_slice_cmp(client->recv_bytes, *recv_bytes) == 0);
799   }
800   GPR_ASSERT(alts_tsi_handshaker_get_has_sent_start_message_for_testing(
801                  client->handshaker) == has_sent_start_message);
802 }
803
804 void alts_handshaker_client_set_vtable_for_testing(
805     alts_handshaker_client* c, alts_handshaker_client_vtable* vtable) {
806   GPR_ASSERT(c != nullptr);
807   GPR_ASSERT(vtable != nullptr);
808   alts_grpc_handshaker_client* client =
809       reinterpret_cast<alts_grpc_handshaker_client*>(c);
810   client->base.vtable = vtable;
811 }
812
813 alts_tsi_handshaker* alts_handshaker_client_get_handshaker_for_testing(
814     alts_handshaker_client* c) {
815   GPR_ASSERT(c != nullptr);
816   alts_grpc_handshaker_client* client =
817       reinterpret_cast<alts_grpc_handshaker_client*>(c);
818   return client->handshaker;
819 }
820
821 void alts_handshaker_client_set_cb_for_testing(
822     alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb) {
823   GPR_ASSERT(c != nullptr);
824   alts_grpc_handshaker_client* client =
825       reinterpret_cast<alts_grpc_handshaker_client*>(c);
826   client->cb = cb;
827 }
828
829 grpc_closure* alts_handshaker_client_get_closure_for_testing(
830     alts_handshaker_client* c) {
831   GPR_ASSERT(c != nullptr);
832   alts_grpc_handshaker_client* client =
833       reinterpret_cast<alts_grpc_handshaker_client*>(c);
834   return &client->on_handshaker_service_resp_recv;
835 }
836
837 void alts_handshaker_client_ref_for_testing(alts_handshaker_client* c) {
838   alts_grpc_handshaker_client* client =
839       reinterpret_cast<alts_grpc_handshaker_client*>(c);
840   gpr_ref(&client->refs);
841 }
842
843 void alts_handshaker_client_on_status_received_for_testing(
844     alts_handshaker_client* c, grpc_status_code status, grpc_error* error) {
845   // We first make sure that the handshake queue has been initialized
846   // here because there are tests that use this API that mock out
847   // other parts of the alts_handshaker_client in such a way that the
848   // code path that would normally ensure that the handshake queue
849   // has been initialized isn't taken.
850   gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
851   alts_grpc_handshaker_client* client =
852       reinterpret_cast<alts_grpc_handshaker_client*>(c);
853   client->handshake_status_code = status;
854   client->handshake_status_details = grpc_empty_slice();
855   grpc_core::Closure::Run(DEBUG_LOCATION, &client->on_status_received, error);
856 }
857
858 }  // namespace internal
859 }  // namespace grpc_core
860
861 tsi_result alts_handshaker_client_start_client(alts_handshaker_client* client) {
862   if (client != nullptr && client->vtable != nullptr &&
863       client->vtable->client_start != nullptr) {
864     return client->vtable->client_start(client);
865   }
866   gpr_log(GPR_ERROR,
867           "client or client->vtable has not been initialized properly");
868   return TSI_INVALID_ARGUMENT;
869 }
870
871 tsi_result alts_handshaker_client_start_server(alts_handshaker_client* client,
872                                                grpc_slice* bytes_received) {
873   if (client != nullptr && client->vtable != nullptr &&
874       client->vtable->server_start != nullptr) {
875     return client->vtable->server_start(client, bytes_received);
876   }
877   gpr_log(GPR_ERROR,
878           "client or client->vtable has not been initialized properly");
879   return TSI_INVALID_ARGUMENT;
880 }
881
882 tsi_result alts_handshaker_client_next(alts_handshaker_client* client,
883                                        grpc_slice* bytes_received) {
884   if (client != nullptr && client->vtable != nullptr &&
885       client->vtable->next != nullptr) {
886     return client->vtable->next(client, bytes_received);
887   }
888   gpr_log(GPR_ERROR,
889           "client or client->vtable has not been initialized properly");
890   return TSI_INVALID_ARGUMENT;
891 }
892
893 void alts_handshaker_client_shutdown(alts_handshaker_client* client) {
894   if (client != nullptr && client->vtable != nullptr &&
895       client->vtable->shutdown != nullptr) {
896     client->vtable->shutdown(client);
897   }
898 }
899
900 void alts_handshaker_client_destroy(alts_handshaker_client* c) {
901   if (c != nullptr) {
902     alts_grpc_handshaker_client* client =
903         reinterpret_cast<alts_grpc_handshaker_client*>(c);
904     alts_grpc_handshaker_client_unref(client);
905   }
906 }