simplify conv dnnlowp ops by not allowing fp32 in/out (#15758)
authorJongsoo Park <jongsoo@fb.com>
Mon, 7 Jan 2019 23:12:25 +0000 (15:12 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 7 Jan 2019 23:14:59 +0000 (15:14 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15758

DNNLOWP Conv operators became very complex due to many options. This diff simplifies them by not allowing fp32 in/out. This is OK for Conv operators because Conv operators are usually used in deep networks where quantizing and dequantizing using separate operators is not much overhead.

Reviewed By: csummersea

Differential Revision: D13587341

fbshipit-source-id: e88c919dae79d1c5b7d787ea539edf5bcb064afc

caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
caffe2/quantization/server/conv_dnnlowp_acc16_op.h
caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
caffe2/quantization/server/conv_dnnlowp_op.cc
caffe2/quantization/server/conv_dnnlowp_op.h
caffe2/quantization/server/conv_dnnlowp_op_test.py
caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py
caffe2/quantization/server/conv_groupwise_dnnlowp_op_test.py
caffe2/quantization/server/conv_pool_dnnlowp_op_base.h
caffe2/quantization/server/dnnlowp_op.h

index 6711fcb..b4bf136 100644 (file)
@@ -54,35 +54,11 @@ ConvDNNLowPAcc16Op<ReluFused>::ConvDNNLowPAcc16Op(
           FLAGS_caffe2_dnnlowp_copy_to_32bit_frequency)) {}
 
 template <bool ReluFused>
-bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
-  if (fallback_to_32_bit_accumulation_) {
-    return BaseType::RunOnDeviceWithOrderNCHW();
-  }
-  const Tensor& X = InputTensorCPU_(INPUT);
-  if (X.template IsType<uint8_t>()) {
-    return RunOnDeviceWithOrderNCHWAndType_<uint8_t>();
-  } else {
-    assert(X.template IsType<float>());
-    return RunOnDeviceWithOrderNCHWAndType_<float>();
-  }
-}
-
-template <bool ReluFused>
-bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
+bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
   if (fallback_to_32_bit_accumulation_) {
-    return BaseType::RunOnDeviceWithOrderNHWC();
-  }
-  const Tensor& X = InputTensorCPU_(INPUT);
-  if (X.template IsType<uint8_t>()) {
-    return RunOnDeviceWithOrderNHWCAndType_<uint8_t>();
-  } else {
-    assert(X.template IsType<float>());
-    return RunOnDeviceWithOrderNHWCAndType_<float>();
+    return true;
   }
-}
 
-template <bool ReluFused>
-bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
   if (!BaseType::GetQuantizationParameters_()) {
     return false;
   }
@@ -245,8 +221,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
 }
 
 template <bool ReluFused>
-template <typename InType>
-bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
+bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
   VLOG(2) << "Running DNNLOWP_ACC16 Conv";
 
   using namespace dnnlowp;
@@ -256,7 +231,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
     return false;
   }
   if (fallback_to_32_bit_accumulation_) {
-    return BaseType::template RunOnDeviceWithOrderNCHWAndType_<InType>();
+    return BaseType::RunOnDeviceWithOrderNCHW();
   }
 
   const Tensor& X = InputTensorCPU_(INPUT);
@@ -311,10 +286,10 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
 
   // The col buffer is stored in CHW order as well - kernel_dim, and the
   // height and width.
-  const InType* Xdata = X.template data<InType>();
+  const uint8_t* Xdata = X.template data<uint8_t>();
 
   col_buffer_.Resize(buffer_shape);
-  InType* col_buffer_data = col_buffer_.template mutable_data<InType>();
+  uint8_t* col_buffer_data = col_buffer_.template mutable_data<uint8_t>();
 
   auto f = [&](vector<int32_t>* Y_int32) {
     Y_int32->resize(M * output_image_size * dnnlowp_get_max_threads());
@@ -322,16 +297,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
         buffer_shape.begin() + 1, buffer_shape.end());
 
     // Im2Col, followed by gemm.
-    vector<uint8_t> Y_temp;
-    uint8_t* Y_data;
-    float* Y_data_float = nullptr;
-    if (dequantize_output_) {
-      Y_temp.resize(Y->numel());
-      Y_data = Y_temp.data();
-      Y_data_float = Y->template mutable_data<float>();
-    } else {
-      Y_data = Y->template mutable_data<uint8_t>();
-    }
+    uint8_t* Y_data = Y->template mutable_data<uint8_t>();
     this->column_offsets_->resize(
         output_image_size * dnnlowp_get_max_threads());
 
@@ -342,7 +308,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
       int tid = dnnlowp_get_thread_num();
       for (int group_id = 0; group_id < group_; ++group_id) {
         if (this->kernel_.size() == 2) {
-          math::Im2ColNCHW<InType>(
+          math::Im2ColNCHW<uint8_t>(
               C / group_,
               input_dims[0],
               input_dims[1],
@@ -359,9 +325,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
               Xdata + (group_ * image_id + group_id) * input_offset,
               col_buffer_data + tid * col_buffer_size,
               &context_,
-              X.IsType<uint8_t>() ? in_qparams_[INPUT].zero_point : 0);
+              in_qparams_[INPUT].zero_point);
         } else {
-          math::Im2ColNdNCHW<InType>(
+          math::Im2ColNdNCHW<uint8_t>(
               this->kernel_.size(),
               C * input_image_size,
               col_buffer_size,
@@ -374,26 +340,11 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
               Xdata + (group_ * image_id + group_id) * input_offset,
               col_buffer_data + tid * col_buffer_size,
               &context_,
-              X.IsType<uint8_t>() ? in_qparams_[INPUT].zero_point : 0);
+              in_qparams_[INPUT].zero_point);
         }
 
         // quantize col_buffer
-        uint8_t* col_buffer_quantized_data = nullptr;
-        vector<uint8_t> col_buffer_quantized;
-        if (X.template IsType<uint8_t>()) {
-          col_buffer_quantized_data =
-              reinterpret_cast<uint8_t*>(col_buffer_data) +
-              tid * col_buffer_size;
-        } else {
-          col_buffer_quantized.resize(kernel_dim * output_image_size);
-          fbgemm::Quantize<uint8_t>(
-              reinterpret_cast<const float*>(col_buffer_data) +
-                  tid * col_buffer_size,
-              col_buffer_quantized.data(),
-              col_buffer_quantized.size(),
-              in_qparams_[INPUT]);
-          col_buffer_quantized_data = col_buffer_quantized.data();
-        }
+        uint8_t* col_buffer_private = col_buffer_data + tid * col_buffer_size;
 
         // main GEMM
         int32_t* Y_int32_temp = Y_int32->data() +
@@ -415,7 +366,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
             int16_t int16_sum = 0;
             for (int k = 0; k < kernel_dim; ++k) {
               int32_t w = W_quantized_group[i * kernel_dim + k];
-              int32_t x = col_buffer_quantized_data[k * output_image_size + j];
+              int32_t x = col_buffer_private[k * output_image_size + j];
 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH
               int16_sum = std::max<int32_t>(
                   numeric_limits<int16_t>::min(),
@@ -434,23 +385,12 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
           }
         }
 
-        if (dequantize_output_) {
-          this->RunOnDeviceEpilogueNCHW_(
-              col_buffer_quantized_data,
-              Y_int32_temp,
-              Y_data_float +
-                  (M * image_id + M / group_ * group_id) * output_image_size,
-              M / group_ * group_id,
-              group_id);
-        } else {
-          this->RunOnDeviceEpilogueNCHW_(
-              col_buffer_quantized_data,
-              Y_int32_temp,
-              Y_data +
-                  (M * image_id + M / group_ * group_id) * output_image_size,
-              M / group_ * group_id,
-              group_id);
-        }
+        this->RunOnDeviceEpilogueNCHW_(
+            col_buffer_private,
+            Y_int32_temp,
+            Y_data + (M * image_id + M / group_ * group_id) * output_image_size,
+            M / group_ * group_id,
+            group_id);
       } // for each group
     } // for each image_id
   }; // f
@@ -461,9 +401,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
     f(&(this->Y_int32_));
   }
 
-  if (!dequantize_output_) {
-    PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-  }
+  PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
 
   this->MeasureQuantizationError_();
 
@@ -553,17 +491,17 @@ static void conv_nhwc_acc16_ref_(
 }
 
 template <bool ReluFused>
-template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
+template <fbgemm::QuantizationGranularity Q_GRAN>
 void ConvDNNLowPAcc16Op<ReluFused>::DispatchFBGEMM_(
-    PackAMatrix& packA,
-    const uint8_t* col_buffer_quantized_data,
+    fbgemm::PackAWithRowOffset<uint8_t, int16_t>& packA,
+    const uint8_t* col_buffer_data,
     vector<int32_t>* Y_int32,
     uint8_t* Y_uint8_data) {
+  // This function is called within an OpenMP region
   auto& filter = InputTensorCPU_(FILTER);
   const int M = filter.dim32(0);
 
-  bool fuse_output_pipeline = Wq_acc16_packed_ && !dequantize_output_;
-  assert(fuse_output_pipeline);
+  assert(Wq_acc16_packed_.get());
   int kernel_dim = this->KernelDim_();
 
   int nthreads = dnnlowp_get_num_threads();
@@ -589,11 +527,7 @@ void ConvDNNLowPAcc16Op<ReluFused>::DispatchFBGEMM_(
         int32_t,
         ReQuantizeOutput<ReluFused, Q_GRAN>>
         spmdmObj(
-            reqObj,
-            col_buffer_quantized_data,
-            group_ * kernel_dim,
-            *Wq_outlier_,
-            group_);
+            reqObj, col_buffer_data, group_ * kernel_dim, *Wq_outlier_, group_);
 
     fbgemmPacked(
         packA,
@@ -664,8 +598,7 @@ void ConvDNNLowPAcc16Op<ReluFused>::ConvOutlier_(
 }
 
 template <bool ReluFused>
-template <typename InType>
-bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
+bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
   CAFFE_ENFORCE_LE(
       this->kernel_.size(),
       3,
@@ -686,7 +619,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
   }
 
   if (fallback_to_32_bit_accumulation_) {
-    return BaseType::template RunOnDeviceWithOrderNHWCAndType_<InType>();
+    return BaseType::RunOnDeviceWithOrderNHWC();
   }
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -719,14 +652,13 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
     t_begin = chrono::system_clock::now();
 #endif
 
-    bool fuse_output_pipeline = Wq_acc16_packed_ && !dequantize_output_;
     bool no_im2col = this->NoIm2ColNHWC_();
 
     // Im2Col, followed by gemm.
     auto f2 = [&](Tensor* col_buffer_) {
-      const InType* Xdata = X.template data<InType>();
-      const InType* col_buffer_data =
-          no_im2col ? Xdata : this->template Im2ColNHWC_<InType>(col_buffer_);
+      const uint8_t* Xdata = X.template data<uint8_t>();
+      const uint8_t* col_buffer_data =
+          no_im2col ? Xdata : this->Im2ColNHWC_(col_buffer_);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
       t_end = chrono::system_clock::now();
@@ -735,65 +667,21 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
       t_begin = chrono::system_clock::now();
 #endif
 
-      // quantize col_buffer
-      const uint8_t* col_buffer_quantized_data = nullptr;
-      vector<uint8_t> col_buffer_quantized;
-      if (X.template IsType<uint8_t>()) {
-        col_buffer_quantized_data =
-            reinterpret_cast<const uint8_t*>(col_buffer_data);
-      } else {
-        col_buffer_quantized.resize(
-            group_ * kernel_dim * output_image_size * N);
-#ifdef _OPENMP
-#pragma omp parallel
-#endif
-        {
-          size_t begin, end;
-          std::tie(begin, end) = Get1DPartition(
-              col_buffer_quantized.size(),
-              dnnlowp_get_num_threads(),
-              dnnlowp_get_thread_num());
-          fbgemm::Quantize<uint8_t>(
-              (const float*)col_buffer_data + begin,
-              col_buffer_quantized.data() + begin,
-              end - begin,
-              in_qparams_[INPUT]);
-        }
-        col_buffer_quantized_data = col_buffer_quantized.data();
-      }
-
-#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
-      t_end = chrono::system_clock::now();
-      dt = chrono::duration<double>(t_end - t_begin).count();
-      LOG(INFO) << "this=" << this << " quantize col_buf: " << dt * 1e3
-                << " ms";
-      t_begin = chrono::system_clock::now();
-#endif
-
       using namespace fbgemm;
       int row_offset_size_per_thread = -1;
       int x_pack_buf_size_per_thread = -1;
       if (Wq_acc16_packed_) {
-        if (fuse_output_pipeline) {
-          row_offset_size_per_thread =
-              PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize();
-          x_pack_buf_size_per_thread =
-              PackAWithRowOffset<uint8_t, int16_t>::packedBufferSize();
-          row_offsets_.resize(
-              dnnlowp_get_max_threads() * row_offset_size_per_thread);
-        } else {
-          x_pack_buf_size_per_thread =
-              PackAMatrix<uint8_t, int16_t>::packedBufferSize();
-        }
+        row_offset_size_per_thread =
+            PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize();
+        x_pack_buf_size_per_thread =
+            PackAWithRowOffset<uint8_t, int16_t>::packedBufferSize();
+        row_offsets_.resize(
+            dnnlowp_get_max_threads() * row_offset_size_per_thread);
         X_pack_buf_.resize(
             dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
       }
 
-      uint8_t* Y_uint8_data = nullptr;
-      if (!dequantize_output_) {
-        // Output is uint8_t
-        Y_uint8_data = Y->template mutable_data<uint8_t>();
-      }
+      uint8_t* Y_uint8_data = Y->template mutable_data<uint8_t>();
 
       // Main GEMM for non-outlier
       if (Wq_acc16_packed_)
@@ -802,55 +690,26 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
 #endif
       {
         // fast path
-        int nthreads = dnnlowp_get_num_threads();
         int tid = dnnlowp_get_thread_num();
 
-        if (fuse_output_pipeline) {
-          // no im2col fusion
-          PackAWithRowOffset<uint8_t, int16_t> packA(
-              matrix_op_t::NoTranspose,
-              N * output_image_size,
-              group_ * kernel_dim,
-              col_buffer_quantized_data,
-              group_ * kernel_dim,
-              X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
-              group_,
-              row_offsets_.data() + tid * row_offset_size_per_thread);
-
-          if (this->quantize_groupwise_) {
-            DispatchFBGEMM_<
-                PackAWithRowOffset<uint8_t, int16_t>,
-                QuantizationGranularity::GROUP>(
-                packA, col_buffer_quantized_data, Y_int32, Y_uint8_data);
-          } else {
-            DispatchFBGEMM_<
-                PackAWithRowOffset<uint8_t, int16_t>,
-                QuantizationGranularity::TENSOR>(
-                packA, col_buffer_quantized_data, Y_int32, Y_uint8_data);
-          }
+        // no im2col fusion
+        PackAWithRowOffset<uint8_t, int16_t> packA(
+            matrix_op_t::NoTranspose,
+            N * output_image_size,
+            group_ * kernel_dim,
+            col_buffer_data,
+            group_ * kernel_dim,
+            X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
+            group_,
+            row_offsets_.data() + tid * row_offset_size_per_thread);
+
+        if (this->quantize_groupwise_) {
+          DispatchFBGEMM_<QuantizationGranularity::GROUP>(
+              packA, col_buffer_data, Y_int32, Y_uint8_data);
         } else {
-          // !fuse_output_pipeline
-          PackAMatrix<uint8_t, int16_t> packA(
-              matrix_op_t::NoTranspose,
-              N * output_image_size,
-              group_ * kernel_dim,
-              col_buffer_quantized_data,
-              group_ * kernel_dim,
-              X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
-              group_); // group
-
-          DoNothing<int32_t, int32_t> doNothingObj{};
-          memCopy<> memCopyObj(doNothingObj);
-          fbgemmPacked(
-              packA,
-              *Wq_acc16_packed_,
-              Y_int32->data(),
-              Y_int32->data(),
-              M,
-              memCopyObj,
-              tid, // thread_id
-              nthreads); // num_threads
-        } // omp parallel
+          DispatchFBGEMM_<QuantizationGranularity::TENSOR>(
+              packA, col_buffer_data, Y_int32, Y_uint8_data);
+        }
       } else {
         // slow path
         conv_nhwc_acc16_ref_(
@@ -859,7 +718,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
             output_image_size,
             M,
             kernel_dim,
-            col_buffer_quantized_data,
+            col_buffer_data,
             W_quantized_.data(),
             Y_int32->data()
 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH
@@ -879,8 +738,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
       t_begin = chrono::system_clock::now();
 #endif
 
-      if (!fuse_output_pipeline) {
-        ConvOutlier_(col_buffer_quantized_data, Y_int32);
+      if (!Wq_acc16_packed_) {
+        ConvOutlier_(col_buffer_data, Y_int32);
       }
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -890,13 +749,10 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
       t_begin = chrono::system_clock::now();
 #endif
 
-      if (!fuse_output_pipeline) {
-        this->RunOnDeviceEpilogueNHWC_(
-            col_buffer_quantized_data, Y_int32->data());
+      if (!Wq_acc16_packed_) {
+        this->RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
       } else {
-        if (!dequantize_output_) {
-          PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-        }
+        PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
       }
     }; // f2
 
index b28bb38..39dbddc 100644 (file)
@@ -18,7 +18,6 @@ class ConvDNNLowPAcc16Op final : public ConvDNNLowPOp<std::uint8_t, ReluFused> {
   using BaseType = ConvDNNLowPOp<std::uint8_t, ReluFused>;
   using BaseType::BIAS;
   using BaseType::col_buffer_;
-  using BaseType::dequantize_output_;
   using BaseType::FILTER;
   using BaseType::in_qparams_;
   using BaseType::INPUT;
@@ -36,15 +35,10 @@ class ConvDNNLowPAcc16Op final : public ConvDNNLowPOp<std::uint8_t, ReluFused> {
 
   bool GetQuantizationParameters_();
 
-  template <typename InType>
-  bool RunOnDeviceWithOrderNCHWAndType_();
-  template <typename InType>
-  bool RunOnDeviceWithOrderNHWCAndType_();
-
-  template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
+  template <fbgemm::QuantizationGranularity Q_GRAN>
   void DispatchFBGEMM_(
-      PackAMatrix& packA,
-      const std::uint8_t* col_buffer_quantized_data,
+      fbgemm::PackAWithRowOffset<std::uint8_t, std::int16_t>& packA,
+      const std::uint8_t* col_buffer_data,
       vector<std::int32_t>* Y_int32,
       uint8_t* Y_uint8_data);
 
index a725587..1da2f31 100644 (file)
@@ -31,8 +31,6 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group=st.integers(2, 16),
         batch_size=st.integers(1, 3),
         order=st.sampled_from(["NCHW", "NHWC"]),
-        in_quantized=st.booleans(),
-        out_quantized=st.booleans(),
         weight_quantized=st.booleans(),
         share_col_buffer=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
@@ -51,8 +49,6 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
-        in_quantized,
-        out_quantized,
         weight_quantized,
         share_col_buffer,
         preserve_activation_sparsity,
@@ -121,8 +117,8 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         for op_type, engine in op_engine_list:
             net = core.Net("test_net")
 
-            do_quantize = "DNNLOWP" in engine and in_quantized
-            do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_quantize = "DNNLOWP" in engine
+            do_dequantize = "DNNLOWP" in engine
             do_quantize_weight = (
                 "DNNLOWP" in engine and weight_quantized and len(outputs) > 0
             )
@@ -166,7 +162,6 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 dilation=dilation,
                 pad=pad,
                 order=order,
-                dequantize_output=not do_dequantize,
                 shared_buffer=(1 if share_col_buffer else 0),
                 preserve_activation_sparsity=preserve_activation_sparsity,
                 preserve_weight_sparsity=preserve_weight_sparsity,
@@ -210,8 +205,6 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group=st.integers(2, 16),
         batch_size=st.integers(1, 3),
         order=st.sampled_from(["NHWC"]),
-        in_quantized=st.booleans(),
-        out_quantized=st.booleans(),
         weight_quantized=st.booleans(),
         prepack_weight=st.booleans(),
         nbits_in_non_outlier=st.sampled_from((0, 1, 6, 8)),
@@ -232,8 +225,6 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
-        in_quantized,
-        out_quantized,
         weight_quantized,
         prepack_weight,
         nbits_in_non_outlier,
@@ -295,8 +286,8 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
             init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
-            do_quantize = "DNNLOWP" in engine and in_quantized
-            do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_quantize = "DNNLOWP" in engine
+            do_dequantize = "DNNLOWP" in engine
             do_quantize_weight = "DNNLOWP" in engine and weight_quantized
             do_prepack_weight = "DNNLOWP" in engine and prepack_weight
 
@@ -357,7 +348,6 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 dilation=dilation,
                 pad=pad,
                 order=order,
-                dequantize_output=not do_dequantize,
                 nbits_in_non_outlier=nbits_in_non_outlier,
                 shared_buffer=(1 if share_col_buffer else 0),
                 preserve_activation_sparsity=preserve_activation_sparsity,
index 82b8731..e8cb039 100644 (file)
@@ -76,45 +76,22 @@ ConvDNNLowPOp<T, ReluFused>::RequantizationParams(int group_id) {
 }
 
 template <typename T, bool ReluFused>
-bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHW() {
-  const Tensor& X = InputTensorCPU_(INPUT);
-  if (X.template IsType<T>()) {
-    return RunOnDeviceWithOrderNCHWAndType_<T>();
-  } else {
-    assert(X.template IsType<float>());
-    return RunOnDeviceWithOrderNCHWAndType_<float>();
-  }
-}
-
-template <typename T, bool ReluFused>
-bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
-  const Tensor& X = InputTensorCPU_(INPUT);
-  if (X.template IsType<T>()) {
-    return RunOnDeviceWithOrderNHWCAndType_<T>();
-  } else {
-    assert(X.template IsType<float>());
-    return RunOnDeviceWithOrderNHWCAndType_<float>();
-  }
-}
-
-template <typename T, bool ReluFused>
 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
   return StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
-      is_same<T, uint8_t>::value && X.template IsType<T>() && !Acc16() &&
+      is_same<T, uint8_t>::value && !Acc16() &&
       group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 2 && kernel_h() == 3 && kernel_w() == 3 &&
       stride_h() == stride_w() && (stride_h() == 1 || stride_h() == 2) &&
       dilation_h() == 1 && dilation_w() == 1 && pad_t() == 1 && pad_b() == 1 &&
-      pad_l() == 1 && pad_r() == 1 && !dequantize_output_ &&
-      GetCpuId().avx2() && !quantize_groupwise_;
+      pad_l() == 1 && pad_r() == 1 && GetCpuId().avx2() && !quantize_groupwise_;
 }
 
 template <typename T, bool ReluFused>
 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
   bool ret = StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
-      is_same<T, uint8_t>::value && X.template IsType<T>() && !Acc16() &&
+      is_same<T, uint8_t>::value && !Acc16() &&
       group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 3 && this->kernel_[0] == 3 &&
       this->kernel_[1] == 3 && this->kernel_[2] == 3 &&
@@ -125,7 +102,7 @@ bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
       this->dilation_[2] == 1 &&
       accumulate(
           this->pads_.begin(), this->pads_.end(), 1, multiplies<int>()) == 1 &&
-      !dequantize_output_ && GetCpuId().avx2() && !quantize_groupwise_;
+      GetCpuId().avx2() && !quantize_groupwise_;
   return ret;
 }
 
@@ -179,13 +156,13 @@ bool ConvDNNLowPOp<T, ReluFused>::NoIm2ColNHWC_() {
     return true;
   }
 
-  const Tensor& X = InputTensorCPU_(INPUT);
-  if (Wq_packed_ && X.template IsType<T>() &&
+  if (Wq_packed_ &&
       accumulate(
           this->dilation_.begin(),
           this->dilation_.end(),
           1,
           multiplies<int>()) == 1) {
+    // im2col fusion
     return true;
   }
 
@@ -225,14 +202,13 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeBias_() {
 
   // Quantize bias
   if (InputSize() == 3 &&
-      ((!b_quantized_data_ && !b_dequantized_data_) ||
+      (!b_quantized_data_ ||
        in_qparams_[INPUT].scale != in_qparams_scale_old_)) {
     if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
         this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER)
             .bias.get()) {
       const auto& packed_filter =
           this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
-      CAFFE_ENFORCE(!dequantize_output_);
       b_quantized_ = packed_filter.bias;
       b_quantized_data_ = b_quantized_->data();
     } else {
@@ -250,42 +226,27 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeBias_() {
             1e-4);
         CAFFE_ENFORCE_EQ(bias_qparams.zero_point, 0);
         b_quantized_data_ = bias.template data<int32_t>();
-        if (dequantize_output_) {
-          b_dequantized_.resize(bias.numel());
-#ifdef _OPENMP
-#pragma omp parallel for
-#endif
-          for (int i = 0; i < b_dequantized_.size(); ++i) {
-            b_dequantized_[i] =
-                fbgemm::Dequantize<int32_t>(b_quantized_data_[i], bias_qparams);
-          }
-          b_dequantized_data_ = b_dequantized_.data();
-        }
       } else {
-        b_dequantized_data_ = bias.template data<float>();
-        if (!dequantize_output_) {
-          b_quantized_->resize(bias.numel());
-          for (int g = 0; g < filter_qparams_.size(); ++g) {
-            int i_begin = g * (M / filter_qparams_.size());
-            int i_end = i_begin + (M / filter_qparams_.size());
-            for (int i = i_begin; i < i_end; ++i) {
-              (*b_quantized_)[i] = fbgemm::Quantize<int32_t>(
-                  b_dequantized_data_[i],
-                  0,
-                  in_qparams_[INPUT].scale * FilterQuantizationParams(g).scale,
-                  32,
-                  true /* signed */);
-            }
+        const float* b_data = bias.template data<float>();
+        b_quantized_->resize(bias.numel());
+        for (int g = 0; g < filter_qparams_.size(); ++g) {
+          int i_begin = g * (M / filter_qparams_.size());
+          int i_end = i_begin + (M / filter_qparams_.size());
+          for (int i = i_begin; i < i_end; ++i) {
+            (*b_quantized_)[i] = fbgemm::Quantize<int32_t>(
+                b_data[i],
+                0,
+                in_qparams_[INPUT].scale * FilterQuantizationParams(g).scale,
+                32,
+                true /* signed */);
           }
-          b_quantized_data_ = b_quantized_->data();
         }
+        b_quantized_data_ = b_quantized_->data();
       }
       in_qparams_scale_old_ = in_qparams_[INPUT].scale;
     }
 
-    CAFFE_ENFORCE(
-        (dequantize_output_ && b_dequantized_data_) ||
-        (!dequantize_output_ && b_quantized_data_));
+    CAFFE_ENFORCE(b_quantized_data_);
   }
 }
 
@@ -424,8 +385,13 @@ bool ConvDNNLowPOp<T, ReluFused>::GetQuantizationParameters_() {
   using namespace dnnlowp;
 
   if (!this->arguments_parsed_) {
+    bool dequantize_output;
     ParseDNNLowPOperatorArguments(
-        this, &dequantize_output_, &measure_quantization_error_, &followed_by_);
+        this, &dequantize_output, &measure_quantization_error_, &followed_by_);
+    CAFFE_ENFORCE_EQ(
+        dequantize_output,
+        false,
+        "Conv DNNLOWP operators don't support dequantize_output");
 
     if (ReluFused) {
       // It's actually fused with Relu not followed by but setting this to make
@@ -451,26 +417,23 @@ bool ConvDNNLowPOp<T, ReluFused>::GetQuantizationParameters_() {
   QuantizeBias_();
 
   bool fp32_executed = false;
-  if (!dequantize_output_) {
-    if (HasStaticQuantization(this)) {
-      out_qparams_ = GetStaticQuantizationParamsOf(this, 0);
-    } else {
-      // If quantization parameters are not chosen beforehand, run reference
-      // Conv op in fp32 to choose quantization for Y.
-      Fp32Op_()->DequantizeInput();
-      Fp32Op_()->Get()->RunOnDevice();
-      out_qparams_ = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get());
-      fp32_executed = true;
-    }
+  if (HasStaticQuantization(this)) {
+    out_qparams_ = GetStaticQuantizationParamsOf(this, 0);
+  } else {
+    // If quantization parameters are not chosen beforehand, run reference
+    // Conv op in fp32 to choose quantization for Y.
+    Fp32Op_()->DequantizeInput();
+    Fp32Op_()->Get()->RunOnDevice();
+    out_qparams_ = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get());
+    fp32_executed = true;
+  }
 
-    for (int g = 0; g < filter_qparams_.size(); ++g) {
-      float real_multiplier = in_qparams_[INPUT].scale *
-          FilterQuantizationParams(g).scale / out_qparams_.scale;
-      requantization_params_[g] = qfactory_->ChooseRequantizationMultiplier(
-          real_multiplier, out_qparams_);
-      requantization_multipliers_[g] =
-          requantization_params_[g].real_multiplier;
-    }
+  for (int g = 0; g < filter_qparams_.size(); ++g) {
+    float real_multiplier = in_qparams_[INPUT].scale *
+        FilterQuantizationParams(g).scale / out_qparams_.scale;
+    requantization_params_[g] = qfactory_->ChooseRequantizationMultiplier(
+        real_multiplier, out_qparams_);
+    requantization_multipliers_[g] = requantization_params_[g].real_multiplier;
   }
 
   if (measure_quantization_error_ && Fp32Op_() && !fp32_executed) {
@@ -483,11 +446,10 @@ bool ConvDNNLowPOp<T, ReluFused>::GetQuantizationParameters_() {
 }
 
 template <typename T, bool ReluFused>
-template <typename OutType>
 void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNCHW_(
-    const T* col_buffer_quantized_data,
+    const T* col_buffer_data,
     int32_t* Y_int32,
-    OutType* Y_data,
+    T* Y_data,
     size_t i_offset,
     int group_id) {
   auto& filter = InputTensorCPU_(FILTER);
@@ -507,51 +469,30 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNCHW_(
   for (int j = 0; j < Y_HxW; ++j) {
     int sum = 0;
     for (int k = 0; k < kernel_dim; ++k) {
-      sum += col_buffer_quantized_data[k * Y_HxW + j];
+      sum += col_buffer_data[k * Y_HxW + j];
     }
     column_offsets[j] = sum * filter_qparams.zero_point;
   }
 
-  if (dequantize_output_) {
-    for (int i = 0; i < M / group_; ++i) {
-      int32_t row_offset = row_offsets_[i_offset + i];
-      row_offset *= -in_qparams_[INPUT].zero_point;
-      for (int j = 0; j < Y_HxW; ++j) {
-        int32_t raw = Y_int32[i * Y_HxW + j] + row_offset - column_offsets[j];
-        float dequantized =
-            raw * in_qparams_[INPUT].scale * filter_qparams.scale;
-        if (InputSize() == 3) {
-          dequantized += b_dequantized_data_[i_offset + i];
-        }
-        if (ReluFused) {
-          dequantized = std::max(0.f, dequantized);
-        }
-        Y_data[i * Y_HxW + j] = dequantized;
-      }
+  for (int i = 0; i < M / group_; ++i) {
+    int32_t row_offset = row_offsets_[i_offset + i];
+    row_offset *= -in_qparams_[INPUT].zero_point;
+    if (InputSize() == 3) {
+      row_offset += b_quantized_data_[i_offset + i];
     }
-  } // dequantize_output_
-  else {
-    for (int i = 0; i < M / group_; ++i) {
-      int32_t row_offset = row_offsets_[i_offset + i];
-      row_offset *= -in_qparams_[INPUT].zero_point;
-      if (InputSize() == 3) {
-        row_offset += b_quantized_data_[i_offset + i];
-      }
-      for (int j = 0; j < Y_HxW; ++j) {
-        int32_t raw = Y_int32[i * Y_HxW + j] + row_offset - column_offsets[j];
-        if (ReluFused) {
-          raw = std::max(0, raw);
-        }
-        Y_data[i * Y_HxW + j] =
-            fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
+    for (int j = 0; j < Y_HxW; ++j) {
+      int32_t raw = Y_int32[i * Y_HxW + j] + row_offset - column_offsets[j];
+      if (ReluFused) {
+        raw = std::max(0, raw);
       }
+      Y_data[i * Y_HxW + j] =
+          fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
     }
-  } // !dequantize_output_
+  }
 }
 
 template <typename T, bool ReluFused>
-template <typename InType>
-bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
+bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHW() {
   VLOG(2) << "Running DNNLOWP Conv";
 
   using namespace dnnlowp;
@@ -613,23 +554,17 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
 
   // The col buffer is stored in CHW order as well - kernel_dim, and the
   // height and width.
-  const InType* Xdata = X.template data<InType>();
+  const T* Xdata = X.template data<T>();
 
   // We must not call mutable_data inside omp region
-  float* Y_data_float = nullptr;
-  T* Y_data_T = nullptr;
-  if (dequantize_output_) {
-    Y_data_float = Y->template mutable_data<float>();
-  } else {
-    Y_data_T = Y->template mutable_data<T>();
-  }
+  T* Y_data_T = Y->template mutable_data<T>();
   column_offsets_->resize(Y_HxW * dnnlowp_get_max_threads());
 
   auto f = [&](Tensor* col_buffer) {
     col_buffer->Resize(buffer_shape);
     vector<int> buffer_shape_per_thread(
         buffer_shape.begin() + 1, buffer_shape.end());
-    InType* col_buffer_data = col_buffer->template mutable_data<InType>();
+    T* col_buffer_data = col_buffer->template mutable_data<T>();
 
     auto f2 = [&](vector<int32_t>* Y_int32) {
       Y_int32->resize(M * Y_HxW * dnnlowp_get_max_threads());
@@ -642,7 +577,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
         int tid = dnnlowp_get_thread_num();
         for (int group_id = 0; group_id < group_; ++group_id) {
           if (this->kernel_.size() == 2) {
-            math::Im2ColNCHW<InType>(
+            math::Im2ColNCHW<T>(
                 C / group_,
                 input_dims[0],
                 input_dims[1],
@@ -659,9 +594,9 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
                 Xdata + (group_ * image_id + group_id) * input_offset,
                 col_buffer_data + tid * col_buffer_size,
                 &context_,
-                X.IsType<T>() ? in_qparams_[INPUT].zero_point : 0);
+                in_qparams_[INPUT].zero_point);
           } else {
-            math::Im2ColNdNCHW<InType>(
+            math::Im2ColNdNCHW<T>(
                 this->kernel_.size(),
                 C * X_HxW,
                 col_buffer_size,
@@ -674,24 +609,11 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
                 Xdata + (group_ * image_id + group_id) * input_offset,
                 col_buffer_data + tid * col_buffer_size,
                 &context_,
-                X.IsType<T>() ? in_qparams_[INPUT].zero_point : 0);
+                in_qparams_[INPUT].zero_point);
           }
 
           // quantize col_buffer
-          T* col_buffer_quantized_data = nullptr;
-          vector<T> col_buffer_quantized;
-          if (X.template IsType<T>()) {
-            col_buffer_quantized_data =
-                reinterpret_cast<T*>(col_buffer_data) + tid * col_buffer_size;
-          } else {
-            col_buffer_quantized.resize(kernel_dim * Y_HxW);
-            fbgemm::Quantize<T>(
-                (const float*)col_buffer_data + tid * col_buffer_size,
-                col_buffer_quantized.data(),
-                col_buffer_quantized.size(),
-                in_qparams_[INPUT]);
-            col_buffer_quantized_data = col_buffer_quantized.data();
-          }
+          T* col_buffer_private = col_buffer_data + tid * col_buffer_size;
 
           int32_t* Y_int32_temp =
               Y_int32->data() + ((M / group_) * group_id + M * tid) * Y_HxW;
@@ -703,34 +625,23 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
               int32_t sum = 0;
               for (int k = 0; k < kernel_dim; ++k) {
                 int w = W_quantized_group[i * kernel_dim + k];
-                int x = col_buffer_quantized_data[k * Y_HxW + j];
+                int x = col_buffer_private[k * Y_HxW + j];
                 sum += w * x;
               }
               Y_int32_temp[i * Y_HxW + j] = sum;
             } // j
           } // i
 
-          if (dequantize_output_) {
-            RunOnDeviceEpilogueNCHW_(
-                col_buffer_quantized_data,
-                Y_int32_temp,
-                Y_data_float + (M * image_id + M / group_ * group_id) * Y_HxW,
-                M / group_ * group_id,
-                group_id);
-          } else {
-            RunOnDeviceEpilogueNCHW_(
-                col_buffer_quantized_data,
-                Y_int32_temp,
-                Y_data_T + (M * image_id + M / group_ * group_id) * Y_HxW,
-                M / group_ * group_id,
-                group_id);
-          }
+          RunOnDeviceEpilogueNCHW_(
+              col_buffer_private,
+              Y_int32_temp,
+              Y_data_T + (M * image_id + M / group_ * group_id) * Y_HxW,
+              M / group_ * group_id,
+              group_id);
         } // for each group
       } // for each image_id
 
-      if (!dequantize_output_) {
-        PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-      }
+      PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
       MeasureQuantizationError_();
     }; // f2
 
@@ -748,11 +659,11 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
   }
 
   return true;
-} // RunOnDeviceWithOrderNCHWAndType_
+} // RunOnDeviceWithOrderNCHW
 
 template <typename T, bool ReluFused>
 void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
-    const T* col_buffer_quantized_data,
+    const T* col_buffer_data,
     int32_t* Y_int32) {
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
@@ -765,177 +676,138 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
   // Adjust with bias and zero_point and then requantize
   // See batch_matmul_dnnlowp_op.cc to why we compute column_offsets,
   // row_offset, and const_offset in this way.
-  if (dequantize_output_) {
-    float* Ydata = Y->template mutable_data<float>();
+  int32_t A_zero_point = in_qparams_[INPUT].zero_point;
+
+  if (!dnnlowp::HasStaticQuantization(this)) {
+    if (quantize_groupwise_) {
+      static int log_occurences = 0;
+      if (log_occurences < 32) {
+        ++log_occurences;
+        LOG(WARNING) << "Cannot do group-wise quantization without "
+                        "static quantization of activations for "
+                     << OperatorBase::debug_def().output(0);
+      }
+    }
+
+    int32_t Y_min = numeric_limits<int32_t>::max();
+    int32_t Y_max = numeric_limits<int32_t>::min();
 
 #ifdef _OPENMP
-#pragma omp parallel for
+#pragma omp parallel for reduction(min : Y_min), reduction(max : Y_max)
 #endif
     for (int i = 0; i < N * Y_HxW; ++i) {
       for (int group_id = 0; group_id < group_; ++group_id) {
         int32_t row_offset = 0;
         for (int k = 0; k < kernel_dim; ++k) {
-          row_offset += col_buffer_quantized_data
-              [(i * group_ + group_id) * kernel_dim + k];
+          row_offset +=
+              col_buffer_data[(i * group_ + group_id) * kernel_dim + k];
         }
-        row_offset *= FilterQuantizationParams(group_id).zero_point;
+        row_offset *= FilterQuantizationParams(0).zero_point;
 
         for (int j = group_id * (M / group_); j < (group_id + 1) * (M / group_);
              ++j) {
-          Y_int32[i * M + j] -=
-              in_qparams_[INPUT].zero_point * (*column_offsets_)[j] +
-              row_offset;
-          Ydata[i * M + j] = Y_int32[i * M + j] * in_qparams_[INPUT].scale *
-                  FilterQuantizationParams(group_id).scale +
-              ((InputSize() == 3) ? b_dequantized_data_[j] : 0.f);
-          if (ReluFused) {
-            Ydata[i * M + j] = std::max(Ydata[i * M + j], 0.f);
+          int32_t raw = Y_int32[i * M + j] -
+              A_zero_point * (*column_offsets_)[j] - row_offset;
+          if (b_quantized_data_) {
+            raw += b_quantized_data_[j];
           }
+          Y_min = std::min(Y_min, raw);
+          Y_max = std::max(Y_max, raw);
         }
       } // for each group
-    } // for each i
-  } else {
-    int32_t A_zero_point = in_qparams_[INPUT].zero_point;
-
-    if (!dnnlowp::HasStaticQuantization(this)) {
-      if (quantize_groupwise_) {
-        static int log_occurences = 0;
-        if (log_occurences < 32) {
-          ++log_occurences;
-          LOG(WARNING) << "Cannot do group-wise quantization without "
-                          "static quantization of activations for "
-                       << OperatorBase::debug_def().output(0);
-        }
-      }
-
-      int32_t Y_int32_min = numeric_limits<int32_t>::max();
-      int32_t Y_int32_max = numeric_limits<int32_t>::min();
-
-#ifdef _OPENMP
-#pragma omp parallel for reduction(min             \
-                                   : Y_int32_min), \
-    reduction(max                                  \
-              : Y_int32_max)
-#endif
-      for (int i = 0; i < N * Y_HxW; ++i) {
-        for (int group_id = 0; group_id < group_; ++group_id) {
-          int32_t row_offset = 0;
-          for (int k = 0; k < kernel_dim; ++k) {
-            row_offset += col_buffer_quantized_data
-                [(i * group_ + group_id) * kernel_dim + k];
-          }
-          row_offset *= FilterQuantizationParams(0).zero_point;
-
-          for (int j = group_id * (M / group_);
-               j < (group_id + 1) * (M / group_);
-               ++j) {
-            int32_t raw = Y_int32[i * M + j] -
-                A_zero_point * (*column_offsets_)[j] - row_offset;
-            if (b_quantized_data_) {
-              raw += b_quantized_data_[j];
-            }
-            Y_int32_min = std::min(Y_int32_min, raw);
-            Y_int32_max = std::max(Y_int32_max, raw);
-          }
-        } // for each group
-      } // for each row i
+    } // for each row i
 
-      if (ReluFused) {
-        Y_int32_min = std::max(0, Y_int32_min);
-        Y_int32_max = std::max(0, Y_int32_max);
-      }
+    if (ReluFused) {
+      Y_min = std::max(0, Y_min);
+      Y_max = std::max(0, Y_max);
+    }
 
-      float Y_int32_scale =
-          in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale;
-      out_qparams_ = qfactory_->ChooseQuantizationParams(
-          Y_int32_scale * Y_int32_min, Y_int32_scale * Y_int32_max);
+    float Y_scale =
+        in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale;
+    out_qparams_ =
+        qfactory_->ChooseQuantizationParams(Y_scale * Y_min, Y_scale * Y_max);
 
-      float real_multiplier = Y_int32_scale / out_qparams_.scale;
-      requantization_params_[0] = qfactory_->ChooseRequantizationMultiplier(
-          real_multiplier, out_qparams_);
-      requantization_multipliers_[0] =
-          requantization_params_[0].real_multiplier;
-    }
+    float real_multiplier = Y_scale / out_qparams_.scale;
+    requantization_params_[0] = qfactory_->ChooseRequantizationMultiplier(
+        real_multiplier, out_qparams_);
+    requantization_multipliers_[0] = requantization_params_[0].real_multiplier;
+  }
 
-    int32_t C_zero_point = out_qparams_.zero_point;
+  int32_t C_zero_point = out_qparams_.zero_point;
 
-    T* Ydata = Y->template mutable_data<T>();
+  T* Ydata = Y->template mutable_data<T>();
 
-    using namespace fbgemm;
-    if (is_same<T, uint8_t>::value && GetCpuId().avx2()) {
+  using namespace fbgemm;
+  if (is_same<T, uint8_t>::value && GetCpuId().avx2()) {
 #ifdef _OPENMP
 #pragma omp parallel for
 #endif
-      for (int i = 0; i < N * Y_HxW; ++i) {
-        for (int group_id = 0; group_id < group_; ++group_id) {
-          int32_t row_offset;
-          row_offsets_u8acc32_ref(
-              1,
-              kernel_dim,
-              group_ * kernel_dim,
-              reinterpret_cast<const uint8_t*>(
-                  col_buffer_quantized_data +
-                  (i * group_ + group_id) * kernel_dim),
-              &row_offset);
-
-          int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
-          float C_multiplier = RequantizationParams(group_id).real_multiplier;
-
-          requantize_u8acc32_ref(
-              1,
-              M / group_,
-              M,
-              Y_int32 + i * M + group_id * (M / group_),
-              reinterpret_cast<uint8_t*>(
-                  Ydata + i * M + group_id * (M / group_)),
-              &C_multiplier,
-              C_zero_point,
-              A_zero_point,
-              &B_zero_point,
-              &row_offset,
-              column_offsets_->data() + group_id * (M / group_),
-              b_quantized_data_ ? b_quantized_data_ + group_id * (M / group_)
-                                : nullptr,
-              M / group_,
-              ReluFused);
-        } // for each group
-      } // for each row i
-    } else {
+    for (int i = 0; i < N * Y_HxW; ++i) {
+      for (int group_id = 0; group_id < group_; ++group_id) {
+        int32_t row_offset;
+        row_offsets_u8acc32_ref(
+            1,
+            kernel_dim,
+            group_ * kernel_dim,
+            reinterpret_cast<const uint8_t*>(
+                col_buffer_data + (i * group_ + group_id) * kernel_dim),
+            &row_offset);
+
+        int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
+        float C_multiplier = RequantizationParams(group_id).real_multiplier;
+
+        requantize_u8acc32_ref(
+            1,
+            M / group_,
+            M,
+            Y_int32 + i * M + group_id * (M / group_),
+            reinterpret_cast<uint8_t*>(Ydata + i * M + group_id * (M / group_)),
+            &C_multiplier,
+            C_zero_point,
+            A_zero_point,
+            &B_zero_point,
+            &row_offset,
+            column_offsets_->data() + group_id * (M / group_),
+            b_quantized_data_ ? b_quantized_data_ + group_id * (M / group_)
+                              : nullptr,
+            M / group_,
+            ReluFused);
+      } // for each group
+    } // for each row i
+  } else {
 #ifdef _OPENMP
 #pragma omp parallel for
 #endif
-      for (int i = 0; i < N * Y_HxW; ++i) {
-        for (int group_id = 0; group_id < group_; ++group_id) {
-          int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
-          int32_t row_offset = 0;
-          for (int k = 0; k < kernel_dim; ++k) {
-            row_offset += col_buffer_quantized_data
-                [(i * group_ + group_id) * kernel_dim + k];
-          }
-          row_offset *= B_zero_point;
+    for (int i = 0; i < N * Y_HxW; ++i) {
+      for (int group_id = 0; group_id < group_; ++group_id) {
+        int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
+        int32_t row_offset = 0;
+        for (int k = 0; k < kernel_dim; ++k) {
+          row_offset +=
+              col_buffer_data[(i * group_ + group_id) * kernel_dim + k];
+        }
+        row_offset *= B_zero_point;
 
-          for (int j = group_id * (M / group_);
-               j < (group_id + 1) * (M / group_);
-               ++j) {
-            int32_t raw = Y_int32[i * M + j] -
-                A_zero_point * (*column_offsets_)[j] - row_offset;
-            if (b_quantized_data_) {
-              raw += b_quantized_data_[j];
-            }
+        for (int j = group_id * (M / group_); j < (group_id + 1) * (M / group_);
+             ++j) {
+          int32_t raw = Y_int32[i * M + j] -
+              A_zero_point * (*column_offsets_)[j] - row_offset;
+          if (b_quantized_data_) {
+            raw += b_quantized_data_[j];
+          }
 
+          Ydata[i * M + j] =
+              fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
+          if (ReluFused) { // static if
             Ydata[i * M + j] =
-                fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
-            if (ReluFused) { // static if
-              Ydata[i * M + j] =
-                  std::max<int32_t>(C_zero_point, Ydata[i * M + j]);
-            }
+                std::max<int32_t>(C_zero_point, Ydata[i * M + j]);
           }
-        } // for each group
-      } // for each row i
-    } // !__AVX2__
+        }
+      } // for each group
+    } // for each row i
+  } // !__AVX2__
 
-    dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-  }
+  dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
 }
 
 template <typename T, bool ReluFused>
@@ -964,8 +836,7 @@ void ConvDNNLowPOp<T, ReluFused>::PartitionGroupedNHWCConv_(
 }
 
 template <typename T, bool ReluFused>
-template <typename InType>
-const InType* ConvDNNLowPOp<T, ReluFused>::Im2ColNHWC_(Tensor* col_buffer) {
+const T* ConvDNNLowPOp<T, ReluFused>::Im2ColNHWC_(Tensor* col_buffer) {
   const Tensor& X = InputTensorCPU_(INPUT);
   Tensor* Y = OutputTensorCPU_(0);
   int ndim = X.dim();
@@ -978,7 +849,7 @@ const InType* ConvDNNLowPOp<T, ReluFused>::Im2ColNHWC_(Tensor* col_buffer) {
   const int input_offset = X_HxW * C;
   const int Y_HxW = this->GetDimsSize(*Y);
 
-  const InType* Xdata = X.template data<InType>();
+  const T* Xdata = X.template data<T>();
 
   vector<int> buffer_shape(ndim);
   for (auto i = 0; i < ndim - 1; ++i) {
@@ -988,14 +859,14 @@ const InType* ConvDNNLowPOp<T, ReluFused>::Im2ColNHWC_(Tensor* col_buffer) {
 
   col_buffer->Resize(buffer_shape);
 
-  InType* col_buffer_data = col_buffer->template mutable_data<InType>();
+  T* col_buffer_data = col_buffer->template mutable_data<T>();
 
 #ifdef _OPENMP
 #pragma omp parallel for if (N > 1)
 #endif
   for (int image_id = 0; image_id < N; ++image_id) {
     if (this->kernel_.size() <= 2) {
-      math::Im2ColNHWC<InType>(
+      math::Im2ColNHWC<T>(
           C,
           X.dim32(1),
           this->kernel_.size() == 2 ? X.dim32(2) : 1,
@@ -1013,9 +884,9 @@ const InType* ConvDNNLowPOp<T, ReluFused>::Im2ColNHWC_(Tensor* col_buffer) {
           col_buffer_data + image_id * group_ * kernel_dim * Y_HxW,
           &context_,
           group_,
-          X.IsType<T>() ? in_qparams_[INPUT].zero_point : 0);
+          in_qparams_[INPUT].zero_point);
     } else {
-      math::Im2Col3DNHWC<InType>(
+      math::Im2Col3DNHWC<T>(
           C,
           X.dim32(1), // num_frames
           X.dim32(2), // H
@@ -1039,11 +910,11 @@ const InType* ConvDNNLowPOp<T, ReluFused>::Im2ColNHWC_(Tensor* col_buffer) {
           col_buffer_data + image_id * group_ * kernel_dim * Y_HxW,
           &context_,
           group_,
-          X.IsType<T>() ? in_qparams_[INPUT].zero_point : 0);
+          in_qparams_[INPUT].zero_point);
     }
   }
 
-  return col_buffer->template data<InType>();
+  return col_buffer->template data<T>();
 }
 
 template <typename T, typename T_signed>
@@ -1077,8 +948,8 @@ template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
 void ConvDNNLowPOp<T, ReluFused>::DispatchFBGEMM_(
     PackAMatrix& packA,
     vector<int32_t>* Y_int32,
-    uint8_t* Y_uint8_data,
-    float* Y_float_data) {
+    uint8_t* Y_uint8_data) {
+  // This function is called within an OpenMP region
   auto& filter = InputTensorCPU_(FILTER);
   const int M = filter.dim32(0);
 
@@ -1086,60 +957,33 @@ void ConvDNNLowPOp<T, ReluFused>::DispatchFBGEMM_(
   int tid = dnnlowp_get_thread_num();
 
   using namespace fbgemm;
-  if (Y_uint8_data) {
-    DoNothing<> doNothingObj{};
-    ReQuantizeOutput<ReluFused, Q_GRAN> outputProcObj(
-        doNothingObj,
-        requantization_multipliers_.data(),
-        out_qparams_.zero_point,
-        in_qparams_[INPUT].zero_point,
-        filter_zero_points_.data(),
-        packA.getRowOffsetBuffer(),
-        column_offsets_->data(),
-        InputSize() == 3 ? b_quantized_data_ : nullptr,
-        M,
-        group_);
+  DoNothing<> doNothingObj{};
+  ReQuantizeOutput<ReluFused, Q_GRAN> outputProcObj(
+      doNothingObj,
+      requantization_multipliers_.data(),
+      out_qparams_.zero_point,
+      in_qparams_[INPUT].zero_point,
+      filter_zero_points_.data(),
+      packA.getRowOffsetBuffer(),
+      column_offsets_->data(),
+      InputSize() == 3 ? b_quantized_data_ : nullptr,
+      M,
+      group_);
 
-    fbgemmPacked(
-        packA,
-        *Wq_packed_,
-        Y_uint8_data,
-        Y_int32->data(),
-        M,
-        outputProcObj,
-        tid,
-        nthreads);
-  } else {
-    DoNothing<float, float> doNothingObj{};
-    ReQuantizeForFloat<ReluFused, Q_GRAN> outputProcObj(
-        doNothingObj,
-        in_qparams_[INPUT].scale,
-        filter_scales_.data(),
-        in_qparams_[INPUT].zero_point,
-        filter_zero_points_.data(),
-        packA.getRowOffsetBuffer(),
-        column_offsets_->data(),
-        InputSize() == 3 ? b_dequantized_data_ : nullptr,
-        M,
-        group_);
-
-    fbgemmPacked(
-        packA,
-        *Wq_packed_,
-        Y_float_data,
-        reinterpret_cast<int32_t*>(Y_float_data),
-        M,
-        outputProcObj,
-        tid,
-        nthreads);
-  }
+  fbgemmPacked(
+      packA,
+      *Wq_packed_,
+      Y_uint8_data,
+      Y_int32->data(),
+      M,
+      outputProcObj,
+      tid,
+      nthreads);
 }
 
 template <typename T, bool ReluFused>
-template <typename InType>
 void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
-    const InType* col_buffer_data,
-    const T* col_buffer_quantized_data,
+    const T* col_buffer_data,
     vector<int32_t>* Y_int32) {
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
@@ -1154,7 +998,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
     StoreMatrixInMatrixMarketFormat(
         N * Y_HxW * group_,
         kernel_dim,
-        col_buffer_quantized_data,
+        col_buffer_data,
         OperatorBase::debug_def().input(INPUT));
 
     // Dump weight
@@ -1166,7 +1010,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
   }
 
   if (TakeDepthWise3x3x3FastPath_()) {
-    const InType* Xdata = X.template data<InType>();
+    const T* Xdata = X.template data<T>();
     uint8_t* Y_uint8_data =
         OutputTensorCPU_(0)->template mutable_data<uint8_t>();
 
@@ -1198,7 +1042,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
     return;
   } else if (TakeDepthWise3x3FastPath_()) {
     const int H = X.dim32(1), W = X.dim32(2);
-    const InType* Xdata = X.template data<InType>();
+    const T* Xdata = X.template data<T>();
     uint8_t* Y_uint8_data =
         OutputTensorCPU_(0)->template mutable_data<uint8_t>();
 
@@ -1231,8 +1075,8 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
   using namespace fbgemm;
   int row_offset_size_per_thread = -1;
   int x_pack_buf_size_per_thread = -1;
-  bool fuse_im2col = Wq_packed_ && X.template IsType<T>() &&
-      X.template data<T>() == col_buffer_quantized_data && !IsConvGEMM_();
+  bool fuse_im2col =
+      Wq_packed_ && X.template data<T>() == col_buffer_data && !IsConvGEMM_();
   if (Wq_packed_) {
     if (fuse_im2col) {
       row_offset_size_per_thread =
@@ -1248,15 +1092,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
     X_pack_buf_.resize(dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
   }
 
-  uint8_t* Y_uint8_data = nullptr;
-  float* Y_float_data = nullptr;
-  if (dequantize_output_) {
-    // Output is float
-    Y_float_data = Y->template mutable_data<float>();
-  } else {
-    // Output is uint8_t
-    Y_uint8_data = Y->template mutable_data<uint8_t>();
-  }
+  uint8_t* Y_uint8_data = Y->template mutable_data<uint8_t>();
 
   if (Wq_packed_)
 #ifdef _OPENMP
@@ -1285,7 +1121,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
 
         PackAWithIm2Col<uint8_t> packA(
             conv_p,
-            reinterpret_cast<const uint8_t*>(col_buffer_quantized_data),
+            reinterpret_cast<const uint8_t*>(col_buffer_data),
             // buffer for packed matrix
             X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
             in_qparams_[INPUT].zero_point,
@@ -1294,13 +1130,11 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
         if (quantize_groupwise_) {
           DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t>,
-              QuantizationGranularity::GROUP>(
-              packA, Y_int32, Y_uint8_data, Y_float_data);
+              QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
         } else {
           DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t>,
-              QuantizationGranularity::TENSOR>(
-              packA, Y_int32, Y_uint8_data, Y_float_data);
+              QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
         }
       } else {
         // 3D
@@ -1321,7 +1155,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
 
         PackAWithIm2Col<uint8_t, int32_t, 3> packA(
             conv_p,
-            reinterpret_cast<const uint8_t*>(col_buffer_quantized_data),
+            reinterpret_cast<const uint8_t*>(col_buffer_data),
             // buffer for packed matrix
             X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
             in_qparams_[INPUT].zero_point,
@@ -1330,13 +1164,11 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
         if (quantize_groupwise_) {
           DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t, int32_t, 3>,
-              QuantizationGranularity::GROUP>(
-              packA, Y_int32, Y_uint8_data, Y_float_data);
+              QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
         } else {
           DispatchFBGEMM_<
               PackAWithIm2Col<uint8_t, int32_t, 3>,
-              QuantizationGranularity::TENSOR>(
-              packA, Y_int32, Y_uint8_data, Y_float_data);
+              QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
         }
       } // 3D
     } else {
@@ -1345,7 +1177,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
           matrix_op_t::NoTranspose,
           N * Y_HxW,
           group_ * kernel_dim,
-          reinterpret_cast<const uint8_t*>(col_buffer_quantized_data),
+          reinterpret_cast<const uint8_t*>(col_buffer_data),
           group_ * kernel_dim,
           // buffer for packed matrix
           X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
@@ -1355,13 +1187,11 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
       if (quantize_groupwise_) {
         DispatchFBGEMM_<
             PackAWithRowOffset<uint8_t>,
-            QuantizationGranularity::GROUP>(
-            packA, Y_int32, Y_uint8_data, Y_float_data);
+            QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
       } else {
         DispatchFBGEMM_<
             PackAWithRowOffset<uint8_t>,
-            QuantizationGranularity::TENSOR>(
-            packA, Y_int32, Y_uint8_data, Y_float_data);
+            QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
       }
     } // no im2col fusion
   } else {
@@ -1374,7 +1204,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
           N * Y_HxW,
           M,
           kernel_dim,
-          col_buffer_quantized_data,
+          col_buffer_data,
           W_quantized_.data(),
           Y_int32->data());
     }
@@ -1382,8 +1212,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
 }
 
 template <typename T, bool ReluFused>
-template <typename InType>
-bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
+bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
   CAFFE_ENFORCE_LE(
       this->kernel_.size(),
       3,
@@ -1415,7 +1244,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
   const Tensor& X = InputTensorCPU_(INPUT);
   auto& filter = InputTensorCPU_(FILTER);
   Tensor* Y = OutputTensorCPU_(0);
-  const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
+  const int C = X.dim32(X.dim() - 1);
   const int G = group_;
   CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
   const int M = filter.dim32(0);
@@ -1432,10 +1261,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
       M % G, 0, "The number of output channels is not divisible by group.");
 
   ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
-  // The dimension of each kernel
-  const int kernel_dim = KernelDim_();
-  // The output image size is the spatial size of the output.
-  const int Y_HxW = this->GetDimsSize(*Y);
+
   // The col buffer is stored in HWC order as well - kernel_dim, and the height
   // and width.
 
@@ -1451,9 +1277,8 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
 
     // Im2col, followed by gemm.
     auto f2 = [&](Tensor* col_buffer_) {
-      const InType* Xdata = X.template data<InType>();
-      const InType* col_buffer_data =
-          no_im2col ? Xdata : Im2ColNHWC_<InType>(col_buffer_);
+      const T* Xdata = X.template data<T>();
+      const T* col_buffer_data = no_im2col ? Xdata : Im2ColNHWC_(col_buffer_);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
       /*if (VLOG_IS_ON(3)) */ {
@@ -1464,21 +1289,6 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
       }
 #endif
 
-      // quantize col_buffer
-      const T* col_buffer_quantized_data = nullptr;
-      vector<T> col_buffer_quantized;
-      if (X.template IsType<T>()) {
-        col_buffer_quantized_data = reinterpret_cast<const T*>(col_buffer_data);
-      } else {
-        col_buffer_quantized.resize(G * kernel_dim * Y_HxW * N);
-        fbgemm::Quantize<T>(
-            reinterpret_cast<const float*>(col_buffer_data),
-            col_buffer_quantized.data(),
-            col_buffer_quantized.size(),
-            in_qparams_[INPUT]);
-        col_buffer_quantized_data = col_buffer_quantized.data();
-      }
-
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
       /*if (VLOG_IS_ON(3)) */ {
         t_end = chrono::system_clock::now();
@@ -1489,7 +1299,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
       }
 #endif
 
-      ConvNHWCCore_(col_buffer_data, col_buffer_quantized_data, Y_int32);
+      ConvNHWCCore_(col_buffer_data, Y_int32);
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
       /*if (VLOG_IS_ON(3)) */ {
@@ -1504,11 +1314,9 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
           Wq_depthwise_3x3x3_packed_) {
         // In fast path with fbgemm except when
         // rescaling quantized numbers should've been already done.
-        if (!dequantize_output_) {
-          PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-        }
+        PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
       } else {
-        RunOnDeviceEpilogueNHWC_(col_buffer_quantized_data, Y_int32->data());
+        RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
       }
     }; // f2
 
@@ -1527,6 +1335,12 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
   /*if (VLOG_IS_ON(3)) */ {
+    const int N = X.dim32(0);
+    // The dimension of each kernel
+    const int kernel_dim = KernelDim_();
+    // The output image size is the spatial size of the output.
+    const int Y_HxW = this->GetDimsSize(*Y);
+
     t_end = chrono::system_clock::now();
     double dt = chrono::duration<double>(t_end - t_begin).count();
     LOG(INFO) << "this=" << this << " prologue: " << dt * 1e3 << " ms";
index 07bdae9..dfecc17 100644 (file)
@@ -26,11 +26,6 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   bool RunOnDeviceWithOrderNCHW() override;
   bool RunOnDeviceWithOrderNHWC() override;
 
-  template <typename InType>
-  bool RunOnDeviceWithOrderNCHWAndType_();
-  template <typename InType>
-  bool RunOnDeviceWithOrderNHWCAndType_();
-
   bool GetQuantizationParameters_();
 
   /**
@@ -41,8 +36,7 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   bool NoIm2ColNHWC_();
   int KernelDim_();
 
-  template <typename InType>
-  const InType* Im2ColNHWC_(Tensor* col_buffer);
+  const T* Im2ColNHWC_(Tensor* col_buffer);
 
   dnnlowp::TensorQuantizationParams& FilterQuantizationParams(int group_id);
   dnnlowp::RequantizationParams& RequantizationParams(int group_id);
@@ -83,15 +77,14 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
 
   std::vector<std::uint8_t> X_pack_buf_;
 
-  template <typename OutType>
   void RunOnDeviceEpilogueNCHW_(
-      const T* col_buffer_quantized_data,
+      const T* col_buffer_data,
       std::int32_t* Y_int32,
-      OutType* Y_data,
+      T* Y_data,
       std::size_t i_offset,
       int group_id);
   void RunOnDeviceEpilogueNHWC_(
-      const T* col_buffer_quantized_data,
+      const T* col_buffer_data,
       std::int32_t* Y_int32);
 
   std::vector<std::int32_t> Y_int32_;
@@ -114,14 +107,9 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   void DispatchFBGEMM_(
       PackAMatrix& packA,
       vector<std::int32_t>* Y_int32,
-      uint8_t* Y_uint8_data,
-      float* Y_float_data);
+      uint8_t* Y_uint8_data);
 
-  template <typename InType>
-  void ConvNHWCCore_(
-      const InType* col_buffer_data,
-      const T* col_buffer_quantized_data,
-      vector<std::int32_t>* Y_int32);
+  void ConvNHWCCore_(const T* col_buffer_data, vector<std::int32_t>* Y_int32);
 
   std::vector<dnnlowp::RequantizationParams> requantization_params_;
 
@@ -136,11 +124,6 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   // pre-computed biases and offsets
   std::shared_ptr<std::vector<std::int32_t>> b_quantized_;
 
-  // Dequantized bias populated when input bias is quantized and
-  // dequantized_output_ == true
-  std::vector<float> b_dequantized_;
-  const float* b_dequantized_data_{nullptr};
-
   float in_qparams_scale_old_ = 0;
 }; // class ConvDNNLowPOp
 
index d5497fd..03d31b5 100644 (file)
@@ -31,8 +31,6 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group=st.integers(2, 16),
         batch_size=st.integers(1, 3),
         order=st.sampled_from(["NCHW", "NHWC"]),
-        in_quantized=st.booleans(),
-        out_quantized=st.booleans(),
         weight_quantized=st.booleans(),
         prepack_weight=st.booleans(),
         share_col_buffer=st.booleans(),
@@ -52,8 +50,6 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
-        in_quantized,
-        out_quantized,
         weight_quantized,
         prepack_weight,
         share_col_buffer,
@@ -94,8 +90,8 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
             init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
-            do_quantize = "DNNLOWP" in engine and in_quantized
-            do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_quantize = "DNNLOWP" in engine
+            do_dequantize = "DNNLOWP" in engine
             # If output scale/zp aren't set, it gets computed from ref fp32 op
             # in DNNLOWP, which isn't possible when we quantize input weights.
             # Make sure atleast one output is collected to compute output
@@ -159,7 +155,6 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
                 dilation=dilation,
                 pad=pad,
                 order=order,
-                dequantize_output=not do_dequantize,
                 shared_buffer=(1 if share_col_buffer else 0),
                 preserve_activation_sparsity=preserve_activation_sparsity,
                 preserve_weight_sparsity=preserve_weight_sparsity,
index f969a0a..44f7aad 100644 (file)
@@ -31,8 +31,6 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group=st.integers(2, 16),
         batch_size=st.integers(1, 3),
         order=st.sampled_from(["NCHW", "NHWC"]),
-        in_quantized=st.booleans(),
-        out_quantized=st.booleans(),
         share_col_buffer=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
@@ -50,8 +48,6 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
-        in_quantized,
-        out_quantized,
         share_col_buffer,
         preserve_activation_sparsity,
         preserve_weight_sparsity,
@@ -124,8 +120,8 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         for op_type, engine in op_engine_list:
             net = core.Net("test_net")
 
-            do_quantize = "DNNLOWP" in engine and in_quantized
-            do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_quantize = "DNNLOWP" in engine
+            do_dequantize = "DNNLOWP" in engine
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -147,7 +143,6 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 dilation=dilation,
                 pad=pad,
                 order=order,
-                dequantize_output=not do_dequantize,
                 shared_buffer=(1 if share_col_buffer else 0),
                 preserve_activation_sparsity=preserve_activation_sparsity,
                 preserve_weight_sparsity=preserve_weight_sparsity,
@@ -190,8 +185,6 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group=st.integers(2, 16),
         batch_size=st.integers(1, 3),
         order=st.sampled_from(["NHWC"]),
-        in_quantized=st.booleans(),
-        out_quantized=st.booleans(),
         prepack_weight=st.booleans(),
         nbits_in_non_outlier=st.sampled_from((0, 1, 6, 8)),
         share_col_buffer=st.booleans(),
@@ -209,8 +202,6 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
-        in_quantized,
-        out_quantized,
         prepack_weight,
         nbits_in_non_outlier,
         share_col_buffer,
@@ -269,8 +260,8 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
             init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
-            do_quantize = "DNNLOWP" in engine and in_quantized
-            do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_quantize = "DNNLOWP" in engine
+            do_dequantize = "DNNLOWP" in engine
             do_prepack_weight = "DNNLOWP" in engine and prepack_weight
 
             if do_quantize:
@@ -309,7 +300,6 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 dilation=dilation,
                 pad=pad,
                 order=order,
-                dequantize_output=not do_dequantize,
                 nbits_in_non_outlier=nbits_in_non_outlier,
                 shared_buffer=(1 if share_col_buffer else 0),
                 engine=engine,
index e3ed280..f582d9e 100644 (file)
@@ -28,8 +28,6 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group=st.integers(2, 16),
         batch_size=st.integers(1, 3),
         order=st.sampled_from(["NCHW", "NHWC"]),
-        in_quantized=st.booleans(),
-        out_quantized=st.booleans(),
         prepack_weight=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
@@ -47,8 +45,6 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
-        in_quantized,
-        out_quantized,
         prepack_weight,
         preserve_activation_sparsity,
         preserve_weight_sparsity,
@@ -88,8 +84,8 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
             init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
-            do_quantize = "DNNLOWP" in engine and in_quantized
-            do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_quantize = "DNNLOWP" in engine
+            do_dequantize = "DNNLOWP" in engine
             do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
@@ -135,7 +131,6 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
                 dilation=dilation,
                 pad=pad,
                 order=order,
-                dequantize_output=not do_dequantize,
                 preserve_activation_sparsity=preserve_activation_sparsity,
                 preserve_weight_sparsity=preserve_weight_sparsity,
                 engine=engine,
index 1b511af..10ee04a 100644 (file)
@@ -55,20 +55,11 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
   }
 
   TensorCPU* OutputTensorCPU_(int idx) {
-    if (dequantize_output_) {
-      return Output(idx);
-    } else {
-      return &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
-    }
+    return &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
   }
 
   T* GetQuantizedOutputData_() {
-    if (dequantize_output_) {
-      out_temp_.resize(Output(0)->size());
-      return out_temp_.data();
-    } else {
-      return OutputTensorCPU_(0)->template mutable_data<T>();
-    }
+    return OutputTensorCPU_(0)->template mutable_data<T>();
   }
 
   void MeasureQuantizationError_() {
@@ -104,26 +95,23 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
   }
 
   void RunOnDeviceEpilogue_() {
-    if (dequantize_output_) {
-      fbgemm::Dequantize<T>(
-          out_temp_.data(),
-          OutputTensorCPU_(0)->template mutable_data<float>(),
-          OutputTensorCPU_(0)->size(),
-          out_qparams_);
-    } else {
-      dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
-    }
+    dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
 
     MeasureQuantizationError_();
   }
 
   void ParseDNNLowPOperatorArguments_() {
     if (!arguments_parsed_) {
+      bool dequantize_output;
       dnnlowp::ParseDNNLowPOperatorArguments(
           this,
-          &dequantize_output_,
+          &dequantize_output,
           &measure_quantization_error_,
           &followed_by_);
+      CAFFE_ENFORCE_EQ(
+          dequantize_output,
+          false,
+          "Conv DNNLOWP operators don't support dequantize_output");
       arguments_parsed_ = true;
     }
   }
@@ -183,7 +171,7 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
     f(Y_int32);
   }
 
-  bool dequantize_output_{false}, measure_quantization_error_{false};
+  bool measure_quantization_error_{false};
   std::string followed_by_;
 
   std::vector<dnnlowp::TensorQuantizationParams> in_qparams_;
@@ -210,7 +198,6 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
   /* using override */ using BaseType::MeasureQuantizationError_;          \
   /* using override */ using BaseType::OutputTensorCPU_;                   \
   /* using override */ using BaseType::RunOnDeviceEpilogue_;               \
-  /* using override */ using BaseType::dequantize_output_;                 \
   /* using override */ using BaseType::followed_by_;                       \
   /* using override */ using BaseType::in_qparams_;                        \
   /* using override */ using BaseType::measure_quantization_error_;        \
index 72d4005..5690675 100644 (file)
@@ -45,7 +45,12 @@ namespace caffe2 {
  *        C2 operators with DNNLOWP engine have the following arguments:
  *        - dequantize_output (default=false): when true, output is dequantized
  *          as fp32. Useful when we're only quantizing individual operators
- *          rather than doing end-to-end quantization.
+ *          rather than doing end-to-end quantization. Conv operators don't
+            support dequantize_output option as an exception because doing so
+            complicate the implementation significantly and having a separate
+            Dequantize operator doesn't add much overhead because Conv ops are
+            usually used in deep networks where regions of quantization are
+            long chains.
  *        - followed_by (default=null): can be relu, sigmoid, or tanh. When
  *          specified, the current operator is only followed by relu, sigmoid,
  *          or tanh, and this fact can be used for more accurate output