disallow nbits_in_non_outlier == 0 in acc16 conv; option to fallback to acc32 (#15708)
authorJongsoo Park <jongsoo@fb.com>
Fri, 4 Jan 2019 04:28:09 +0000 (20:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 4 Jan 2019 04:31:33 +0000 (20:31 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15708

nbits_in_non_outlier == 0 doesn't make sense because it means everything is outlier and we can just use 32-bit accumulation.
Depending on architecture, break-even point between acc16 and acc32 can be different. Adding thresholds for falling back to acc32.

Reviewed By: jianyuh

Differential Revision: D13574832

fbshipit-source-id: b7a37aacbfdc7867e31838dafcdd5f7c2ac282af

caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
caffe2/quantization/server/conv_dnnlowp_acc16_op.h
caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
caffe2/quantization/server/conv_dnnlowp_op.cc
caffe2/quantization/server/conv_dnnlowp_op.h
caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py

index dfd72e0..6711fcb 100644 (file)
@@ -18,6 +18,25 @@ C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
 C10_DECLARE_int32(caffe2_dnnlowp_copy_to_32bit_frequency);
 C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
 
+// Thresholds to fallback to 32-bit accumulation when 16-bit accumulation
+// doesn't provide performance benefits.
+C10_DEFINE_double(
+    caffe2_dnnlowp_fallback_to_32_bit_accumulation_density_threshold,
+    0.05,
+    "If density of outlier is higher than this, fallback to 32-bit accumulation");
+C10_DEFINE_int32(
+    caffe2_dnnlowp_fallback_to_32_bit_accumulation_m_threshold,
+    0,
+    "If m is smaller than this, fallback to 32-bit accumulation");
+C10_DEFINE_int32(
+    caffe2_dnnlowp_fallback_to_32_bit_accumulation_n_threshold,
+    0,
+    "If n is smaller than this, fallback to 32-bit accumulation");
+C10_DEFINE_int32(
+    caffe2_dnnlowp_fallback_to_32_bit_accumulation_k_threshold,
+    0,
+    "If k is smaller than this, fallback to 32-bit accumulation");
+
 namespace caffe2 {
 
 using namespace std;
@@ -36,6 +55,9 @@ ConvDNNLowPAcc16Op<ReluFused>::ConvDNNLowPAcc16Op(
 
 template <bool ReluFused>
 bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
+  if (fallback_to_32_bit_accumulation_) {
+    return BaseType::RunOnDeviceWithOrderNCHW();
+  }
   const Tensor& X = InputTensorCPU_(INPUT);
   if (X.template IsType<uint8_t>()) {
     return RunOnDeviceWithOrderNCHWAndType_<uint8_t>();
@@ -47,6 +69,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
 
 template <bool ReluFused>
 bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
+  if (fallback_to_32_bit_accumulation_) {
+    return BaseType::RunOnDeviceWithOrderNHWC();
+  }
   const Tensor& X = InputTensorCPU_(INPUT);
   if (X.template IsType<uint8_t>()) {
     return RunOnDeviceWithOrderNHWCAndType_<uint8_t>();
@@ -88,23 +113,84 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
 
   int kernel_dim = this->KernelDim_();
   const auto& filter = InputTensorCPU_(FILTER);
-  int M = filter.dim32(0);
+  int num_out_channels = filter.dim32(0);
+
+  // Check if we should fallback to 32-bit accumulation
+  if (this->order_ == StorageOrder::NHWC) {
+    const Tensor& X = InputTensorCPU_(INPUT);
+    int N = X.dim32(0);
+
+    Tensor* Y = OutputTensorCPU_(0);
+    this->SetOutputSize(X, Y, filter.dim32(0));
+    const int output_image_size = this->GetDimsSize(*Y);
+
+    if (N * output_image_size <
+        FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_m_threshold) {
+      LOG(INFO)
+          << "M " << N * output_image_size << " is smaller than threshold "
+          << FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_m_threshold
+          << " . Falling back to acc32";
+      fallback_to_32_bit_accumulation_ = true;
+      return true;
+    }
+    if (num_out_channels / group_ <
+        FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_n_threshold) {
+      LOG(INFO)
+          << "N " << num_out_channels / group_ << " is smaller than threshold "
+          << FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_n_threshold
+          << " . Falling back to acc32";
+      fallback_to_32_bit_accumulation_ = true;
+      return true;
+    }
+    if (kernel_dim <
+        FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_k_threshold) {
+      LOG(INFO)
+          << "K " << kernel_dim << " is smaller than threshold "
+          << FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_k_threshold
+          << " . Falling back to acc32";
+      fallback_to_32_bit_accumulation_ = true;
+      return true;
+    }
+  }
+
+  if (nbits_in_non_outlier_ == 0) {
+    // nbits_in_non_outlier_ == 0 means everything is outlier and we can just
+    // use 32-bit accumulation.
+    LOG(INFO) << "nbits_in_non_outlier == 0 means everything is outlier so we "
+                 "fallback to acc32";
+    fallback_to_32_bit_accumulation_ = true;
+    return true;
+  }
 
   // Separate out outliers
-  if (!Wq_outlier_ &&
-      ConvPoolOpBase<CPUContext>::order_ == StorageOrder::NHWC &&
+  if (!Wq_outlier_ && this->order_ == StorageOrder::NHWC &&
       nbits_in_non_outlier_ < 8) {
     CAFFE_ENFORCE(!W_quantized_.empty());
 
     Wq_outlier_.reset(ExtractOutlierMatrix(
-        group_, kernel_dim, M, nbits_in_non_outlier_, W_quantized_));
-    int outlier_cnt = Wq_outlier_->ColPtr()[M];
+        group_,
+        kernel_dim,
+        num_out_channels,
+        nbits_in_non_outlier_,
+        W_quantized_));
+    int outlier_cnt = Wq_outlier_->ColPtr()[num_out_channels];
 
     LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
               << OperatorBase::debug_def().input(1) << " is "
-              << (float)outlier_cnt / W_quantized_.size();
+              << static_cast<float>(outlier_cnt) / W_quantized_.size();
     LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_
               << " copy_to_32bit_frequency " << copy_to_32bit_frequency_;
+
+    if (static_cast<float>(outlier_cnt) / W_quantized_.size() >
+        FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_density_threshold) {
+      LOG(INFO)
+          << "Density of outliers is higher than threshold "
+          << FLAGS_caffe2_dnnlowp_fallback_to_32_bit_accumulation_density_threshold
+          << " . Falling back to acc32";
+      fallback_to_32_bit_accumulation_ = true;
+      Wq_outlier_.reset();
+      return true;
+    }
   }
 
   bool packW = ConvPoolOpBase<CPUContext>::order_ == StorageOrder::NHWC &&
@@ -147,7 +233,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
     Wq_acc16_packed_.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
         fbgemm::matrix_op_t::Transpose,
         group_ * kernel_dim,
-        M / group_,
+        num_out_channels / group_,
         W_quantized_.data(),
         kernel_dim, // ld
         nullptr, // pmat
@@ -169,6 +255,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
   if (!GetQuantizationParameters_()) {
     return false;
   }
+  if (fallback_to_32_bit_accumulation_) {
+    return BaseType::template RunOnDeviceWithOrderNCHWAndType_<InType>();
+  }
 
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
@@ -190,7 +279,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
       0,
       "The number of output channels is not divisible by group.");
 
-  ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
+  this->SetOutputSize(X, Y, filter.dim32(0));
 
   const vector<int> input_dims = GetDims(X);
   const vector<int> output_dims = GetDims(*Y);
@@ -293,11 +382,13 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
         vector<uint8_t> col_buffer_quantized;
         if (X.template IsType<uint8_t>()) {
           col_buffer_quantized_data =
-              (uint8_t*)col_buffer_data + tid * col_buffer_size;
+              reinterpret_cast<uint8_t*>(col_buffer_data) +
+              tid * col_buffer_size;
         } else {
           col_buffer_quantized.resize(kernel_dim * output_image_size);
           fbgemm::Quantize<uint8_t>(
-              (const float*)col_buffer_data + tid * col_buffer_size,
+              reinterpret_cast<const float*>(col_buffer_data) +
+                  tid * col_buffer_size,
               col_buffer_quantized.data(),
               col_buffer_quantized.size(),
               in_qparams_[INPUT]);
@@ -463,7 +554,7 @@ static void conv_nhwc_acc16_ref_(
 
 template <bool ReluFused>
 template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
-void ConvDNNLowPAcc16Op<ReluFused>::DispatchFBGEMM(
+void ConvDNNLowPAcc16Op<ReluFused>::DispatchFBGEMM_(
     PackAMatrix& packA,
     const uint8_t* col_buffer_quantized_data,
     vector<int32_t>* Y_int32,
@@ -471,6 +562,8 @@ void ConvDNNLowPAcc16Op<ReluFused>::DispatchFBGEMM(
   auto& filter = InputTensorCPU_(FILTER);
   const int M = filter.dim32(0);
 
+  bool fuse_output_pipeline = Wq_acc16_packed_ && !dequantize_output_;
+  assert(fuse_output_pipeline);
   int kernel_dim = this->KernelDim_();
 
   int nthreads = dnnlowp_get_num_threads();
@@ -538,10 +631,6 @@ void ConvDNNLowPAcc16Op<ReluFused>::ConvOutlier_(
     const int kernel_dim = this->KernelDim_();
     const int output_image_size = this->GetDimsSize(*Y);
 
-    if (nbits_in_non_outlier_ == 0) {
-      memset(Y_int32->data(), 0, sizeof((*Y_int32)[0]) * M * N);
-    }
-
 #ifdef _OPENMP
 #pragma omp parallel
 #endif
@@ -596,6 +685,10 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
     return false;
   }
 
+  if (fallback_to_32_bit_accumulation_) {
+    return BaseType::template RunOnDeviceWithOrderNHWCAndType_<InType>();
+  }
+
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
   t_end = chrono::system_clock::now();
   double dt = chrono::duration<double>(t_end - t_begin).count();
@@ -611,7 +704,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
   const int M = filter.dim32(0);
   CAFFE_ENFORCE_EQ(filter.dim32(filter.ndim() - 1), C / group_);
 
-  ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
+  this->SetOutputSize(X, Y, filter.dim32(0));
   // The dimension of each kernel
   const int kernel_dim = this->KernelDim_();
   // The output image size is the spatial size of the output.
@@ -626,6 +719,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
     t_begin = chrono::system_clock::now();
 #endif
 
+    bool fuse_output_pipeline = Wq_acc16_packed_ && !dequantize_output_;
     bool no_im2col = this->NoIm2ColNHWC_();
 
     // Im2Col, followed by gemm.
@@ -642,10 +736,11 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
 #endif
 
       // quantize col_buffer
-      uint8_t* col_buffer_quantized_data = nullptr;
+      const uint8_t* col_buffer_quantized_data = nullptr;
       vector<uint8_t> col_buffer_quantized;
       if (X.template IsType<uint8_t>()) {
-        col_buffer_quantized_data = (uint8_t*)col_buffer_data;
+        col_buffer_quantized_data =
+            reinterpret_cast<const uint8_t*>(col_buffer_data);
       } else {
         col_buffer_quantized.resize(
             group_ * kernel_dim * output_image_size * N);
@@ -675,9 +770,6 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
       t_begin = chrono::system_clock::now();
 #endif
 
-      bool fuse_output_pipeline =
-          Wq_acc16_packed_ && nbits_in_non_outlier_ > 0 && !dequantize_output_;
-
       using namespace fbgemm;
       int row_offset_size_per_thread = -1;
       int x_pack_buf_size_per_thread = -1;
@@ -703,81 +795,79 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
         Y_uint8_data = Y->template mutable_data<uint8_t>();
       }
 
-      if (nbits_in_non_outlier_ > 0) {
-        // Main GEMM for non-outlier
-        if (Wq_acc16_packed_) {
-          // fast path
+      // Main GEMM for non-outlier
+      if (Wq_acc16_packed_)
 #ifdef _OPENMP
 #pragma omp parallel
 #endif
-          {
-            int nthreads = dnnlowp_get_num_threads();
-            int tid = dnnlowp_get_thread_num();
-
-            if (fuse_output_pipeline) {
-              PackAWithRowOffset<uint8_t, int16_t> packA(
-                  matrix_op_t::NoTranspose,
-                  N * output_image_size,
-                  group_ * kernel_dim,
-                  col_buffer_quantized_data,
-                  group_ * kernel_dim,
-                  X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
-                  group_,
-                  row_offsets_.data() + tid * row_offset_size_per_thread);
-
-              if (this->quantize_groupwise_) {
-                DispatchFBGEMM<
-                    PackAWithRowOffset<uint8_t, int16_t>,
-                    QuantizationGranularity::GROUP>(
-                    packA, col_buffer_quantized_data, Y_int32, Y_uint8_data);
-              } else {
-                DispatchFBGEMM<
-                    PackAWithRowOffset<uint8_t, int16_t>,
-                    QuantizationGranularity::TENSOR>(
-                    packA, col_buffer_quantized_data, Y_int32, Y_uint8_data);
-              }
-            } else {
-              // !fuse_output_pipeline
-              PackAMatrix<uint8_t, int16_t> packA(
-                  matrix_op_t::NoTranspose,
-                  N * output_image_size,
-                  group_ * kernel_dim,
-                  col_buffer_quantized_data,
-                  group_ * kernel_dim,
-                  X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
-                  group_); // group
-
-              DoNothing<int32_t, int32_t> doNothingObj{};
-              memCopy<> memCopyObj(doNothingObj);
-              fbgemmPacked(
-                  packA,
-                  *Wq_acc16_packed_,
-                  Y_int32->data(),
-                  Y_int32->data(),
-                  M,
-                  memCopyObj,
-                  tid, // thread_id
-                  nthreads); // num_threads
-            }
-          } // omp parallel
-        } else {
-          // slow path
-          conv_nhwc_acc16_ref_(
+      {
+        // fast path
+        int nthreads = dnnlowp_get_num_threads();
+        int tid = dnnlowp_get_thread_num();
+
+        if (fuse_output_pipeline) {
+          // no im2col fusion
+          PackAWithRowOffset<uint8_t, int16_t> packA(
+              matrix_op_t::NoTranspose,
+              N * output_image_size,
+              group_ * kernel_dim,
+              col_buffer_quantized_data,
+              group_ * kernel_dim,
+              X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
               group_,
-              N,
-              output_image_size,
-              M,
-              kernel_dim,
+              row_offsets_.data() + tid * row_offset_size_per_thread);
+
+          if (this->quantize_groupwise_) {
+            DispatchFBGEMM_<
+                PackAWithRowOffset<uint8_t, int16_t>,
+                QuantizationGranularity::GROUP>(
+                packA, col_buffer_quantized_data, Y_int32, Y_uint8_data);
+          } else {
+            DispatchFBGEMM_<
+                PackAWithRowOffset<uint8_t, int16_t>,
+                QuantizationGranularity::TENSOR>(
+                packA, col_buffer_quantized_data, Y_int32, Y_uint8_data);
+          }
+        } else {
+          // !fuse_output_pipeline
+          PackAMatrix<uint8_t, int16_t> packA(
+              matrix_op_t::NoTranspose,
+              N * output_image_size,
+              group_ * kernel_dim,
               col_buffer_quantized_data,
-              W_quantized_.data(),
-              Y_int32->data()
+              group_ * kernel_dim,
+              X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
+              group_); // group
+
+          DoNothing<int32_t, int32_t> doNothingObj{};
+          memCopy<> memCopyObj(doNothingObj);
+          fbgemmPacked(
+              packA,
+              *Wq_acc16_packed_,
+              Y_int32->data(),
+              Y_int32->data(),
+              M,
+              memCopyObj,
+              tid, // thread_id
+              nthreads); // num_threads
+        } // omp parallel
+      } else {
+        // slow path
+        conv_nhwc_acc16_ref_(
+            group_,
+            N,
+            output_image_size,
+            M,
+            kernel_dim,
+            col_buffer_quantized_data,
+            W_quantized_.data(),
+            Y_int32->data()
 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH
-                  ,
-              this
+                ,
+            this
 #endif
-          );
-        } // slow path
-      } // nbits_in_non_outlier_ > 0
+        );
+      } // slow path
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
       t_end = chrono::system_clock::now();
index 61440f1..b28bb38 100644 (file)
@@ -34,7 +34,7 @@ class ConvDNNLowPAcc16Op final : public ConvDNNLowPOp<std::uint8_t, ReluFused> {
   bool RunOnDeviceWithOrderNCHW() override;
   bool RunOnDeviceWithOrderNHWC() override;
 
-  bool GetQuantizationParameters_() override;
+  bool GetQuantizationParameters_();
 
   template <typename InType>
   bool RunOnDeviceWithOrderNCHWAndType_();
@@ -42,7 +42,7 @@ class ConvDNNLowPAcc16Op final : public ConvDNNLowPOp<std::uint8_t, ReluFused> {
   bool RunOnDeviceWithOrderNHWCAndType_();
 
   template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
-  void DispatchFBGEMM(
+  void DispatchFBGEMM_(
       PackAMatrix& packA,
       const std::uint8_t* col_buffer_quantized_data,
       vector<std::int32_t>* Y_int32,
@@ -52,6 +52,10 @@ class ConvDNNLowPAcc16Op final : public ConvDNNLowPOp<std::uint8_t, ReluFused> {
       const std::uint8_t* col_buffer,
       vector<std::int32_t>* Y_int32);
 
+  virtual bool Acc16() const override {
+    return !fallback_to_32_bit_accumulation_;
+  }
+
   std::shared_ptr<fbgemm::PackBMatrix<std::int8_t, std::int16_t>>
       Wq_acc16_packed_;
 
@@ -66,7 +70,11 @@ class ConvDNNLowPAcc16Op final : public ConvDNNLowPOp<std::uint8_t, ReluFused> {
   int nbits_in_non_outlier_;
   int copy_to_32bit_frequency_;
 
-  bool first_invocation_ = true;
+  bool first_invocation_{true};
+  // If outlier matrix is not sparse enough, using 16-bit accumulation won't
+  // give speedup due to too much overhead of sparse matrix multiplication or
+  // sparse convolution anyway, so fallback to 32-bit accumulation
+  bool fallback_to_32_bit_accumulation_{false};
 }; // class ConvDNNLowPAcc16Op
 
 } // namespace caffe2
index ebd752b..a725587 100644 (file)
@@ -214,7 +214,7 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         out_quantized=st.booleans(),
         weight_quantized=st.booleans(),
         prepack_weight=st.booleans(),
-        nbits_in_non_outlier=st.sampled_from((6, 8)),
+        nbits_in_non_outlier=st.sampled_from((0, 1, 6, 8)),
         share_col_buffer=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
@@ -249,51 +249,38 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         input_channels = input_channels_per_group * group
         output_channels = output_channels_per_group * group
 
-        if nbits_in_non_outlier == 0:
-            X, W, b = generate_conv_inputs(
-                stride,
-                pad,
-                kernel,
-                dilation,
-                size,
-                group,
-                input_channels_per_group,
-                output_channels_per_group,
-                batch_size,
-                order,
-                preserve_activation_sparsity=preserve_activation_sparsity,
-                preserve_weight_sparsity=preserve_weight_sparsity,
-            )
+        X_min = 0 if preserve_activation_sparsity else -77
+        X_max = X_min + 255
+        X = np.random.rand(batch_size, size, size, input_channels) * 4 + X_min
+        X = np.round(X).astype(np.float32)
+        X[..., 0] = X_min
+        X[0, 0, 0, 1] = X_max
+
+        if preserve_weight_sparsity:
+            W_min = -128
+            W_max = 100
         else:
-            X_min = 0 if preserve_activation_sparsity else -77
-            X_max = X_min + 255
-            X = np.random.rand(batch_size, size, size, input_channels) * 4 + X_min
-            X = np.round(X).astype(np.float32)
-            X[..., 0] = X_min
-            X[0, 0, 0, 1] = X_max
-
-            if preserve_weight_sparsity:
-                W_min = -128
-                W_max = 100
-            else:
-                W_min = -100
-                W_max = W_min + 255
-            W = (
-                np.random.rand(
-                    output_channels, kernel, kernel, input_channels_per_group
-                )
-                * 4
-                - 2
-                + W_min
-                + 128
+            W_min = -100
+            W_max = W_min + 255
+        W = (
+            np.random.rand(
+                output_channels, kernel, kernel, input_channels_per_group
             )
-            W = np.round(W).astype(np.float32)
-            W[0, 0, 0, 0] = W_min
-            W[1, 0, 0, 0] = W_max
-            W[..., 1] = W_min + 128
+            * 4
+            - 2
+            + W_min
+            + 128
+        )
+        W = np.round(W).astype(np.float32)
+        W[0, 0, 0, 0] = W_min
+        W[1, 0, 0, 0] = W_max
+        W[..., 1] = W_min + 128  # "zeros"
 
-            # No input quantization error in bias
-            b = np.round(np.random.randn(output_channels)).astype(np.float32)
+        if order == "NCHW":
+            X = utils.NHWC2NCHW(X)
+            W = utils.NHWC2NCHW(W)
+
+        b = np.round(np.random.randn(output_channels)).astype(np.float32)
 
         Output = collections.namedtuple("Output", ["Y", "op_type", "engine", "order"])
         outputs = []
index 3df909c..82b8731 100644 (file)
@@ -101,8 +101,7 @@ template <typename T, bool ReluFused>
 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
   return StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
-      is_same<T, uint8_t>::value && X.template IsType<T>() &&
-      this->debug_def().engine() != "DNNLOWP_ACC16" &&
+      is_same<T, uint8_t>::value && X.template IsType<T>() && !Acc16() &&
       group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 2 && kernel_h() == 3 && kernel_w() == 3 &&
       stride_h() == stride_w() && (stride_h() == 1 || stride_h() == 2) &&
@@ -115,8 +114,7 @@ template <typename T, bool ReluFused>
 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
   bool ret = StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
-      is_same<T, uint8_t>::value && X.template IsType<T>() &&
-      this->debug_def().engine() != "DNNLOWP_ACC16" &&
+      is_same<T, uint8_t>::value && X.template IsType<T>() && !Acc16() &&
       group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 3 && this->kernel_[0] == 3 &&
       this->kernel_[1] == 3 && this->kernel_[2] == 3 &&
@@ -301,8 +299,7 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
   int M = filter.dim32(0);
 
   bool packW = ConvPoolOpBase<CPUContext>::order_ == StorageOrder::NHWC &&
-      OperatorBase::debug_def().engine() != "DNNLOWP_ACC16" &&
-      is_same<T, uint8_t>::value && GetCpuId().avx2() &&
+      !Acc16() && is_same<T, uint8_t>::value && GetCpuId().avx2() &&
       !FLAGS_caffe2_dnnlowp_force_slow_path;
 
   bool depthwise_3x3_fast_path = false, depthwise_3x3x3_fast_path = false;
@@ -399,9 +396,7 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
         reason = "fbgemm only supports 8-bit integers";
       } else if (!GetCpuId().avx2()) {
         reason = "fbgemm only supports AVX2+";
-      } else if (
-          OperatorBase::debug_def().engine() == "DNNLOWP_ACC16" ||
-          depthwise_3x3_fast_path) {
+      } else if (Acc16()) {
         reason = "";
       } else if (FLAGS_caffe2_dnnlowp_force_slow_path) {
         reason = "slow path enforced";
@@ -1079,7 +1074,7 @@ static void conv_nhwc_ref_(
 
 template <typename T, bool ReluFused>
 template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
-void ConvDNNLowPOp<T, ReluFused>::DispatchFBGEMM(
+void ConvDNNLowPOp<T, ReluFused>::DispatchFBGEMM_(
     PackAMatrix& packA,
     vector<int32_t>* Y_int32,
     uint8_t* Y_uint8_data,
@@ -1297,12 +1292,12 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
             row_offsets_.data() + tid * row_offset_size_per_thread);
 
         if (quantize_groupwise_) {
-          DispatchFBGEMM<
+          DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t>,
               QuantizationGranularity::GROUP>(
               packA, Y_int32, Y_uint8_data, Y_float_data);
         } else {
-          DispatchFBGEMM<
+          DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t>,
               QuantizationGranularity::TENSOR>(
               packA, Y_int32, Y_uint8_data, Y_float_data);
@@ -1333,12 +1328,12 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
             row_offsets_.data() + tid * row_offset_size_per_thread);
 
         if (quantize_groupwise_) {
-          DispatchFBGEMM<
+          DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t, int32_t, 3>,
               QuantizationGranularity::GROUP>(
               packA, Y_int32, Y_uint8_data, Y_float_data);
         } else {
-          DispatchFBGEMM<
+          DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t, int32_t, 3>,
               QuantizationGranularity::TENSOR>(
               packA, Y_int32, Y_uint8_data, Y_float_data);
@@ -1358,12 +1353,12 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
           row_offsets_.data() + tid * row_offset_size_per_thread);
 
       if (quantize_groupwise_) {
-        DispatchFBGEMM<
+        DispatchFBGEMM_<
             PackAWithRowOffset<uint8_t>,
             QuantizationGranularity::GROUP>(
             packA, Y_int32, Y_uint8_data, Y_float_data);
       } else {
-        DispatchFBGEMM<
+        DispatchFBGEMM_<
             PackAWithRowOffset<uint8_t>,
             QuantizationGranularity::TENSOR>(
             packA, Y_int32, Y_uint8_data, Y_float_data);
index 7807035..07bdae9 100644 (file)
@@ -26,7 +26,12 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   bool RunOnDeviceWithOrderNCHW() override;
   bool RunOnDeviceWithOrderNHWC() override;
 
-  virtual bool GetQuantizationParameters_();
+  template <typename InType>
+  bool RunOnDeviceWithOrderNCHWAndType_();
+  template <typename InType>
+  bool RunOnDeviceWithOrderNHWCAndType_();
+
+  bool GetQuantizationParameters_();
 
   /**
    * @return true if convolution is basically a GEMM point-wise (e.g., 1x1)
@@ -52,6 +57,10 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
       int nthreads,
       int thread_id);
 
+  virtual bool Acc16() const {
+    return false;
+  }
+
   Tensor col_buffer_{CPU};
   Tensor img_shape_device_{CPU};
   Tensor col_buffer_shape_device_{CPU};
@@ -101,13 +110,8 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   bool TakeDepthWise3x3FastPath_();
   bool TakeDepthWise3x3x3FastPath_();
 
-  template <typename InType>
-  bool RunOnDeviceWithOrderNCHWAndType_();
-  template <typename InType>
-  bool RunOnDeviceWithOrderNHWCAndType_();
-
   template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
-  void DispatchFBGEMM(
+  void DispatchFBGEMM_(
       PackAMatrix& packA,
       vector<std::int32_t>* Y_int32,
       uint8_t* Y_uint8_data,
index 29bace6..f969a0a 100644 (file)
@@ -193,7 +193,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
         prepack_weight=st.booleans(),
-        nbits_in_non_outlier=st.sampled_from((6, 8)),
+        nbits_in_non_outlier=st.sampled_from((0, 1, 6, 8)),
         share_col_buffer=st.booleans(),
         **hu.gcs_cpu_only
     )
@@ -223,54 +223,38 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         input_channels = input_channels_per_group * group
         output_channels = output_channels_per_group * group
 
-        if nbits_in_non_outlier == 0:
-            X, W, b = generate_conv_inputs(
-                stride,
-                pad,
-                kernel,
-                dilation,
-                size,
-                group,
-                input_channels_per_group,
-                output_channels_per_group,
-                batch_size,
-                order,
-                True,  # group-wise
-            )
-        else:
-            X_min = -77
-            X_max = X_min + 255
-            X = np.random.rand(batch_size, size, size, input_channels) * 4 + X_min
-            X = np.round(X).astype(np.float32)
-            X[..., 0] = X_min
-            X[0, 0, 0, 1] = X_max
+        X_min = -77
+        X_max = X_min + 255
+        X = np.random.rand(batch_size, size, size, input_channels) * 4 + X_min
+        X = np.round(X).astype(np.float32)
+        X[..., 0] = X_min
+        X[0, 0, 0, 1] = X_max
 
-            W_min = -100
-            W_max = W_min + 255
-            W = (
-                np.random.rand(
-                    output_channels, kernel, kernel, input_channels_per_group
-                )
-                * 4
-                - 2
-                + W_min
-                + 128
+        W_min = -100
+        W_max = W_min + 255
+        W = (
+            np.random.rand(
+                output_channels, kernel, kernel, input_channels_per_group
             )
-            W = np.round(W).astype(np.float32)
-            W[..., 1] = W_min + 128  # "zeros"
-            for g in range(group):
-                W[g * output_channels_per_group, 0, 0, 0] = W_min
-                W[g * output_channels_per_group + 1, 0, 0, 0] = W_max
-                W[
-                    g * output_channels_per_group : (g + 1) * output_channels_per_group,
-                ] += g
+            * 4
+            - 2
+            + W_min
+            + 128
+        )
+        W = np.round(W).astype(np.float32)
+        W[..., 1] = W_min + 128  # "zeros"
+        for g in range(group):
+            W[g * output_channels_per_group, 0, 0, 0] = W_min
+            W[g * output_channels_per_group + 1, 0, 0, 0] = W_max
+            W[
+                g * output_channels_per_group : (g + 1) * output_channels_per_group,
+            ] += g
 
-            if order == "NCHW":
-                X = utils.NHWC2NCHW(X)
-                W = utils.NHWC2NCHW(W)
+        if order == "NCHW":
+            X = utils.NHWC2NCHW(X)
+            W = utils.NHWC2NCHW(W)
 
-            # No input quantization error in bias
-            b = np.round(np.random.randn(output_channels)).astype(np.float32)
+        b = np.round(np.random.randn(output_channels)).astype(np.float32)
 
         Output = collections.namedtuple("Output", ["Y", "op_type", "engine", "order"])
         outputs = []