5f8d77944145990c94d2cbc6a149102d4db046f8
[platform/upstream/grpc.git] / src / core / ext / filters / fault_injection / fault_injection_filter.cc
1 //
2 // Copyright 2021 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include <grpc/support/port_platform.h>
18
19 #include "src/core/ext/filters/fault_injection/fault_injection_filter.h"
20
21 #include "absl/strings/numbers.h"
22
23 #include <grpc/support/alloc.h>
24 #include <grpc/support/log.h>
25
26 #include "src/core/ext/filters/client_channel/service_config.h"
27 #include "src/core/ext/filters/client_channel/service_config_call_data.h"
28 #include "src/core/ext/filters/fault_injection/service_config_parser.h"
29 #include "src/core/lib/channel/channel_stack.h"
30 #include "src/core/lib/channel/status_util.h"
31 #include "src/core/lib/gprpp/atomic.h"
32 #include "src/core/lib/gprpp/sync.h"
33 #include "src/core/lib/iomgr/closure.h"
34 #include "src/core/lib/iomgr/timer.h"
35 #include "src/core/lib/transport/status_conversion.h"
36
37 namespace grpc_core {
38
39 TraceFlag grpc_fault_injection_filter_trace(false, "fault_injection_filter");
40
41 namespace {
42
43 Atomic<uint32_t> g_active_faults{0};
44 static_assert(
45     std::is_trivially_destructible<Atomic<uint32_t>>::value,
46     "the active fault counter needs to have a trivially destructible type");
47
48 inline int GetLinkedMetadatumValueInt(grpc_linked_mdelem* md) {
49   int res;
50   if (absl::SimpleAtoi(StringViewFromSlice(GRPC_MDVALUE(md->md)), &res)) {
51     return res;
52   } else {
53     return -1;
54   }
55 }
56
57 inline uint32_t GetLinkedMetadatumValueUnsignedInt(grpc_linked_mdelem* md) {
58   uint32_t res;
59   if (absl::SimpleAtoi(StringViewFromSlice(GRPC_MDVALUE(md->md)), &res)) {
60     return res;
61   } else {
62     return -1;
63   }
64 }
65
66 inline int64_t GetLinkedMetadatumValueInt64(grpc_linked_mdelem* md) {
67   int64_t res;
68   if (absl::SimpleAtoi(StringViewFromSlice(GRPC_MDVALUE(md->md)), &res)) {
69     return res;
70   } else {
71     return -1;
72   }
73 }
74
75 inline bool UnderFraction(const uint32_t numerator,
76                           const uint32_t denominator) {
77   if (numerator <= 0) return false;
78   if (numerator >= denominator) return true;
79   // Generate a random number in [0, denominator).
80   const uint32_t random_number = rand() % denominator;
81   return random_number < numerator;
82 }
83
84 class ChannelData {
85  public:
86   static grpc_error_handle Init(grpc_channel_element* elem,
87                                 grpc_channel_element_args* args);
88   static void Destroy(grpc_channel_element* elem);
89
90   int index() const { return index_; }
91
92  private:
93   ChannelData(grpc_channel_element* elem, grpc_channel_element_args* args);
94   ~ChannelData() = default;
95
96   // The relative index of instances of the same filter.
97   int index_;
98 };
99
100 class CallData {
101  public:
102   static grpc_error_handle Init(grpc_call_element* elem,
103                                 const grpc_call_element_args* args);
104
105   static void Destroy(grpc_call_element* elem,
106                       const grpc_call_final_info* /*final_info*/,
107                       grpc_closure* /*then_schedule_closure*/);
108
109   static void StartTransportStreamOpBatch(
110       grpc_call_element* elem, grpc_transport_stream_op_batch* batch);
111
112  private:
113   class ResumeBatchCanceller;
114
115   CallData(grpc_call_element* elem, const grpc_call_element_args* args);
116   ~CallData();
117
118   void DecideWhetherToInjectFaults(grpc_metadata_batch* initial_metadata);
119
120   // Checks if current active faults exceed the allowed max faults.
121   bool HaveActiveFaultsQuota(bool increment);
122
123   // Returns true if this RPC needs to be delayed. If so, this call will be
124   // counted as an active fault.
125   bool MaybeDelay();
126
127   // Returns the aborted RPC status if this RPC needs to be aborted. If so,
128   // this call will be counted as an active fault. Otherwise, it returns
129   // GRPC_ERROR_NONE.
130   // If this call is already been delay injected, skip the active faults
131   // quota check.
132   grpc_error_handle MaybeAbort();
133
134   // Delays the stream operations batch.
135   void DelayBatch(grpc_call_element* elem,
136                   grpc_transport_stream_op_batch* batch);
137
138   // Cancels the delay timer.
139   void CancelDelayTimer() { grpc_timer_cancel(&delay_timer_); }
140
141   // Finishes the fault injection, should only be called once.
142   void FaultInjectionFinished() {
143     g_active_faults.FetchSub(1, MemoryOrder::RELAXED);
144   }
145
146   // This is a callback that will be invoked after the delay timer is up.
147   static void ResumeBatch(void* arg, grpc_error_handle error);
148
149   // This is a callback invoked upon completion of recv_trailing_metadata.
150   // Injects the abort_error_ to the recv_trailing_metadata batch if needed.
151   static void HijackedRecvTrailingMetadataReady(void* arg, grpc_error_handle);
152
153   // Used to track the policy structs that needs to be destroyed in dtor.
154   bool fi_policy_owned_ = false;
155   const FaultInjectionMethodParsedConfig::FaultInjectionPolicy* fi_policy_;
156   grpc_call_stack* owning_call_;
157   Arena* arena_;
158   CallCombiner* call_combiner_;
159
160   // Indicates whether we are doing a delay and/or an abort for this call.
161   bool delay_request_ = false;
162   bool abort_request_ = false;
163
164   // Delay states
165   grpc_timer delay_timer_ ABSL_GUARDED_BY(delay_mu_);
166   ResumeBatchCanceller* resume_batch_canceller_ ABSL_GUARDED_BY(delay_mu_);
167   grpc_transport_stream_op_batch* delayed_batch_ ABSL_GUARDED_BY(delay_mu_);
168   // Abort states
169   grpc_error_handle abort_error_ = GRPC_ERROR_NONE;
170   grpc_closure recv_trailing_metadata_ready_;
171   grpc_closure* original_recv_trailing_metadata_ready_;
172   // Protects the asynchronous delay, resume, and cancellation.
173   Mutex delay_mu_;
174 };
175
176 // ChannelData
177
178 grpc_error_handle ChannelData::Init(grpc_channel_element* elem,
179                                     grpc_channel_element_args* args) {
180   GPR_ASSERT(elem->filter == &FaultInjectionFilterVtable);
181   new (elem->channel_data) ChannelData(elem, args);
182   return GRPC_ERROR_NONE;
183 }
184
185 void ChannelData::Destroy(grpc_channel_element* elem) {
186   auto* chand = static_cast<ChannelData*>(elem->channel_data);
187   chand->~ChannelData();
188 }
189
190 ChannelData::ChannelData(grpc_channel_element* elem,
191                          grpc_channel_element_args* args)
192     : index_(grpc_channel_stack_filter_instance_number(args->channel_stack,
193                                                        elem)) {}
194
195 // CallData::ResumeBatchCanceller
196
197 class CallData::ResumeBatchCanceller {
198  public:
199   explicit ResumeBatchCanceller(grpc_call_element* elem) : elem_(elem) {
200     auto* calld = static_cast<CallData*>(elem->call_data);
201     GRPC_CALL_STACK_REF(calld->owning_call_, "ResumeBatchCanceller");
202     GRPC_CLOSURE_INIT(&closure_, &Cancel, this, grpc_schedule_on_exec_ctx);
203     calld->call_combiner_->SetNotifyOnCancel(&closure_);
204   }
205
206  private:
207   static void Cancel(void* arg, grpc_error_handle error) {
208     auto* self = static_cast<ResumeBatchCanceller*>(arg);
209     auto* chand = static_cast<ChannelData*>(self->elem_->channel_data);
210     auto* calld = static_cast<CallData*>(self->elem_->call_data);
211     {
212       MutexLock lock(&calld->delay_mu_);
213       if (GRPC_TRACE_FLAG_ENABLED(grpc_fault_injection_filter_trace)) {
214         gpr_log(GPR_INFO,
215                 "chand=%p calld=%p: cancelling schdueled pick: "
216                 "error=%s self=%p calld->resume_batch_canceller_=%p",
217                 chand, calld, grpc_error_std_string(error).c_str(), self,
218                 calld->resume_batch_canceller_);
219       }
220       if (error != GRPC_ERROR_NONE && calld->resume_batch_canceller_ == self) {
221         // Cancel the delayed pick.
222         calld->CancelDelayTimer();
223         calld->FaultInjectionFinished();
224         // Fail pending batches on the call.
225         grpc_transport_stream_op_batch_finish_with_failure(
226             calld->delayed_batch_, GRPC_ERROR_REF(error),
227             calld->call_combiner_);
228       }
229     }
230     GRPC_CALL_STACK_UNREF(calld->owning_call_, "ResumeBatchCanceller");
231     delete self;
232   }
233
234   grpc_call_element* elem_;
235   grpc_closure closure_;
236 };
237
238 // CallData
239
240 grpc_error_handle CallData::Init(grpc_call_element* elem,
241                                  const grpc_call_element_args* args) {
242   auto* calld = new (elem->call_data) CallData(elem, args);
243   if (calld->fi_policy_ == nullptr) {
244     return GRPC_ERROR_CREATE_FROM_STATIC_STRING(
245         "failed to find fault injection policy");
246   }
247   return GRPC_ERROR_NONE;
248 }
249
250 void CallData::Destroy(grpc_call_element* elem,
251                        const grpc_call_final_info* /*final_info*/,
252                        grpc_closure* /*then_schedule_closure*/) {
253   auto* calld = static_cast<CallData*>(elem->call_data);
254   calld->~CallData();
255 }
256
257 void CallData::StartTransportStreamOpBatch(
258     grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
259   auto* calld = static_cast<CallData*>(elem->call_data);
260   // There should only be one send_initial_metdata op, and fault injection also
261   // only need to be enforced once.
262   if (batch->send_initial_metadata) {
263     calld->DecideWhetherToInjectFaults(
264         batch->payload->send_initial_metadata.send_initial_metadata);
265     if (GRPC_TRACE_FLAG_ENABLED(grpc_fault_injection_filter_trace)) {
266       gpr_log(GPR_INFO,
267               "chand=%p calld=%p: Fault injection triggered delay=%d abort=%d",
268               elem->channel_data, calld, calld->delay_request_,
269               calld->abort_request_);
270     }
271     if (calld->MaybeDelay()) {
272       // Delay the batch, and pass down the batch in the scheduled closure.
273       calld->DelayBatch(elem, batch);
274       return;
275     }
276     grpc_error_handle abort_error = calld->MaybeAbort();
277     if (abort_error != GRPC_ERROR_NONE) {
278       calld->abort_error_ = abort_error;
279       grpc_transport_stream_op_batch_finish_with_failure(
280           batch, GRPC_ERROR_REF(calld->abort_error_), calld->call_combiner_);
281       return;
282     }
283   } else {
284     if (batch->recv_trailing_metadata) {
285       // Intercept recv_trailing_metadata callback so that we can inject the
286       // failure when aborting streaming calls, because their
287       // recv_trailing_metatdata op may not be on the same batch as the
288       // send_initial_metadata op.
289       calld->original_recv_trailing_metadata_ready_ =
290           batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready;
291       batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready =
292           &calld->recv_trailing_metadata_ready_;
293     }
294     if (calld->abort_error_ != GRPC_ERROR_NONE) {
295       // If we already decided to abort, then immediately fail this batch.
296       grpc_transport_stream_op_batch_finish_with_failure(
297           batch, GRPC_ERROR_REF(calld->abort_error_), calld->call_combiner_);
298       return;
299     }
300   }
301   // Chain to the next filter.
302   grpc_call_next_op(elem, batch);
303 }
304
305 CallData::CallData(grpc_call_element* elem, const grpc_call_element_args* args)
306     : owning_call_(args->call_stack),
307       arena_(args->arena),
308       call_combiner_(args->call_combiner) {
309   auto* chand = static_cast<ChannelData*>(elem->channel_data);
310   // Fetch the fault injection policy from the service config, based on the
311   // relative index for which policy should this CallData use.
312   auto* service_config_call_data = static_cast<ServiceConfigCallData*>(
313       args->context[GRPC_CONTEXT_SERVICE_CONFIG_CALL_DATA].value);
314   auto* method_params = static_cast<FaultInjectionMethodParsedConfig*>(
315       service_config_call_data->GetMethodParsedConfig(
316           FaultInjectionServiceConfigParser::ParserIndex()));
317   if (method_params != nullptr) {
318     fi_policy_ = method_params->fault_injection_policy(chand->index());
319   }
320   GRPC_CLOSURE_INIT(&recv_trailing_metadata_ready_,
321                     HijackedRecvTrailingMetadataReady, elem,
322                     grpc_schedule_on_exec_ctx);
323 }
324
325 CallData::~CallData() {
326   if (fi_policy_owned_) {
327     fi_policy_->~FaultInjectionPolicy();
328   }
329   GRPC_ERROR_UNREF(abort_error_);
330 }
331
332 void CallData::DecideWhetherToInjectFaults(
333     grpc_metadata_batch* initial_metadata) {
334   FaultInjectionMethodParsedConfig::FaultInjectionPolicy* copied_policy =
335       nullptr;
336   // Update the policy with values in initial metadata.
337   if (!fi_policy_->abort_code_header.empty() ||
338       !fi_policy_->abort_percentage_header.empty() ||
339       !fi_policy_->delay_header.empty() ||
340       !fi_policy_->delay_percentage_header.empty()) {
341     // Defer the actual copy until the first matched header.
342     auto maybe_copy_policy_func = [this, &copied_policy]() {
343       if (copied_policy == nullptr) {
344         copied_policy =
345             arena_->New<FaultInjectionMethodParsedConfig::FaultInjectionPolicy>(
346                 *fi_policy_);
347       }
348     };
349     for (grpc_linked_mdelem* md = initial_metadata->list.head; md != nullptr;
350          md = md->next) {
351       absl::string_view key = StringViewFromSlice(GRPC_MDKEY(md->md));
352       // Only perform string comparison if:
353       //   1. Needs to check this header;
354       //   2. The value is not been filled before.
355       if (!fi_policy_->abort_code_header.empty() &&
356           (copied_policy == nullptr ||
357            copied_policy->abort_code == GRPC_STATUS_OK) &&
358           key == fi_policy_->abort_code_header) {
359         maybe_copy_policy_func();
360         grpc_status_code_from_int(GetLinkedMetadatumValueInt(md),
361                                   &copied_policy->abort_code);
362       }
363       if (!fi_policy_->abort_percentage_header.empty() &&
364           key == fi_policy_->abort_percentage_header) {
365         maybe_copy_policy_func();
366         copied_policy->abort_percentage_numerator =
367             GPR_MIN(GetLinkedMetadatumValueUnsignedInt(md),
368                     fi_policy_->abort_percentage_numerator);
369       }
370       if (!fi_policy_->delay_header.empty() &&
371           (copied_policy == nullptr || copied_policy->delay == 0) &&
372           key == fi_policy_->delay_header) {
373         maybe_copy_policy_func();
374         copied_policy->delay = static_cast<grpc_millis>(
375             GPR_MAX(GetLinkedMetadatumValueInt64(md), 0));
376       }
377       if (!fi_policy_->delay_percentage_header.empty() &&
378           key == fi_policy_->delay_percentage_header) {
379         maybe_copy_policy_func();
380         copied_policy->delay_percentage_numerator =
381             GPR_MIN(GetLinkedMetadatumValueUnsignedInt(md),
382                     fi_policy_->delay_percentage_numerator);
383       }
384     }
385     if (copied_policy != nullptr) fi_policy_ = copied_policy;
386   }
387   // Roll the dice
388   delay_request_ = fi_policy_->delay != 0 &&
389                    UnderFraction(fi_policy_->delay_percentage_numerator,
390                                  fi_policy_->delay_percentage_denominator);
391   abort_request_ = fi_policy_->abort_code != GRPC_STATUS_OK &&
392                    UnderFraction(fi_policy_->abort_percentage_numerator,
393                                  fi_policy_->abort_percentage_denominator);
394   if (!delay_request_ && !abort_request_) {
395     if (copied_policy != nullptr) copied_policy->~FaultInjectionPolicy();
396     // No fault injection for this call
397   } else {
398     fi_policy_owned_ = copied_policy != nullptr;
399   }
400 }
401
402 bool CallData::HaveActiveFaultsQuota(bool increment) {
403   if (g_active_faults.Load(MemoryOrder::ACQUIRE) >= fi_policy_->max_faults) {
404     return false;
405   }
406   if (increment) g_active_faults.FetchAdd(1, MemoryOrder::RELAXED);
407   return true;
408 }
409
410 bool CallData::MaybeDelay() {
411   if (delay_request_) {
412     return HaveActiveFaultsQuota(true);
413   }
414   return false;
415 }
416
417 grpc_error_handle CallData::MaybeAbort() {
418   if (abort_request_ && (delay_request_ || HaveActiveFaultsQuota(false))) {
419     return grpc_error_set_int(
420         GRPC_ERROR_CREATE_FROM_COPIED_STRING(fi_policy_->abort_message.c_str()),
421         GRPC_ERROR_INT_GRPC_STATUS, fi_policy_->abort_code);
422   }
423   return GRPC_ERROR_NONE;
424 }
425
426 void CallData::DelayBatch(grpc_call_element* elem,
427                           grpc_transport_stream_op_batch* batch) {
428   MutexLock lock(&delay_mu_);
429   delayed_batch_ = batch;
430   resume_batch_canceller_ = new ResumeBatchCanceller(elem);
431   grpc_millis resume_time = ExecCtx::Get()->Now() + fi_policy_->delay;
432   GRPC_CLOSURE_INIT(&batch->handler_private.closure, ResumeBatch, elem,
433                     grpc_schedule_on_exec_ctx);
434   grpc_timer_init(&delay_timer_, resume_time, &batch->handler_private.closure);
435 }
436
437 void CallData::ResumeBatch(void* arg, grpc_error_handle error) {
438   grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
439   auto* calld = static_cast<CallData*>(elem->call_data);
440   MutexLock lock(&calld->delay_mu_);
441   // Cancelled or canceller has already run
442   if (error == GRPC_ERROR_CANCELLED ||
443       calld->resume_batch_canceller_ == nullptr) {
444     return;
445   }
446   if (GRPC_TRACE_FLAG_ENABLED(grpc_fault_injection_filter_trace)) {
447     gpr_log(GPR_INFO, "chand=%p calld=%p: Resuming delayed stream op batch %p",
448             elem->channel_data, calld, calld->delayed_batch_);
449   }
450   // Lame the canceller
451   calld->resume_batch_canceller_ = nullptr;
452   // Finish fault injection.
453   calld->FaultInjectionFinished();
454   // Abort if needed.
455   error = calld->MaybeAbort();
456   if (error != GRPC_ERROR_NONE) {
457     grpc_transport_stream_op_batch_finish_with_failure(
458         calld->delayed_batch_, error, calld->call_combiner_);
459     return;
460   }
461   // Chain to the next filter.
462   grpc_call_next_op(elem, calld->delayed_batch_);
463 }
464
465 void CallData::HijackedRecvTrailingMetadataReady(void* arg,
466                                                  grpc_error_handle error) {
467   grpc_call_element* elem = static_cast<grpc_call_element*>(arg);
468   auto* calld = static_cast<CallData*>(elem->call_data);
469   if (calld->abort_error_ != GRPC_ERROR_NONE) {
470     error = grpc_error_add_child(GRPC_ERROR_REF(error),
471                                  GRPC_ERROR_REF(calld->abort_error_));
472   } else {
473     error = GRPC_ERROR_REF(error);
474   }
475   Closure::Run(DEBUG_LOCATION, calld->original_recv_trailing_metadata_ready_,
476                error);
477 }
478
479 }  // namespace
480
481 extern const grpc_channel_filter FaultInjectionFilterVtable = {
482     CallData::StartTransportStreamOpBatch,
483     grpc_channel_next_op,
484     sizeof(CallData),
485     CallData::Init,
486     grpc_call_stack_ignore_set_pollset_or_pollset_set,
487     CallData::Destroy,
488     sizeof(ChannelData),
489     ChannelData::Init,
490     ChannelData::Destroy,
491     grpc_channel_next_get_info,
492     "fault_injection_filter",
493 };
494
495 void FaultInjectionFilterInit(void) {
496   grpc_core::FaultInjectionServiceConfigParser::Register();
497 }
498
499 void FaultInjectionFilterShutdown(void) {}
500
501 }  // namespace grpc_core