use fbgemm gconv in dnnlowp (#16020)
authorJongsoo Park <jongsoo@fb.com>
Tue, 15 Jan 2019 07:59:33 +0000 (23:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 08:02:31 +0000 (00:02 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16020

Needs to go over more iterations. For conv, I think we need a high level interface that abstracts out low-level details of which code path will be taken (acc16, outlier-aware, depth-wise, group conv, ...) otherwise the client code will be complex as can be seen from DNNLOWP Conv ops. This will also help us to make interface more stable.

Reviewed By: dskhudia, jianyuh

Differential Revision: D13588996

fbshipit-source-id: 9afce9e441bcaf20437fcc2874fb9d4165a46bcb

caffe2/quantization/server/conv_dnnlowp_op.cc
caffe2/quantization/server/conv_dnnlowp_op.h
caffe2/quantization/server/fbgemm_pack_blob.h
caffe2/quantization/server/fbgemm_pack_op.cc
caffe2/quantization/server/fbgemm_pack_op.h

index 7a18c7a..f4de399 100644 (file)
@@ -80,9 +80,8 @@ ConvDNNLowPOp<T, ReluFused>::RequantizationParams(int group_id) {
 template <typename T, bool ReluFused>
 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
-  return StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
-      is_same<T, uint8_t>::value && !Acc16() &&
-      group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
+  return this->order_ == StorageOrder::NHWC && is_same<T, uint8_t>::value &&
+      !Acc16() && group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 2 && kernel_h() == 3 && kernel_w() == 3 &&
       stride_h() == stride_w() && (stride_h() == 1 || stride_h() == 2) &&
       dilation_h() == 1 && dilation_w() == 1 && pad_t() == 1 && pad_b() == 1 &&
@@ -94,9 +93,8 @@ bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3FastPath_() {
 template <typename T, bool ReluFused>
 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
-  bool ret = StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
-      is_same<T, uint8_t>::value && !Acc16() &&
-      group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
+  bool ret = this->order_ == StorageOrder::NHWC && is_same<T, uint8_t>::value &&
+      !Acc16() && group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 3 && this->kernel_[0] == 3 &&
       this->kernel_[1] == 3 && this->kernel_[2] == 3 &&
       this->stride_[0] == this->stride_[1] &&
@@ -111,6 +109,30 @@ bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
 }
 
 template <typename T, bool ReluFused>
+bool ConvDNNLowPOp<T, ReluFused>::TakeGConvFastPath_() {
+  const Tensor& X = InputTensorCPU_(INPUT);
+  if (this->order_ != StorageOrder::NHWC || !is_same<T, uint8_t>::value ||
+      !X.template IsType<T>() || this->kernel_.size() != 2) {
+    return false;
+  }
+
+  auto& filter = InputTensorCPU_(FILTER);
+  const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
+  const int M = filter.dim32(0);
+  fbgemm::conv_param_t<> conv_p(
+      N,
+      C,
+      M,
+      {X.dim32(1), X.dim32(2)},
+      group_,
+      {this->kernel_[0], this->kernel_[1]},
+      {this->stride_[0], this->stride_[1]},
+      {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+  return fbgemm::fbgemmOptimizedGConv(conv_p);
+}
+
+template <typename T, bool ReluFused>
 int ConvDNNLowPOp<T, ReluFused>::KernelDim_() {
   int kernel_dim;
   const Tensor& X = InputTensorCPU_(INPUT);
@@ -156,7 +178,8 @@ bool ConvDNNLowPOp<T, ReluFused>::IsConvGEMM_() const {
 
 template <typename T, bool ReluFused>
 bool ConvDNNLowPOp<T, ReluFused>::NoIm2ColNHWC_() {
-  if (TakeDepthWise3x3FastPath_() || TakeDepthWise3x3x3FastPath_()) {
+  if (TakeDepthWise3x3FastPath_() || TakeDepthWise3x3x3FastPath_() ||
+      TakeGConvFastPath_()) {
     return true;
   }
 
@@ -267,18 +290,23 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
       !Acc16() && is_same<T, uint8_t>::value && GetCpuId().avx2() &&
       !FLAGS_caffe2_dnnlowp_force_slow_path;
 
-  bool depthwise_3x3_fast_path = false, depthwise_3x3x3_fast_path = false;
+  bool depthwise_3x3_fast_path = false, depthwise_3x3x3_fast_path = false,
+       gconv_fast_path = false;
   if (TakeDepthWise3x3FastPath_()) {
     depthwise_3x3_fast_path = true;
     packW = false;
   } else if (TakeDepthWise3x3x3FastPath_()) {
     depthwise_3x3x3_fast_path = true;
     packW = false;
+  } else if (TakeGConvFastPath_()) {
+    gconv_fast_path = true;
+    packW = false;
   }
 
   if ((depthwise_3x3_fast_path && !Wq_depthwise_3x3_packed_) ||
       (depthwise_3x3x3_fast_path && !Wq_depthwise_3x3x3_packed_) ||
-      (packW && !Wq_packed_) || (!packW && W_quantized_.empty())) {
+      (gconv_fast_path && !Wq_gconv_packed_) || (packW && !Wq_packed_) ||
+      (!packW && W_quantized_.empty())) {
     if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
       CAFFE_ENFORCE_EQ(
           ConvPoolOpBase<CPUContext>::order_,
@@ -337,6 +365,30 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
         Wq_depthwise_3x3x3_packed_.reset(new fbgemm::Packed3x3x3ConvMatrix(
             group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
       }
+    } else if (gconv_fast_path) {
+      if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+        const auto& packed_filter =
+            this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+        Wq_gconv_packed_ = packed_filter.W_gconv;
+      } else {
+        const Tensor& X = InputTensorCPU_(INPUT);
+        const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
+
+        fbgemm::conv_param_t<> conv_p(
+            N,
+            C,
+            M,
+            {X.dim32(1), X.dim32(2)},
+            group_,
+            {this->kernel_[0], this->kernel_[1]},
+            {this->stride_[0], this->stride_[1]},
+            {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+        Wq_gconv_packed_.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
+            fbgemm::matrix_op_t::Transpose,
+            conv_p,
+            reinterpret_cast<const int8_t*>(W_quantized_.data())));
+      }
     } else if (packW) {
       if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
         const auto& packed_filter =
@@ -1000,6 +1052,8 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
         this->debug_def().input(FILTER));
   }
 
+  using namespace fbgemm;
+
   if (TakeDepthWise3x3x3FastPath_()) {
     const T* Xdata = X.template data<T>();
     uint8_t* Y_uint8_data =
@@ -1010,7 +1064,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
 #endif
     {
       if (quantize_groupwise_) {
-        fbgemm::depthwise_3x3x3_per_channel_quantization_pad_1(
+        depthwise_3x3x3_per_channel_quantization_pad_1(
             N,
             X.dim32(1),
             X.dim32(2),
@@ -1032,7 +1086,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
             dnnlowp_get_thread_num(),
             dnnlowp_get_num_threads());
       } else {
-        fbgemm::depthwise_3x3x3_pad_1(
+        depthwise_3x3x3_pad_1(
             N,
             X.dim32(1),
             X.dim32(2),
@@ -1068,7 +1122,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
 #endif
     {
       if (quantize_groupwise_) {
-        fbgemm::depthwise_3x3_per_channel_quantization_pad_1(
+        depthwise_3x3_per_channel_quantization_pad_1(
             N,
             H,
             W,
@@ -1088,7 +1142,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
             dnnlowp_get_thread_num(),
             dnnlowp_get_num_threads());
       } else {
-        fbgemm::depthwise_3x3_pad_1(
+        depthwise_3x3_pad_1(
             N,
             H,
             W,
@@ -1111,10 +1165,88 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
     } // omp parallel
 
     return;
+  } else if (TakeGConvFastPath_()) {
+    const T* Xdata = X.template data<T>();
+    uint8_t* Y_uint8_data =
+        OutputTensorCPU_(0)->template mutable_data<uint8_t>();
+
+    conv_param_t<> conv_p(
+        N,
+        C,
+        M,
+        {X.dim32(1), X.dim32(2)},
+        group_,
+        {this->kernel_[0], this->kernel_[1]},
+        {this->stride_[0], this->stride_[1]},
+        {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+    int row_offset_size_per_thread = rowOffsetBufferSizeGConv(conv_p);
+    row_offsets_.resize(dnnlowp_get_max_threads() * row_offset_size_per_thread);
+
+#ifdef _OPENMP
+// TODO: add parallelization once fbgemmGroupwiseConv supports multi-threading
+// #pragma omp parallel
+#endif
+    {
+      int tid = 0; // dnnlowp_get_thread_num();
+      int nthreads = 1; // dnnlowp_get_num_threads();
+
+      DoNothing<> doNothingObj{};
+      if (quantize_groupwise_) {
+        ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj(
+            doNothingObj,
+            requantization_multipliers_.data(),
+            out_qparams_.zero_point,
+            in_qparams_[INPUT].zero_point,
+            filter_zero_points_.data(),
+            row_offsets_.data() + tid * row_offset_size_per_thread,
+            column_offsets_->data(),
+            InputSize() == 3 ? b_quantized_data_ : nullptr,
+            conv_p.OC,
+            conv_p.G);
+
+        fbgemmGroupwiseConv(
+            conv_p,
+            reinterpret_cast<const uint8_t*>(Xdata),
+            in_qparams_[INPUT].zero_point,
+            row_offsets_.data() + tid * row_offset_size_per_thread,
+            *Wq_gconv_packed_,
+            Y_uint8_data,
+            Y_int32->data(),
+            reqObj,
+            tid,
+            nthreads);
+      } else {
+        ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj(
+            doNothingObj,
+            requantization_multipliers_.data(),
+            out_qparams_.zero_point,
+            in_qparams_[INPUT].zero_point,
+            filter_zero_points_.data(),
+            row_offsets_.data() + tid * row_offset_size_per_thread,
+            column_offsets_->data(),
+            InputSize() == 3 ? b_quantized_data_ : nullptr,
+            conv_p.OC,
+            conv_p.G);
+
+        fbgemmGroupwiseConv(
+            conv_p,
+            reinterpret_cast<const uint8_t*>(Xdata),
+            in_qparams_[INPUT].zero_point,
+            row_offsets_.data() + tid * row_offset_size_per_thread,
+            *Wq_gconv_packed_,
+            Y_uint8_data,
+            Y_int32->data(),
+            reqObj,
+            tid,
+            nthreads);
+      }
+    } // omp parallel
+
+    return;
   }
 
   // Normal path for non-special (e.g., no depth-wise) convolutions.
-  using namespace fbgemm;
   int row_offset_size_per_thread = -1;
   int x_pack_buf_size_per_thread = -1;
   bool fuse_im2col =
@@ -1351,7 +1483,8 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
     }
 #endif
 
-    if (Wq_packed_ || Wq_depthwise_3x3_packed_ || Wq_depthwise_3x3x3_packed_) {
+    if (Wq_packed_ || Wq_depthwise_3x3_packed_ || Wq_depthwise_3x3x3_packed_ ||
+        Wq_gconv_packed_) {
       // In fast path with fbgemm except when
       // rescaling quantized numbers should've been already done.
       PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
index dfecc17..c8b6574 100644 (file)
@@ -102,6 +102,7 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
 
   bool TakeDepthWise3x3FastPath_();
   bool TakeDepthWise3x3x3FastPath_();
+  bool TakeGConvFastPath_();
 
   template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
   void DispatchFBGEMM_(
@@ -120,6 +121,9 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   std::shared_ptr<fbgemm::Packed3x3ConvMatrix> Wq_depthwise_3x3_packed_;
   // For depthwise 3x3x3 conv
   std::shared_ptr<fbgemm::Packed3x3x3ConvMatrix> Wq_depthwise_3x3x3_packed_;
+  // For small gconv
+  std::shared_ptr<fbgemm::PackWeightMatrixForGConv<std::int8_t>>
+      Wq_gconv_packed_;
 
   // pre-computed biases and offsets
   std::shared_ptr<std::vector<std::int32_t>> b_quantized_;
index 56396cc..52b221a 100644 (file)
@@ -39,6 +39,7 @@ struct Int8ConvDNNLowPPackedWeightBlob : public Int8FCDNNLowPPackedWeightBlob {
   // Only for 32-bit accumulation
   std::shared_ptr<fbgemm::Packed3x3ConvMatrix> W_depthwise_3x3;
   std::shared_ptr<fbgemm::Packed3x3x3ConvMatrix> W_depthwise_3x3x3;
+  std::shared_ptr<fbgemm::PackWeightMatrixForGConv<std::int8_t>> W_gconv;
 };
 
 } // namespace caffe2
index f697029..e0852ba 100644 (file)
@@ -349,6 +349,27 @@ bool ConvDNNLowPPackWeightOp::TakeDepthWise3x3x3FastPath_() {
   return ret;
 }
 
+bool ConvDNNLowPPackWeightOp::TakeGConvFastPath_() {
+  if (this->debug_def().engine() == "DNNLOWP_ACC16" ||
+      this->kernel_.size() != 2) {
+    return false;
+  }
+
+  auto& filter = InputTensorCPU_(FILTER);
+  const int M = filter.dim32(0), C = filter.dim32(filter.dim() - 1) * group_;
+  fbgemm::conv_param_t<> conv_p(
+      1,
+      C,
+      M,
+      {1, 1},
+      group_,
+      {this->kernel_[0], this->kernel_[1]},
+      {this->stride_[0], this->stride_[1]},
+      {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+  return fbgemm::fbgemmOptimizedGConv(conv_p);
+}
+
 bool ConvDNNLowPPackWeightOp::RunOnDevice() {
   const auto& filter = InputTensorCPU_(FILTER);
 
@@ -426,6 +447,19 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() {
   } else if (TakeDepthWise3x3x3FastPath_()) {
     Y->W_depthwise_3x3x3.reset(
         new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
+  } else if (TakeGConvFastPath_()) {
+    fbgemm::conv_param_t<> conv_p(
+        1,
+        group_ * C_per_group,
+        M,
+        {1, 1},
+        group_,
+        {this->kernel_[0], this->kernel_[1]},
+        {this->stride_[0], this->stride_[1]},
+        {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+    Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
+        fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data()));
   } else {
     Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
         fbgemm::matrix_op_t::Transpose,
index af6203b..8d28711 100644 (file)
@@ -53,6 +53,7 @@ class ConvDNNLowPPackWeightOp final
  private:
   bool TakeDepthWise3x3FastPath_();
   bool TakeDepthWise3x3x3FastPath_();
+  bool TakeGConvFastPath_();
 
   bool quantize_groupwise_;
   int nbits_in_non_outlier_; // only for DNNLOWP_ACC16