9c8c8d9e1885c929b0c2e9c47631885f8d3ef7d4
[platform/upstream/grpc.git] / src / core / ext / filters / http / message_compress / message_compress_filter.cc
1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18
19 #include <grpc/support/port_platform.h>
20
21 #include <assert.h>
22 #include <string.h>
23
24 #include <grpc/compression.h>
25 #include <grpc/slice_buffer.h>
26 #include <grpc/support/alloc.h>
27 #include <grpc/support/log.h>
28
29 #include "src/core/ext/filters/http/message_compress/message_compress_filter.h"
30 #include "src/core/lib/channel/channel_args.h"
31 #include "src/core/lib/compression/algorithm_metadata.h"
32 #include "src/core/lib/compression/compression_internal.h"
33 #include "src/core/lib/compression/message_compress.h"
34 #include "src/core/lib/gpr/string.h"
35 #include "src/core/lib/gprpp/manual_constructor.h"
36 #include "src/core/lib/profiling/timers.h"
37 #include "src/core/lib/slice/slice_internal.h"
38 #include "src/core/lib/slice/slice_string_helpers.h"
39 #include "src/core/lib/surface/call.h"
40 #include "src/core/lib/transport/static_metadata.h"
41
42 static void start_send_message_batch(void* arg, grpc_error* unused);
43 static void send_message_on_complete(void* arg, grpc_error* error);
44 static void on_send_message_next_done(void* arg, grpc_error* error);
45
46 namespace {
47 enum initial_metadata_state {
48   // Initial metadata not yet seen.
49   INITIAL_METADATA_UNSEEN = 0,
50   // Initial metadata seen; compression algorithm set.
51   HAS_COMPRESSION_ALGORITHM,
52   // Initial metadata seen; no compression algorithm set.
53   NO_COMPRESSION_ALGORITHM,
54 };
55
56 struct call_data {
57   call_data(grpc_call_element* elem, const grpc_call_element_args& args)
58       : call_combiner(args.call_combiner) {
59     GRPC_CLOSURE_INIT(&start_send_message_batch_in_call_combiner,
60                       start_send_message_batch, elem,
61                       grpc_schedule_on_exec_ctx);
62     grpc_slice_buffer_init(&slices);
63     GRPC_CLOSURE_INIT(&send_message_on_complete, ::send_message_on_complete,
64                       elem, grpc_schedule_on_exec_ctx);
65     GRPC_CLOSURE_INIT(&on_send_message_next_done, ::on_send_message_next_done,
66                       elem, grpc_schedule_on_exec_ctx);
67   }
68
69   ~call_data() {
70     grpc_slice_buffer_destroy_internal(&slices);
71     GRPC_ERROR_UNREF(cancel_error);
72   }
73
74   grpc_call_combiner* call_combiner;
75   grpc_linked_mdelem compression_algorithm_storage;
76   grpc_linked_mdelem stream_compression_algorithm_storage;
77   grpc_linked_mdelem accept_encoding_storage;
78   grpc_linked_mdelem accept_stream_encoding_storage;
79   /** Compression algorithm we'll try to use. It may be given by incoming
80    * metadata, or by the channel's default compression settings. */
81   grpc_message_compression_algorithm message_compression_algorithm =
82       GRPC_MESSAGE_COMPRESS_NONE;
83   initial_metadata_state send_initial_metadata_state = INITIAL_METADATA_UNSEEN;
84   grpc_error* cancel_error = GRPC_ERROR_NONE;
85   grpc_closure start_send_message_batch_in_call_combiner;
86   grpc_transport_stream_op_batch* send_message_batch = nullptr;
87   grpc_slice_buffer slices; /**< Buffers up input slices to be compressed */
88   grpc_core::ManualConstructor<grpc_core::SliceBufferByteStream>
89       replacement_stream;
90   grpc_closure* original_send_message_on_complete;
91   grpc_closure send_message_on_complete;
92   grpc_closure on_send_message_next_done;
93 };
94
95 struct channel_data {
96   /** The default, channel-level, compression algorithm */
97   grpc_compression_algorithm default_compression_algorithm;
98   /** Bitset of enabled compression algorithms */
99   uint32_t enabled_algorithms_bitset;
100   /** Supported compression algorithms */
101   uint32_t supported_message_compression_algorithms;
102   /** Supported stream compression algorithms */
103   uint32_t supported_stream_compression_algorithms;
104 };
105 }  // namespace
106
107 static bool skip_compression(grpc_call_element* elem, uint32_t flags,
108                              bool has_compression_algorithm) {
109   call_data* calld = static_cast<call_data*>(elem->call_data);
110   channel_data* channeld = static_cast<channel_data*>(elem->channel_data);
111
112   if (flags & (GRPC_WRITE_NO_COMPRESS | GRPC_WRITE_INTERNAL_COMPRESS)) {
113     return true;
114   }
115   if (has_compression_algorithm) {
116     if (calld->message_compression_algorithm == GRPC_MESSAGE_COMPRESS_NONE) {
117       return true;
118     }
119     return false; /* we have an actual call-specific algorithm */
120   }
121   /* no per-call compression override */
122   return channeld->default_compression_algorithm == GRPC_COMPRESS_NONE;
123 }
124
125 /** Filter initial metadata */
126 static grpc_error* process_send_initial_metadata(
127     grpc_call_element* elem, grpc_metadata_batch* initial_metadata,
128     bool* has_compression_algorithm) GRPC_MUST_USE_RESULT;
129 static grpc_error* process_send_initial_metadata(
130     grpc_call_element* elem, grpc_metadata_batch* initial_metadata,
131     bool* has_compression_algorithm) {
132   call_data* calld = static_cast<call_data*>(elem->call_data);
133   channel_data* channeld = static_cast<channel_data*>(elem->channel_data);
134   *has_compression_algorithm = false;
135   grpc_compression_algorithm compression_algorithm;
136   grpc_stream_compression_algorithm stream_compression_algorithm =
137       GRPC_STREAM_COMPRESS_NONE;
138   if (initial_metadata->idx.named.grpc_internal_encoding_request != nullptr) {
139     grpc_mdelem md =
140         initial_metadata->idx.named.grpc_internal_encoding_request->md;
141     if (GPR_UNLIKELY(!grpc_compression_algorithm_parse(
142             GRPC_MDVALUE(md), &compression_algorithm))) {
143       char* val = grpc_slice_to_c_string(GRPC_MDVALUE(md));
144       gpr_log(GPR_ERROR,
145               "Invalid compression algorithm: '%s' (unknown). Ignoring.", val);
146       gpr_free(val);
147       calld->message_compression_algorithm = GRPC_MESSAGE_COMPRESS_NONE;
148       stream_compression_algorithm = GRPC_STREAM_COMPRESS_NONE;
149     }
150     if (GPR_UNLIKELY(!GPR_BITGET(channeld->enabled_algorithms_bitset,
151                                  compression_algorithm))) {
152       char* val = grpc_slice_to_c_string(GRPC_MDVALUE(md));
153       gpr_log(GPR_ERROR,
154               "Invalid compression algorithm: '%s' (previously disabled). "
155               "Ignoring.",
156               val);
157       gpr_free(val);
158       calld->message_compression_algorithm = GRPC_MESSAGE_COMPRESS_NONE;
159       stream_compression_algorithm = GRPC_STREAM_COMPRESS_NONE;
160     }
161     *has_compression_algorithm = true;
162     grpc_metadata_batch_remove(
163         initial_metadata,
164         initial_metadata->idx.named.grpc_internal_encoding_request);
165     calld->message_compression_algorithm =
166         grpc_compression_algorithm_to_message_compression_algorithm(
167             compression_algorithm);
168     stream_compression_algorithm =
169         grpc_compression_algorithm_to_stream_compression_algorithm(
170             compression_algorithm);
171   } else {
172     /* If no algorithm was found in the metadata and we aren't
173      * exceptionally skipping compression, fall back to the channel
174      * default */
175     if (channeld->default_compression_algorithm != GRPC_COMPRESS_NONE) {
176       calld->message_compression_algorithm =
177           grpc_compression_algorithm_to_message_compression_algorithm(
178               channeld->default_compression_algorithm);
179       stream_compression_algorithm =
180           grpc_compression_algorithm_to_stream_compression_algorithm(
181               channeld->default_compression_algorithm);
182     }
183     *has_compression_algorithm = true;
184   }
185
186   grpc_error* error = GRPC_ERROR_NONE;
187   /* hint compression algorithm */
188   if (stream_compression_algorithm != GRPC_STREAM_COMPRESS_NONE) {
189     error = grpc_metadata_batch_add_tail(
190         initial_metadata, &calld->stream_compression_algorithm_storage,
191         grpc_stream_compression_encoding_mdelem(stream_compression_algorithm));
192   } else if (calld->message_compression_algorithm !=
193              GRPC_MESSAGE_COMPRESS_NONE) {
194     error = grpc_metadata_batch_add_tail(
195         initial_metadata, &calld->compression_algorithm_storage,
196         grpc_message_compression_encoding_mdelem(
197             calld->message_compression_algorithm));
198   }
199
200   if (error != GRPC_ERROR_NONE) return error;
201
202   /* convey supported compression algorithms */
203   error = grpc_metadata_batch_add_tail(
204       initial_metadata, &calld->accept_encoding_storage,
205       GRPC_MDELEM_ACCEPT_ENCODING_FOR_ALGORITHMS(
206           channeld->supported_message_compression_algorithms));
207
208   if (error != GRPC_ERROR_NONE) return error;
209
210   /* Do not overwrite accept-encoding header if it already presents (e.g. added
211    * by some proxy). */
212   if (!initial_metadata->idx.named.accept_encoding) {
213     error = grpc_metadata_batch_add_tail(
214         initial_metadata, &calld->accept_stream_encoding_storage,
215         GRPC_MDELEM_ACCEPT_STREAM_ENCODING_FOR_ALGORITHMS(
216             channeld->supported_stream_compression_algorithms));
217   }
218
219   return error;
220 }
221
222 static void send_message_on_complete(void* arg, grpc_error* error) {
223   grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
224   call_data* calld = static_cast<call_data*>(elem->call_data);
225   grpc_slice_buffer_reset_and_unref_internal(&calld->slices);
226   GRPC_CLOSURE_RUN(calld->original_send_message_on_complete,
227                    GRPC_ERROR_REF(error));
228 }
229
230 static void send_message_batch_continue(grpc_call_element* elem) {
231   call_data* calld = static_cast<call_data*>(elem->call_data);
232   // Note: The call to grpc_call_next_op() results in yielding the
233   // call combiner, so we need to clear calld->send_message_batch
234   // before we do that.
235   grpc_transport_stream_op_batch* send_message_batch =
236       calld->send_message_batch;
237   calld->send_message_batch = nullptr;
238   grpc_call_next_op(elem, send_message_batch);
239 }
240
241 static void finish_send_message(grpc_call_element* elem) {
242   call_data* calld = static_cast<call_data*>(elem->call_data);
243   // Compress the data if appropriate.
244   grpc_slice_buffer tmp;
245   grpc_slice_buffer_init(&tmp);
246   uint32_t send_flags =
247       calld->send_message_batch->payload->send_message.send_message->flags();
248   bool did_compress = grpc_msg_compress(calld->message_compression_algorithm,
249                                         &calld->slices, &tmp);
250   if (did_compress) {
251     if (grpc_compression_trace.enabled()) {
252       const char* algo_name;
253       const size_t before_size = calld->slices.length;
254       const size_t after_size = tmp.length;
255       const float savings_ratio = 1.0f - static_cast<float>(after_size) /
256                                              static_cast<float>(before_size);
257       GPR_ASSERT(grpc_message_compression_algorithm_name(
258           calld->message_compression_algorithm, &algo_name));
259       gpr_log(GPR_INFO,
260               "Compressed[%s] %" PRIuPTR " bytes vs. %" PRIuPTR
261               " bytes (%.2f%% savings)",
262               algo_name, before_size, after_size, 100 * savings_ratio);
263     }
264     grpc_slice_buffer_swap(&calld->slices, &tmp);
265     send_flags |= GRPC_WRITE_INTERNAL_COMPRESS;
266   } else {
267     if (grpc_compression_trace.enabled()) {
268       const char* algo_name;
269       GPR_ASSERT(grpc_message_compression_algorithm_name(
270           calld->message_compression_algorithm, &algo_name));
271       gpr_log(GPR_INFO,
272               "Algorithm '%s' enabled but decided not to compress. Input size: "
273               "%" PRIuPTR,
274               algo_name, calld->slices.length);
275     }
276   }
277   grpc_slice_buffer_destroy_internal(&tmp);
278   // Swap out the original byte stream with our new one and send the
279   // batch down.
280   calld->replacement_stream.Init(&calld->slices, send_flags);
281   calld->send_message_batch->payload->send_message.send_message.reset(
282       calld->replacement_stream.get());
283   calld->original_send_message_on_complete =
284       calld->send_message_batch->on_complete;
285   calld->send_message_batch->on_complete = &calld->send_message_on_complete;
286   send_message_batch_continue(elem);
287 }
288
289 static void fail_send_message_batch_in_call_combiner(void* arg,
290                                                      grpc_error* error) {
291   call_data* calld = static_cast<call_data*>(arg);
292   if (calld->send_message_batch != nullptr) {
293     grpc_transport_stream_op_batch_finish_with_failure(
294         calld->send_message_batch, GRPC_ERROR_REF(error), calld->call_combiner);
295     calld->send_message_batch = nullptr;
296   }
297 }
298
299 // Pulls a slice from the send_message byte stream and adds it to calld->slices.
300 static grpc_error* pull_slice_from_send_message(call_data* calld) {
301   grpc_slice incoming_slice;
302   grpc_error* error =
303       calld->send_message_batch->payload->send_message.send_message->Pull(
304           &incoming_slice);
305   if (error == GRPC_ERROR_NONE) {
306     grpc_slice_buffer_add(&calld->slices, incoming_slice);
307   }
308   return error;
309 }
310
311 // Reads as many slices as possible from the send_message byte stream.
312 // If all data has been read, invokes finish_send_message().  Otherwise,
313 // an async call to ByteStream::Next() has been started, which will
314 // eventually result in calling on_send_message_next_done().
315 static void continue_reading_send_message(grpc_call_element* elem) {
316   call_data* calld = static_cast<call_data*>(elem->call_data);
317   while (calld->send_message_batch->payload->send_message.send_message->Next(
318       ~static_cast<size_t>(0), &calld->on_send_message_next_done)) {
319     grpc_error* error = pull_slice_from_send_message(calld);
320     if (error != GRPC_ERROR_NONE) {
321       // Closure callback; does not take ownership of error.
322       fail_send_message_batch_in_call_combiner(calld, error);
323       GRPC_ERROR_UNREF(error);
324       return;
325     }
326     if (calld->slices.length == calld->send_message_batch->payload->send_message
327                                     .send_message->length()) {
328       finish_send_message(elem);
329       break;
330     }
331   }
332 }
333
334 // Async callback for ByteStream::Next().
335 static void on_send_message_next_done(void* arg, grpc_error* error) {
336   grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
337   call_data* calld = static_cast<call_data*>(elem->call_data);
338   if (error != GRPC_ERROR_NONE) {
339     // Closure callback; does not take ownership of error.
340     fail_send_message_batch_in_call_combiner(calld, error);
341     return;
342   }
343   error = pull_slice_from_send_message(calld);
344   if (error != GRPC_ERROR_NONE) {
345     // Closure callback; does not take ownership of error.
346     fail_send_message_batch_in_call_combiner(calld, error);
347     GRPC_ERROR_UNREF(error);
348     return;
349   }
350   if (calld->slices.length ==
351       calld->send_message_batch->payload->send_message.send_message->length()) {
352     finish_send_message(elem);
353   } else {
354     continue_reading_send_message(elem);
355   }
356 }
357
358 static void start_send_message_batch(void* arg, grpc_error* unused) {
359   grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
360   call_data* calld = static_cast<call_data*>(elem->call_data);
361   if (skip_compression(
362           elem,
363           calld->send_message_batch->payload->send_message.send_message
364               ->flags(),
365           calld->send_initial_metadata_state == HAS_COMPRESSION_ALGORITHM)) {
366     send_message_batch_continue(elem);
367   } else {
368     continue_reading_send_message(elem);
369   }
370 }
371
372 static void compress_start_transport_stream_op_batch(
373     grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
374   GPR_TIMER_SCOPE("compress_start_transport_stream_op_batch", 0);
375   call_data* calld = static_cast<call_data*>(elem->call_data);
376   // Handle cancel_stream.
377   if (batch->cancel_stream) {
378     GRPC_ERROR_UNREF(calld->cancel_error);
379     calld->cancel_error =
380         GRPC_ERROR_REF(batch->payload->cancel_stream.cancel_error);
381     if (calld->send_message_batch != nullptr) {
382       if (calld->send_initial_metadata_state == INITIAL_METADATA_UNSEEN) {
383         GRPC_CALL_COMBINER_START(
384             calld->call_combiner,
385             GRPC_CLOSURE_CREATE(fail_send_message_batch_in_call_combiner, calld,
386                                 grpc_schedule_on_exec_ctx),
387             GRPC_ERROR_REF(calld->cancel_error), "failing send_message op");
388       } else {
389         calld->send_message_batch->payload->send_message.send_message->Shutdown(
390             GRPC_ERROR_REF(calld->cancel_error));
391       }
392     }
393   } else if (calld->cancel_error != GRPC_ERROR_NONE) {
394     grpc_transport_stream_op_batch_finish_with_failure(
395         batch, GRPC_ERROR_REF(calld->cancel_error), calld->call_combiner);
396     return;
397   }
398   // Handle send_initial_metadata.
399   if (batch->send_initial_metadata) {
400     GPR_ASSERT(calld->send_initial_metadata_state == INITIAL_METADATA_UNSEEN);
401     bool has_compression_algorithm;
402     grpc_error* error = process_send_initial_metadata(
403         elem, batch->payload->send_initial_metadata.send_initial_metadata,
404         &has_compression_algorithm);
405     if (error != GRPC_ERROR_NONE) {
406       grpc_transport_stream_op_batch_finish_with_failure(batch, error,
407                                                          calld->call_combiner);
408       return;
409     }
410     calld->send_initial_metadata_state = has_compression_algorithm
411                                              ? HAS_COMPRESSION_ALGORITHM
412                                              : NO_COMPRESSION_ALGORITHM;
413     // If we had previously received a batch containing a send_message op,
414     // handle it now.  Note that we need to re-enter the call combiner
415     // for this, since we can't send two batches down while holding the
416     // call combiner, since the connected_channel filter (at the bottom of
417     // the call stack) will release the call combiner for each batch it sees.
418     if (calld->send_message_batch != nullptr) {
419       GRPC_CALL_COMBINER_START(
420           calld->call_combiner,
421           &calld->start_send_message_batch_in_call_combiner, GRPC_ERROR_NONE,
422           "starting send_message after send_initial_metadata");
423     }
424   }
425   // Handle send_message.
426   if (batch->send_message) {
427     GPR_ASSERT(calld->send_message_batch == nullptr);
428     calld->send_message_batch = batch;
429     // If we have not yet seen send_initial_metadata, then we have to
430     // wait.  We save the batch in calld and then drop the call
431     // combiner, which we'll have to pick up again later when we get
432     // send_initial_metadata.
433     if (calld->send_initial_metadata_state == INITIAL_METADATA_UNSEEN) {
434       GRPC_CALL_COMBINER_STOP(
435           calld->call_combiner,
436           "send_message batch pending send_initial_metadata");
437       return;
438     }
439     start_send_message_batch(elem, GRPC_ERROR_NONE);
440   } else {
441     // Pass control down the stack.
442     grpc_call_next_op(elem, batch);
443   }
444 }
445
446 /* Constructor for call_data */
447 static grpc_error* init_call_elem(grpc_call_element* elem,
448                                   const grpc_call_element_args* args) {
449   new (elem->call_data) call_data(elem, *args);
450   return GRPC_ERROR_NONE;
451 }
452
453 /* Destructor for call_data */
454 static void destroy_call_elem(grpc_call_element* elem,
455                               const grpc_call_final_info* final_info,
456                               grpc_closure* ignored) {
457   call_data* calld = static_cast<call_data*>(elem->call_data);
458   calld->~call_data();
459 }
460
461 /* Constructor for channel_data */
462 static grpc_error* init_channel_elem(grpc_channel_element* elem,
463                                      grpc_channel_element_args* args) {
464   channel_data* channeld = static_cast<channel_data*>(elem->channel_data);
465
466   channeld->enabled_algorithms_bitset =
467       grpc_channel_args_compression_algorithm_get_states(args->channel_args);
468   channeld->default_compression_algorithm =
469       grpc_channel_args_get_compression_algorithm(args->channel_args);
470
471   /* Make sure the default isn't disabled. */
472   if (!GPR_BITGET(channeld->enabled_algorithms_bitset,
473                   channeld->default_compression_algorithm)) {
474     gpr_log(GPR_DEBUG,
475             "compression algorithm %d not enabled: switching to none",
476             channeld->default_compression_algorithm);
477     channeld->default_compression_algorithm = GRPC_COMPRESS_NONE;
478   }
479
480   uint32_t supported_compression_algorithms =
481       (((1u << GRPC_COMPRESS_ALGORITHMS_COUNT) - 1) &
482        channeld->enabled_algorithms_bitset) |
483       1u;
484
485   channeld->supported_message_compression_algorithms =
486       grpc_compression_bitset_to_message_bitset(
487           supported_compression_algorithms);
488
489   channeld->supported_stream_compression_algorithms =
490       grpc_compression_bitset_to_stream_bitset(
491           supported_compression_algorithms);
492
493   GPR_ASSERT(!args->is_last);
494   return GRPC_ERROR_NONE;
495 }
496
497 /* Destructor for channel data */
498 static void destroy_channel_elem(grpc_channel_element* elem) {}
499
500 const grpc_channel_filter grpc_message_compress_filter = {
501     compress_start_transport_stream_op_batch,
502     grpc_channel_next_op,
503     sizeof(call_data),
504     init_call_elem,
505     grpc_call_stack_ignore_set_pollset_or_pollset_set,
506     destroy_call_elem,
507     sizeof(channel_data),
508     init_channel_elem,
509     destroy_channel_elem,
510     grpc_channel_next_get_info,
511     "message_compress"};