Revert D14778810: [caffe2/int8] fix bug when falling back to acc32 when weight is...
authorJunjie Bai <jbai@fb.com>
Fri, 5 Apr 2019 20:56:34 +0000 (13:56 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 21:01:33 +0000 (14:01 -0700)
Differential Revision:
D14778810

Original commit changeset: d49a8c4b7c81

fbshipit-source-id: 15568b084848de74437582548bec42aadc74080d

caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py
caffe2/quantization/server/fbgemm_pack_op.cc

index e074b8b..454be17 100644 (file)
@@ -17,6 +17,7 @@
 C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
 C10_DECLARE_int32(caffe2_dnnlowp_copy_to_32bit_frequency);
 C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
+
 // Thresholds to fallback to 32-bit accumulation when 16-bit accumulation
 // doesn't provide performance benefits.
 C10_DEFINE_double(
@@ -61,8 +62,35 @@ ConvDNNLowPAcc16Op<ReluFused>::ConvDNNLowPAcc16Op(
 template <bool ReluFused>
 bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
   if (fallback_to_32_bit_accumulation_) {
-    // Short cut if we already know we are falling back to acc32
-    return BaseType::GetQuantizationParameters_();
+    return true;
+  }
+
+  if (!BaseType::GetQuantizationParameters_()) {
+    return false;
+  }
+
+  if (!Wq_acc16_packed_ &&
+      this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+    CAFFE_ENFORCE_EQ(
+        this->order_,
+        StorageOrder::NHWC,
+        "Pre-packed weight only works with NHWC layout");
+    // If the input is already packed
+    const auto& packed_filter =
+        this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+    Wq_outlier_ = packed_filter.W_outlier;
+    Wq_acc16_packed_ = packed_filter.W_acc16;
+
+    if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
+      LOG(WARNING)
+          << "nbits_in_non_outlier in packed weight "
+          << packed_filter.nbits_in_non_outlier
+          << " doesn't match with nbits_in_non_outlier specified in operator "
+          << nbits_in_non_outlier_;
+    }
+
+    first_invocation_ = false;
+    return true;
   }
 
   int kernel_dim = this->KernelDim_();
@@ -70,17 +98,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
   int num_out_channels = filter.dim32(0);
 
   // Check if we should fallback to 32-bit accumulation
-  // We should do this before GetQuantizationParameters_ to make sure
-  // GetQuantizationParameters_ initialize things like Wq_packed_ for acc32
-  // properly.
-
-  // We can't fallback if layout is not NHWC or
-  // if weight is prepacked and the prepacked weight doesn't have acc32.
-  bool can_fallback_to_32_bit_accumulation =
-      this->order_ == StorageOrder::NHWC &&
-      (!this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) ||
-       this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER).W);
-  if (can_fallback_to_32_bit_accumulation) {
+  if (this->order_ == StorageOrder::NHWC) {
     const Tensor& X = InputTensorCPU_(INPUT);
     int N = X.dim32(0);
 
@@ -103,71 +121,31 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
     }
 
     if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
-      LOG_FIRST_N(INFO, 10)
-          << "M " << N * output_image_size << " of Conv layer with weight blob "
-          << this->debug_def().input(FILTER) << " is smaller than threshold "
-          << FLAGS_caffe2_dnnlowp_acc16_m_threshold
-          << " . Falling back to acc32";
-      fallback_to_32_bit_accumulation_ = true;
-    }
-    if (!fallback_to_32_bit_accumulation_ &&
-        num_out_channels / group_ < acc16_n_threshold) {
-      LOG_FIRST_N(INFO, 10)
-          << "N " << num_out_channels / group_
-          << " of Conv layer with weight blob "
-          << this->debug_def().input(FILTER) << " is smaller than threshold "
-          << acc16_n_threshold << " . Falling back to acc32";
+      LOG(INFO) << "M " << N * output_image_size
+                << " of Conv layer with weight blob "
+                << this->debug_def().input(1) << " is smaller than threshold "
+                << FLAGS_caffe2_dnnlowp_acc16_m_threshold
+                << " . Falling back to acc32";
       fallback_to_32_bit_accumulation_ = true;
+      return true;
     }
-    if (!fallback_to_32_bit_accumulation_ && kernel_dim < acc16_k_threshold) {
-      LOG_FIRST_N(INFO, 10)
-          << "K " << kernel_dim << " of Conv layer with weight blob "
-          << this->debug_def().input(FILTER) << " is smaller than threshold "
-          << acc16_k_threshold << " . Falling back to acc32";
+    if (num_out_channels / group_ < acc16_n_threshold) {
+      LOG(INFO) << "N " << num_out_channels / group_
+                << " of Conv layer with weight blob "
+                << this->debug_def().input(1) << " is smaller than threshold "
+                << acc16_n_threshold << " . Falling back to acc32";
       fallback_to_32_bit_accumulation_ = true;
+      return true;
     }
-    if (!fallback_to_32_bit_accumulation_ &&
-        this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
-        !this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER)
-             .W_acc16) {
-      LOG_FIRST_N(INFO, 10)
-          << "Falling back to acc32 because packed weight for acc16 is not "
-             "available";
+    if (kernel_dim < acc16_k_threshold) {
+      LOG(INFO) << "K " << kernel_dim << " of Conv layer with weight blob "
+                << this->debug_def().input(1) << " is smaller than threshold "
+                << acc16_k_threshold << " . Falling back to acc32";
       fallback_to_32_bit_accumulation_ = true;
+      return true;
     }
   }
 
-  if (!BaseType::GetQuantizationParameters_()) {
-    return false;
-  }
-
-  if (fallback_to_32_bit_accumulation_) {
-    return true;
-  }
-
-  if (!Wq_acc16_packed_ &&
-      this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
-    CAFFE_ENFORCE_EQ(
-        this->order_,
-        StorageOrder::NHWC,
-        "Pre-packed weight only works with NHWC layout");
-    // If the input is already packed
-    const auto& packed_filter =
-        this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
-    Wq_outlier_ = packed_filter.W_outlier;
-    Wq_acc16_packed_ = packed_filter.W_acc16;
-
-    if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
-      LOG_FIRST_N(WARNING, 10)
-          << "nbits_in_non_outlier in packed weight "
-          << packed_filter.nbits_in_non_outlier
-          << " doesn't match with nbits_in_non_outlier specified in operator "
-          << nbits_in_non_outlier_;
-    }
-    first_invocation_ = false;
-    return true;
-  }
-
   // Separate out outliers
   if (!Wq_outlier_ && this->order_ == StorageOrder::NHWC &&
       nbits_in_non_outlier_ < 8) {
@@ -181,24 +159,20 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
         W_quantized_));
     int outlier_cnt = Wq_outlier_->ColPtr()[num_out_channels];
 
-    LOG_FIRST_N(INFO, 10)
-        << "Proportion of outlier for Conv layer with weight blob "
-        << this->debug_def().input(FILTER) << " is "
-        << static_cast<float>(outlier_cnt) / W_quantized_.size();
-    LOG_FIRST_N(INFO, 10) << "nbits_in_non_outlier " << nbits_in_non_outlier_
-                          << " copy_to_32bit_frequency "
-                          << copy_to_32bit_frequency_;
-
-    if (can_fallback_to_32_bit_accumulation &&
-        static_cast<float>(outlier_cnt) / W_quantized_.size() >
-            FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
-      LOG_FIRST_N(INFO, 10) << "Density of outliers is higher than threshold "
-                            << FLAGS_caffe2_dnnlowp_acc16_density_threshold
-                            << " . Falling back to acc32";
+    LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
+              << this->debug_def().input(1) << " is "
+              << static_cast<float>(outlier_cnt) / W_quantized_.size();
+    LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_
+              << " copy_to_32bit_frequency " << copy_to_32bit_frequency_;
+
+    if (static_cast<float>(outlier_cnt) / W_quantized_.size() >
+        FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
+      LOG(INFO) << "Density of outliers is higher than threshold "
+                << FLAGS_caffe2_dnnlowp_acc16_density_threshold
+                << " . Falling back to acc32";
       fallback_to_32_bit_accumulation_ = true;
       Wq_outlier_.reset();
-      // We need to call GetQuantizationParameters_ again to pack for acc32
-      return BaseType::GetQuantizationParameters_();
+      return true;
     }
   }
 
@@ -219,9 +193,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
         static int log_occurences = 0;
         if (log_occurences < 32) {
           ++log_occurences;
-          LOG_FIRST_N(WARNING, 10)
-              << "Conv with weight " << this->debug_def().input(FILTER)
-              << " falls back to slow path because " << reason;
+          LOG(WARNING) << "Conv with weight " << this->debug_def().input(FILTER)
+                       << " falls back to slow path because " << reason;
         }
       }
     }
@@ -229,8 +202,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
       static int log_occurences = 0;
       if (log_occurences < 32) {
         ++log_occurences;
-        LOG_FIRST_N(WARNING, 10) << "Outlier-aware quantization only supports "
-                                    "NHWC layout";
+        LOG(WARNING) << "Outlier-aware quantization only supports "
+                        "NHWC layout";
       }
     }
     first_invocation_ = false;
@@ -386,7 +359,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
         static int log_occurences = 0;
         if (log_occurences < 32) {
           ++log_occurences;
-          LOG_FIRST_N(WARNING, 10)
+          LOG(WARNING)
               << "Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since "
                  "we're falling back to a slow path because of NCHW layout";
         }
index 1ddf2ce..1da2f31 100644 (file)
@@ -7,19 +7,15 @@ import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep, utils, workspace
 from caffe2.quantization.server import utils as dnnlowp_utils
-from dnnlowp_test_utils import check_quantized_results_close
+from dnnlowp_test_utils import (
+    check_quantized_results_close,
+    generate_conv_inputs,
+)
 from hypothesis import assume, given
 
 
 dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
-workspace.GlobalInit(
-    [
-        "caffe2",
-        "--caffe2_omp_num_threads=11",
-        # Increase this threshold to test acc16 with randomly generated data
-        "--caffe2_dnnlowp_acc16_density_threshold=0.9",
-    ]
-)
+workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
 
 
 class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -258,7 +254,9 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
             W_min = -100
             W_max = W_min + 255
         W = (
-            np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
+            np.random.rand(
+                output_channels, kernel, kernel, input_channels_per_group
+            )
             * 4
             - 2
             + W_min
index d542126..44f7aad 100644 (file)
@@ -7,19 +7,15 @@ import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep, utils, workspace
 from caffe2.quantization.server import utils as dnnlowp_utils
-from dnnlowp_test_utils import check_quantized_results_close
+from dnnlowp_test_utils import (
+    check_quantized_results_close,
+    generate_conv_inputs,
+)
 from hypothesis import assume, given
 
 
 dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
-workspace.GlobalInit(
-    [
-        "caffe2",
-        "--caffe2_omp_num_threads=11",
-        # Increase this threshold to test acc16 with randomly generated data
-        "--caffe2_dnnlowp_acc16_density_threshold=0.9",
-    ]
-)
+workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
 
 
 class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
@@ -228,7 +224,9 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         W_min = -100
         W_max = W_min + 255
         W = (
-            np.random.rand(output_channels, kernel, kernel, input_channels_per_group)
+            np.random.rand(
+                output_channels, kernel, kernel, input_channels_per_group
+            )
             * 4
             - 2
             + W_min
@@ -239,7 +237,9 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         for g in range(group):
             W[g * output_channels_per_group, 0, 0, 0] = W_min
             W[g * output_channels_per_group + 1, 0, 0, 0] = W_max
-            W[g * output_channels_per_group : (g + 1) * output_channels_per_group,] += g
+            W[
+                g * output_channels_per_group : (g + 1) * output_channels_per_group,
+            ] += g
 
         if order == "NCHW":
             X = utils.NHWC2NCHW(X)
index 9e98b01..704d4e1 100644 (file)
@@ -5,9 +5,6 @@
 #include "caffe2_dnnlowp_utils.h"
 
 C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
-C10_DECLARE_double(caffe2_dnnlowp_acc16_density_threshold);
-C10_DECLARE_int32(caffe2_dnnlowp_acc16_n_threshold);
-C10_DECLARE_int32(caffe2_dnnlowp_acc16_k_threshold);
 
 namespace caffe2 {
 
@@ -425,44 +422,9 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() {
   ComputeColumnOffsets(
       kernel_dim, M, W_quantized.data(), Y->qparams, *Y->column_offsets);
 
-  // Check if we should fallback to 32-bit accumulation.
-  // This check is only meaningful when engine is DNNLOWP_ACC16.
-  bool fallback_to_32_bit_accumulation = false;
-  if (nbits_in_non_outlier_ == 0) {
-    LOG(INFO) << "nbits_in_non_outlier == 0 means everything is outlier so we "
-                 "fallback to acc32";
-    fallback_to_32_bit_accumulation = true;
-  }
-  // In Skylake, acc16 is not faster when N or K is smaller than 128
-  // FIXME : code duplication with conv_dnnlowp_acc16_op.cc
-  constexpr int SKYLAKE_ACC16_N_THRESHOLD_MIN = 128,
-                SKYLAKE_ACC16_K_THRESHOLD_MIN = 128;
-  int acc16_n_threshold = FLAGS_caffe2_dnnlowp_acc16_n_threshold;
-  if (caffe2::GetCpuId().avx512f() &&
-      acc16_n_threshold < SKYLAKE_ACC16_N_THRESHOLD_MIN) {
-    acc16_n_threshold = SKYLAKE_ACC16_N_THRESHOLD_MIN;
-  }
-  int acc16_k_threshold = FLAGS_caffe2_dnnlowp_acc16_k_threshold;
-  if (caffe2::GetCpuId().avx512f() &&
-      acc16_k_threshold < SKYLAKE_ACC16_K_THRESHOLD_MIN) {
-    acc16_k_threshold = SKYLAKE_ACC16_K_THRESHOLD_MIN;
-  }
-  if (!fallback_to_32_bit_accumulation && M / group_ < acc16_n_threshold) {
-    LOG(INFO) << "N " << M / group_ << " of weight blob "
-              << this->debug_def().input(0) << " is smaller than threshold "
-              << acc16_n_threshold << " . Falling back to acc32";
-    fallback_to_32_bit_accumulation = true;
-  }
-  if (!fallback_to_32_bit_accumulation && kernel_dim < acc16_k_threshold) {
-    LOG(INFO) << "K " << kernel_dim << " of weight blob "
-              << this->debug_def().input(0) << " is smaller than threshold "
-              << acc16_k_threshold << " . Falling back to acc32";
-    fallback_to_32_bit_accumulation = true;
-  }
-
   // When nbits_in_non_outlier == 0, we fall back to acc32
   if (this->debug_def().engine() == "DNNLOWP_ACC16" &&
-      !fallback_to_32_bit_accumulation) {
+      nbits_in_non_outlier_ > 0) {
     if (nbits_in_non_outlier_ < 8) {
       Y->W_outlier.reset(ExtractOutlierMatrix(
           group_, kernel_dim, M, nbits_in_non_outlier_, W_quantized));
@@ -472,66 +434,45 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() {
                 << this->debug_def().input(0) << " is "
                 << static_cast<float>(outlier_cnt) / W_quantized.size();
       LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_;
-
-      if (static_cast<float>(outlier_cnt) / W_quantized.size() >
-          FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
-        LOG(INFO) << "Density of outliers is higher than threshold "
-                  << FLAGS_caffe2_dnnlowp_acc16_density_threshold
-                  << " . Falling back to acc32";
-        fallback_to_32_bit_accumulation = true;
-      }
-    }
-
-    if (!fallback_to_32_bit_accumulation) {
-      Y->nbits_in_non_outlier = nbits_in_non_outlier_;
-      Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
-          fbgemm::matrix_op_t::Transpose,
-          group_ * kernel_dim,
-          M / group_,
-          W_quantized.data(),
-          kernel_dim,
-          nullptr, // pmat
-          group_));
     }
-  }
-
-  if (fallback_to_32_bit_accumulation) {
-    Y->W_acc16.reset();
-    Y->W_outlier.reset();
-  }
 
-  if (this->debug_def().engine() != "DNNLOWP_ACC16" ||
-      fallback_to_32_bit_accumulation) {
-    // acc32
-    if (TakeDepthWise3x3FastPath_()) {
-      Y->W_depthwise_3x3.reset(
-          new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data()));
-    } else if (TakeDepthWise3x3x3FastPath_()) {
-      Y->W_depthwise_3x3x3.reset(
-          new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
-    } else if (TakeGConvFastPath_()) {
-      fbgemm::conv_param_t<> conv_p(
-          1,
-          group_ * C_per_group,
-          M,
-          {1, 1},
-          group_,
-          {this->kernel_[0], this->kernel_[1]},
-          {this->stride_[0], this->stride_[1]},
-          {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
-
-      Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
-          fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data()));
-    } else {
-      Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
-          fbgemm::matrix_op_t::Transpose,
-          group_ * kernel_dim,
-          M / group_,
-          W_quantized.data(),
-          kernel_dim,
-          nullptr, // pmat
-          group_));
-    }
+    Y->nbits_in_non_outlier = nbits_in_non_outlier_;
+    Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
+        fbgemm::matrix_op_t::Transpose,
+        group_ * kernel_dim,
+        M / group_,
+        W_quantized.data(),
+        kernel_dim,
+        nullptr, // pmat
+        group_));
+  } else if (TakeDepthWise3x3FastPath_()) {
+    Y->W_depthwise_3x3.reset(
+        new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data()));
+  } else if (TakeDepthWise3x3x3FastPath_()) {
+    Y->W_depthwise_3x3x3.reset(
+        new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
+  } else if (TakeGConvFastPath_()) {
+    fbgemm::conv_param_t<> conv_p(
+        1,
+        group_ * C_per_group,
+        M,
+        {1, 1},
+        group_,
+        {this->kernel_[0], this->kernel_[1]},
+        {this->stride_[0], this->stride_[1]},
+        {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
+
+    Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
+        fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data()));
+  } else {
+    Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
+        fbgemm::matrix_op_t::Transpose,
+        group_ * kernel_dim,
+        M / group_,
+        W_quantized.data(),
+        kernel_dim,
+        nullptr, // pmat
+        group_));
   }
 
   if (InputSize() >= 2) {