[pytorch] Make qlinear weight packing thread safe (#63804)
authorJohn Shen <johnshen@fb.com>
Thu, 9 Sep 2021 16:30:32 +0000 (09:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 16:31:48 +0000 (09:31 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63804

Adding a lock around weight packing section of qlinear + qlinear_dynamic

Test Plan: automated tests

Reviewed By: kimishpatel

Differential Revision: D30340957

fbshipit-source-id: 1c9faf796c4ffbc74345396188a6f1154a76bea6

aten/src/ATen/native/quantized/cpu/qlinear.cpp
aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
aten/src/ATen/native/quantized/cpu/qnnpack_utils.h

index 9f3bb4b..9e6ddc0 100644 (file)
@@ -280,6 +280,8 @@ at::Tensor PackedLinearWeightsQnnp::apply_impl(
   size_t cols_w = input_contig.size(input_contig.dim() - 1);
   auto input_scale = input_contig.q_scale();
 
+  // QNNPack is not thread safe
+  std::lock_guard<std::mutex> lock(qnnp_mutex_);
   if (!this->input_scale.has_value() ||
       this->input_scale.value() != input_scale) {
     // Get the original weight and adjust it to uint8 from int8
index 3331a03..09f4228 100644 (file)
@@ -266,6 +266,9 @@ at::Tensor PackedLinearWeightsQnnp::apply_dynamic_impl(at::Tensor input) {
       /*qmin=*/0,
       /*qmax=*/255);
   float* weight_scales_data = w_scales.data_ptr<float>();
+
+  // QNNPack is not thread safe
+  std::lock_guard<std::mutex> lock(qnnp_mutex_);
   if (!input_scale.has_value() || input_scale.value() != q_params.scale) {
     generate_requantization_scales(
         // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
index 91ede92..fa42765 100644 (file)
@@ -75,6 +75,7 @@ struct PackedLinearWeightsQnnp : public LinearPackedParamsBase {
       c10::optional<at::Tensor> bias);
 
  private:
+  std::mutex qnnp_mutex_;
   template <bool ReluFused>
   at::Tensor apply_impl(
       at::Tensor input,