fix bug when falling back to acc32 when weight is prepacked (#18974)
authorSummer Deng <summerdeng@fb.com>
Sun, 7 Apr 2019 04:50:28 +0000 (21:50 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 7 Apr 2019 04:53:08 +0000 (21:53 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18974

When the weight is prepacked and it doesn't contain a prepacked weight for acc32, we shouldn't fallback to acc32.

Reviewed By: bddppq

Differential Revision: D14814067

fbshipit-source-id: aec917322de695e283f0aca1e930c5603d196404

caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py
caffe2/quantization/server/fbgemm_pack_op.cc

index 454be17..3bce760 100644 (file)
@@ -9,6 +9,7 @@
 #include <omp.h>
 #endif
 
+#include "caffe2/core/logging.h"
 #include "dnnlowp_op.h"
 #include "dnnlowp_partition.h"
 #include "fbgemm_pack_op.h"
@@ -17,7 +18,6 @@
 C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
 C10_DECLARE_int32(caffe2_dnnlowp_copy_to_32bit_frequency);
 C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
-
 // Thresholds to fallback to 32-bit accumulation when 16-bit accumulation
 // doesn't provide performance benefits.
 C10_DEFINE_double(
@@ -62,35 +62,8 @@ ConvDNNLowPAcc16Op<ReluFused>::ConvDNNLowPAcc16Op(
 template <bool ReluFused>
 bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
   if (fallback_to_32_bit_accumulation_) {
-    return true;
-  }
-
-  if (!BaseType::GetQuantizationParameters_()) {
-    return false;
-  }
-
-  if (!Wq_acc16_packed_ &&
-      this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
-    CAFFE_ENFORCE_EQ(
-        this->order_,
-        StorageOrder::NHWC,
-        "Pre-packed weight only works with NHWC layout");
-    // If the input is already packed
-    const auto& packed_filter =
-        this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
-    Wq_outlier_ = packed_filter.W_outlier;
-    Wq_acc16_packed_ = packed_filter.W_acc16;
-
-    if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
-      LOG(WARNING)
-          << "nbits_in_non_outlier in packed weight "
-          << packed_filter.nbits_in_non_outlier
-          << " doesn't match with nbits_in_non_outlier specified in operator "
-          << nbits_in_non_outlier_;
-    }
-
-    first_invocation_ = false;
-    return true;
+    // Short cut if we already know we are falling back to acc32
+    return BaseType::GetQuantizationParameters_();
   }
 
   int kernel_dim = this->KernelDim_();
@@ -98,7 +71,17 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
   int num_out_channels = filter.dim32(0);
 
   // Check if we should fallback to 32-bit accumulation
-  if (this->order_ == StorageOrder::NHWC) {
+  // We should do this before GetQuantizationParameters_ to make sure
+  // GetQuantizationParameters_ initialize things like Wq_packed_ for acc32
+  // properly.
+
+  // We can't fallback if layout is not NHWC or
+  // if weight is prepacked and the prepacked weight doesn't have acc32.
+  bool can_fallback_to_32_bit_accumulation =
+      this->order_ == StorageOrder::NHWC &&
+      (!this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) ||
+       this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER).W);
+  if (can_fallback_to_32_bit_accumulation) {
     const Tensor& X = InputTensorCPU_(INPUT);
     int N = X.dim32(0);
 
@@ -121,31 +104,71 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
     }
 
     if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
-      LOG(INFO) << "M " << N * output_image_size
-                << " of Conv layer with weight blob "
-                << this->debug_def().input(1) << " is smaller than threshold "
-                << FLAGS_caffe2_dnnlowp_acc16_m_threshold
-                << " . Falling back to acc32";
+      C10_LOG_FIRST_N(INFO, 10)
+          << "M " << N * output_image_size << " of Conv layer with weight blob "
+          << this->debug_def().input(FILTER) << " is smaller than threshold "
+          << FLAGS_caffe2_dnnlowp_acc16_m_threshold
+          << " . Falling back to acc32";
+      fallback_to_32_bit_accumulation_ = true;
+    }
+    if (!fallback_to_32_bit_accumulation_ &&
+        num_out_channels / group_ < acc16_n_threshold) {
+      C10_LOG_FIRST_N(INFO, 10)
+          << "N " << num_out_channels / group_
+          << " of Conv layer with weight blob "
+          << this->debug_def().input(FILTER) << " is smaller than threshold "
+          << acc16_n_threshold << " . Falling back to acc32";
       fallback_to_32_bit_accumulation_ = true;
-      return true;
     }
-    if (num_out_channels / group_ < acc16_n_threshold) {
-      LOG(INFO) << "N " << num_out_channels / group_
-                << " of Conv layer with weight blob "
-                << this->debug_def().input(1) << " is smaller than threshold "
-                << acc16_n_threshold << " . Falling back to acc32";
+    if (!fallback_to_32_bit_accumulation_ && kernel_dim < acc16_k_threshold) {
+      C10_LOG_FIRST_N(INFO, 10)
+          << "K " << kernel_dim << " of Conv layer with weight blob "
+          << this->debug_def().input(FILTER) << " is smaller than threshold "
+          << acc16_k_threshold << " . Falling back to acc32";
       fallback_to_32_bit_accumulation_ = true;
-      return true;
     }
-    if (kernel_dim < acc16_k_threshold) {
-      LOG(INFO) << "K " << kernel_dim << " of Conv layer with weight blob "
-                << this->debug_def().input(1) << " is smaller than threshold "
-                << acc16_k_threshold << " . Falling back to acc32";
+    if (!fallback_to_32_bit_accumulation_ &&
+        this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
+        !this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER)
+             .W_acc16) {
+      C10_LOG_FIRST_N(INFO, 10)
+          << "Falling back to acc32 because packed weight for acc16 is not "
+             "available";
       fallback_to_32_bit_accumulation_ = true;
-      return true;
     }
   }
 
+  if (!BaseType::GetQuantizationParameters_()) {
+    return false;
+  }
+
+  if (fallback_to_32_bit_accumulation_) {
+    return true;
+  }
+
+  if (!Wq_acc16_packed_ &&
+      this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+    CAFFE_ENFORCE_EQ(
+        this->order_,
+        StorageOrder::NHWC,
+        "Pre-packed weight only works with NHWC layout");
+    // If the input is already packed
+    const auto& packed_filter =
+        this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+    Wq_outlier_ = packed_filter.W_outlier;
+    Wq_acc16_packed_ = packed_filter.W_acc16;
+
+    if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
+      C10_LOG_FIRST_N(WARNING, 10)
+          << "nbits_in_non_outlier in packed weight "
+          << packed_filter.nbits_in_non_outlier
+          << " doesn't match with nbits_in_non_outlier specified in operator "
+          << nbits_in_non_outlier_;
+    }
+    first_invocation_ = false;
+    return true;
+  }
+
   // Separate out outliers
   if (!Wq_outlier_ && this->order_ == StorageOrder::NHWC &&
       nbits_in_non_outlier_ < 8) {
@@ -159,20 +182,25 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
         W_quantized_));
     int outlier_cnt = Wq_outlier_->ColPtr()[num_out_channels];
 
-    LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
-              << this->debug_def().input(1) << " is "
-              << static_cast<float>(outlier_cnt) / W_quantized_.size();
-    LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_
-              << " copy_to_32bit_frequency " << copy_to_32bit_frequency_;
-
-    if (static_cast<float>(outlier_cnt) / W_quantized_.size() >
-        FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
-      LOG(INFO) << "Density of outliers is higher than threshold "
-                << FLAGS_caffe2_dnnlowp_acc16_density_threshold
-                << " . Falling back to acc32";
+    C10_LOG_FIRST_N(INFO, 10)
+        << "Proportion of outlier for Conv layer with weight blob "
+        << this->debug_def().input(FILTER) << " is "
+        << static_cast<float>(outlier_cnt) / W_quantized_.size();
+    C10_LOG_FIRST_N(INFO, 10)
+        << "nbits_in_non_outlier " << nbits_in_non_outlier_
+        << " copy_to_32bit_frequency " << copy_to_32bit_frequency_;
+
+    if (can_fallback_to_32_bit_accumulation &&
+        static_cast<float>(outlier_cnt) / W_quantized_.size() >
+            FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
+      C10_LOG_FIRST_N(INFO, 10)
+          << "Density of outliers is higher than threshold "
+          << FLAGS_caffe2_dnnlowp_acc16_density_threshold
+          << " . Falling back to acc32";
       fallback_to_32_bit_accumulation_ = true;
       Wq_outlier_.reset();
-      return true;
+      // We need to call GetQuantizationParameters_ again to pack for acc32
+      return BaseType::GetQuantizationParameters_();
     }
   }
 
@@ -193,8 +221,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
         static int log_occurences = 0;
         if (log_occurences < 32) {
           ++log_occurences;
-          LOG(WARNING) << "Conv with weight " << this->debug_def().input(FILTER)
-                       << " falls back to slow path because " << reason;
+          C10_LOG_FIRST_N(WARNING, 10)
+              << "Conv with weight " << this->debug_def().input(FILTER)
+              << " falls back to slow path because " << reason;
         }
       }
     }
@@ -202,8 +231,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
       static int log_occurences = 0;
       if (log_occurences < 32) {
         ++log_occurences;
-        LOG(WARNING) << "Outlier-aware quantization only supports "
-                        "NHWC layout";
+        C10_LOG_FIRST_N(WARNING, 10)
+            << "Outlier-aware quantization only supports "
+               "NHWC layout";
       }
     }
     first_invocation_ = false;
@@ -359,7 +389,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
         static int log_occurences = 0;
         if (log_occurences < 32) {
           ++log_occurences;
-          LOG(WARNING)
+          C10_LOG_FIRST_N(WARNING, 10)
               << "Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since "
                  "we're falling back to a slow path because of NCHW layout";
         }
index 1da2f31..1ddf2ce 100644 (file)
@@ -7,15 +7,19 @@ import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep, utils, workspace
 from caffe2.quantization.server import utils as dnnlowp_utils
-from dnnlowp_test_utils import (
-    check_quantized_results_close,
-    generate_conv_inputs,
-)
+from dnnlowp_test_utils import check_quantized_results_close
 from hypothesis import assume, given
 
 
 dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
-workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
+workspace.GlobalInit(
+    [
+        "caffe2",
+        "--caffe2_omp_num_threads=11",
+        # Increase this threshold to test acc16 with randomly generated data
+        "--caffe2_dnnlowp_acc16_density_threshold=0.9",
+    ]
+)
 
 
 class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -254,9 +258,7 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
             W_min = -100
             W_max = W_min + 255
         W = (
-            np.random.rand(
-                output_channels, kernel, kernel, input_channels_per_group
-            )
+            np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
             * 4
             - 2
             + W_min
index 44f7aad..d542126 100644 (file)
@@ -7,15 +7,19 @@ import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep, utils, workspace
 from caffe2.quantization.server import utils as dnnlowp_utils
-from dnnlowp_test_utils import (
-    check_quantized_results_close,
-    generate_conv_inputs,
-)
+from dnnlowp_test_utils import check_quantized_results_close
 from hypothesis import assume, given
 
 
 dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
-workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
+workspace.GlobalInit(
+    [
+        "caffe2",
+        "--caffe2_omp_num_threads=11",
+        # Increase this threshold to test acc16 with randomly generated data
+        "--caffe2_dnnlowp_acc16_density_threshold=0.9",
+    ]
+)
 
 
 class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -224,9 +228,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         W_min = -100
         W_max = W_min + 255
         W = (
-            np.random.rand(
-                output_channels, kernel, kernel, input_channels_per_group
-            )
+            np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
             * 4
             - 2
             + W_min
@@ -237,9 +239,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         for g in range(group):
             W[g * output_channels_per_group, 0, 0, 0] = W_min
             W[g * output_channels_per_group + 1, 0, 0, 0] = W_max
-            W[
-                g * output_channels_per_group : (g + 1) * output_channels_per_group,
-            ] += g
+            W[g * output_channels_per_group : (g + 1) * output_channels_per_group,] += g
 
         if order == "NCHW":
             X = utils.NHWC2NCHW(X)
index 704d4e1..9e98b01 100644 (file)
@@ -5,6 +5,9 @@
 #include "caffe2_dnnlowp_utils.h"
 
 C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
+C10_DECLARE_double(caffe2_dnnlowp_acc16_density_threshold);
+C10_DECLARE_int32(caffe2_dnnlowp_acc16_n_threshold);
+C10_DECLARE_int32(caffe2_dnnlowp_acc16_k_threshold);
 
 namespace caffe2 {
 
@@ -422,9 +425,44 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() {
   ComputeColumnOffsets(
       kernel_dim, M, W_quantized.data(), Y->qparams, *Y->column_offsets);
 
+  // Check if we should fallback to 32-bit accumulation.
+  // This check is only meaningful when engine is DNNLOWP_ACC16.
+  bool fallback_to_32_bit_accumulation = false;
+  if (nbits_in_non_outlier_ == 0) {
+    LOG(INFO) << "nbits_in_non_outlier == 0 means everything is outlier so we "
+                 "fallback to acc32";
+    fallback_to_32_bit_accumulation = true;
+  }
+  // In Skylake, acc16 is not faster when N or K is smaller than 128
+  // FIXME : code duplication with conv_dnnlowp_acc16_op.cc
+  constexpr int SKYLAKE_ACC16_N_THRESHOLD_MIN = 128,
+                SKYLAKE_ACC16_K_THRESHOLD_MIN = 128;
+  int acc16_n_threshold = FLAGS_caffe2_dnnlowp_acc16_n_threshold;
+  if (caffe2::GetCpuId().avx512f() &&
+      acc16_n_threshold < SKYLAKE_ACC16_N_THRESHOLD_MIN) {
+    acc16_n_threshold = SKYLAKE_ACC16_N_THRESHOLD_MIN;
+  }
+  int acc16_k_threshold = FLAGS_caffe2_dnnlowp_acc16_k_threshold;
+  if (caffe2::GetCpuId().avx512f() &&
+      acc16_k_threshold < SKYLAKE_ACC16_K_THRESHOLD_MIN) {
+    acc16_k_threshold = SKYLAKE_ACC16_K_THRESHOLD_MIN;
+  }
+  if (!fallback_to_32_bit_accumulation && M / group_ < acc16_n_threshold) {
+    LOG(INFO) << "N " << M / group_ << " of weight blob "
+              << this->debug_def().input(0) << " is smaller than threshold "
+              << acc16_n_threshold << " . Falling back to acc32";
+    fallback_to_32_bit_accumulation = true;
+  }
+  if (!fallback_to_32_bit_accumulation && kernel_dim < acc16_k_threshold) {
+    LOG(INFO) << "K " << kernel_dim << " of weight blob "
+              << this->debug_def().input(0) << " is smaller than threshold "
+              << acc16_k_threshold << " . Falling back to acc32";
+    fallback_to_32_bit_accumulation = true;
+  }
+
   // When nbits_in_non_outlier == 0, we fall back to acc32
   if (this->debug_def().engine() == "DNNLOWP_ACC16" &&
-      nbits_in_non_outlier_ > 0) {
+      !fallback_to_32_bit_accumulation) {
     if (nbits_in_non_outlier_ < 8) {
       Y->W_outlier.reset(ExtractOutlierMatrix(
           group_, kernel_dim, M, nbits_in_non_outlier_, W_quantized));
@@ -434,45 +472,66 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() {
                 << this->debug_def().input(0) << " is "
                 << static_cast<float>(outlier_cnt) / W_quantized.size();
       LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_;
+
+      if (static_cast<float>(outlier_cnt) / W_quantized.size() >
+          FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
+        LOG(INFO) << "Density of outliers is higher than threshold "
+                  << FLAGS_caffe2_dnnlowp_acc16_density_threshold
+                  << " . Falling back to acc32";
+        fallback_to_32_bit_accumulation = true;
+      }
     }
 
-    Y->nbits_in_non_outlier = nbits_in_non_outlier_;
-    Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
-        fbgemm::matrix_op_t::Transpose,
-        group_ * kernel_dim,
-        M / group_,
-        W_quantized.data(),
-        kernel_dim,
-        nullptr, // pmat
-        group_));
-  } else if (TakeDepthWise3x3FastPath_()) {
-    Y->W_depthwise_3x3.reset(
-        new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data()));
-  } else if (TakeDepthWise3x3x3FastPath_()) {
-    Y->W_depthwise_3x3x3.reset(
-        new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
-  } else if (TakeGConvFastPath_()) {
-    fbgemm::conv_param_t<> conv_p(
-        1,
-        group_ * C_per_group,
-        M,
-        {1, 1},
-        group_,
-        {this->kernel_[0], this->kernel_[1]},
-        {this->stride_[0], this->stride_[1]},
-        {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
-
-    Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
-        fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data()));
-  } else {
-    Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
-        fbgemm::matrix_op_t::Transpose,
-        group_ * kernel_dim,
-        M / group_,
-        W_quantized.data(),
-        kernel_dim,
-        nullptr, // pmat
-        group_));
+    if (!fallback_to_32_bit_accumulation) {
+      Y->nbits_in_non_outlier = nbits_in_non_outlier_;
+      Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
+          fbgemm::matrix_op_t::Transpose,
+          group_ * kernel_dim,
+          M / group_,
+          W_quantized.data(),
+          kernel_dim,
+          nullptr, // pmat
+          group_));
+    }
+  }
+
+  if (fallback_to_32_bit_accumulation) {
+    Y->W_acc16.reset();
+    Y->W_outlier.reset();
+  }
+
+  if (this->debug_def().engine() != "DNNLOWP_ACC16" ||
+      fallback_to_32_bit_accumulation) {
+    // acc32
+    if (TakeDepthWise3x3FastPath_()) {
+      Y->W_depthwise_3x3.reset(
+          new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data()));
+    } else if (TakeDepthWise3x3x3FastPath_()) {
+      Y->W_depthwise_3x3x3.reset(
+          new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
+    } else if (TakeGConvFastPath_()) {
+      fbgemm::conv_param_t<> conv_p(
+          1,
+          group_ * C_per_group,
+          M,
+          {1, 1},
+          group_,
+          {this->kernel_[0], this->kernel_[1]},
+          {this->stride_[0], this->stride_[1]},
+          {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+      Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
+          fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data()));
+    } else {
+      Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
+          fbgemm::matrix_op_t::Transpose,
+          group_ * kernel_dim,
+          M / group_,
+          W_quantized.data(),
+          kernel_dim,
+          nullptr, // pmat
+          group_));
+    }
   }
 
   if (InputSize() >= 2) {