From 28990f34d989da98e2cce3e38d7842049735fb97 Mon Sep 17 00:00:00 2001 From: Summer Deng Date: Fri, 5 Apr 2019 12:44:09 -0700 Subject: [PATCH] fix bug when falling back to acc32 when weight is prepacked (#18881) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18881 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18878 When the weight is prepacked and it doesn't contain a prepacked weight for acc32, we shouldn't fallback to acc32. TODO: add unit tests with better coverage Reviewed By: feiyu1990 Differential Revision: D14778810 fbshipit-source-id: d49a8c4b7c815ab29b77feb53ee730ad63780488 --- .../quantization/server/conv_dnnlowp_acc16_op.cc | 157 ++++++++++++--------- .../server/conv_dnnlowp_acc16_op_test.py | 18 +-- .../server/conv_groupwise_dnnlowp_acc16_op_test.py | 22 +-- caffe2/quantization/server/fbgemm_pack_op.cc | 135 +++++++++++++----- 4 files changed, 210 insertions(+), 122 deletions(-) diff --git a/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc b/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc index 454be17..e074b8b 100644 --- a/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc +++ b/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc @@ -17,7 +17,6 @@ C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier); C10_DECLARE_int32(caffe2_dnnlowp_copy_to_32bit_frequency); C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer); - // Thresholds to fallback to 32-bit accumulation when 16-bit accumulation // doesn't provide performance benefits. C10_DEFINE_double( @@ -62,35 +61,8 @@ ConvDNNLowPAcc16Op::ConvDNNLowPAcc16Op( template bool ConvDNNLowPAcc16Op::GetQuantizationParameters_() { if (fallback_to_32_bit_accumulation_) { - return true; - } - - if (!BaseType::GetQuantizationParameters_()) { - return false; - } - - if (!Wq_acc16_packed_ && - this->template InputIsType(FILTER)) { - CAFFE_ENFORCE_EQ( - this->order_, - StorageOrder::NHWC, - "Pre-packed weight only works with NHWC layout"); - // If the input is already packed - const auto& packed_filter = - this->template Input(FILTER); - Wq_outlier_ = packed_filter.W_outlier; - Wq_acc16_packed_ = packed_filter.W_acc16; - - if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) { - LOG(WARNING) - << "nbits_in_non_outlier in packed weight " - << packed_filter.nbits_in_non_outlier - << " doesn't match with nbits_in_non_outlier specified in operator " - << nbits_in_non_outlier_; - } - - first_invocation_ = false; - return true; + // Short cut if we already know we are falling back to acc32 + return BaseType::GetQuantizationParameters_(); } int kernel_dim = this->KernelDim_(); @@ -98,7 +70,17 @@ bool ConvDNNLowPAcc16Op::GetQuantizationParameters_() { int num_out_channels = filter.dim32(0); // Check if we should fallback to 32-bit accumulation - if (this->order_ == StorageOrder::NHWC) { + // We should do this before GetQuantizationParameters_ to make sure + // GetQuantizationParameters_ initialize things like Wq_packed_ for acc32 + // properly. + + // We can't fallback if layout is not NHWC or + // if weight is prepacked and the prepacked weight doesn't have acc32. + bool can_fallback_to_32_bit_accumulation = + this->order_ == StorageOrder::NHWC && + (!this->template InputIsType(FILTER) || + this->template Input(FILTER).W); + if (can_fallback_to_32_bit_accumulation) { const Tensor& X = InputTensorCPU_(INPUT); int N = X.dim32(0); @@ -121,31 +103,71 @@ bool ConvDNNLowPAcc16Op::GetQuantizationParameters_() { } if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) { - LOG(INFO) << "M " << N * output_image_size - << " of Conv layer with weight blob " - << this->debug_def().input(1) << " is smaller than threshold " - << FLAGS_caffe2_dnnlowp_acc16_m_threshold - << " . Falling back to acc32"; + LOG_FIRST_N(INFO, 10) + << "M " << N * output_image_size << " of Conv layer with weight blob " + << this->debug_def().input(FILTER) << " is smaller than threshold " + << FLAGS_caffe2_dnnlowp_acc16_m_threshold + << " . Falling back to acc32"; + fallback_to_32_bit_accumulation_ = true; + } + if (!fallback_to_32_bit_accumulation_ && + num_out_channels / group_ < acc16_n_threshold) { + LOG_FIRST_N(INFO, 10) + << "N " << num_out_channels / group_ + << " of Conv layer with weight blob " + << this->debug_def().input(FILTER) << " is smaller than threshold " + << acc16_n_threshold << " . Falling back to acc32"; fallback_to_32_bit_accumulation_ = true; - return true; } - if (num_out_channels / group_ < acc16_n_threshold) { - LOG(INFO) << "N " << num_out_channels / group_ - << " of Conv layer with weight blob " - << this->debug_def().input(1) << " is smaller than threshold " - << acc16_n_threshold << " . Falling back to acc32"; + if (!fallback_to_32_bit_accumulation_ && kernel_dim < acc16_k_threshold) { + LOG_FIRST_N(INFO, 10) + << "K " << kernel_dim << " of Conv layer with weight blob " + << this->debug_def().input(FILTER) << " is smaller than threshold " + << acc16_k_threshold << " . Falling back to acc32"; fallback_to_32_bit_accumulation_ = true; - return true; } - if (kernel_dim < acc16_k_threshold) { - LOG(INFO) << "K " << kernel_dim << " of Conv layer with weight blob " - << this->debug_def().input(1) << " is smaller than threshold " - << acc16_k_threshold << " . Falling back to acc32"; + if (!fallback_to_32_bit_accumulation_ && + this->template InputIsType(FILTER) && + !this->template Input(FILTER) + .W_acc16) { + LOG_FIRST_N(INFO, 10) + << "Falling back to acc32 because packed weight for acc16 is not " + "available"; fallback_to_32_bit_accumulation_ = true; - return true; } } + if (!BaseType::GetQuantizationParameters_()) { + return false; + } + + if (fallback_to_32_bit_accumulation_) { + return true; + } + + if (!Wq_acc16_packed_ && + this->template InputIsType(FILTER)) { + CAFFE_ENFORCE_EQ( + this->order_, + StorageOrder::NHWC, + "Pre-packed weight only works with NHWC layout"); + // If the input is already packed + const auto& packed_filter = + this->template Input(FILTER); + Wq_outlier_ = packed_filter.W_outlier; + Wq_acc16_packed_ = packed_filter.W_acc16; + + if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) { + LOG_FIRST_N(WARNING, 10) + << "nbits_in_non_outlier in packed weight " + << packed_filter.nbits_in_non_outlier + << " doesn't match with nbits_in_non_outlier specified in operator " + << nbits_in_non_outlier_; + } + first_invocation_ = false; + return true; + } + // Separate out outliers if (!Wq_outlier_ && this->order_ == StorageOrder::NHWC && nbits_in_non_outlier_ < 8) { @@ -159,20 +181,24 @@ bool ConvDNNLowPAcc16Op::GetQuantizationParameters_() { W_quantized_)); int outlier_cnt = Wq_outlier_->ColPtr()[num_out_channels]; - LOG(INFO) << "Proportion of outlier for Conv layer with weight blob " - << this->debug_def().input(1) << " is " - << static_cast(outlier_cnt) / W_quantized_.size(); - LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_ - << " copy_to_32bit_frequency " << copy_to_32bit_frequency_; - - if (static_cast(outlier_cnt) / W_quantized_.size() > - FLAGS_caffe2_dnnlowp_acc16_density_threshold) { - LOG(INFO) << "Density of outliers is higher than threshold " - << FLAGS_caffe2_dnnlowp_acc16_density_threshold - << " . Falling back to acc32"; + LOG_FIRST_N(INFO, 10) + << "Proportion of outlier for Conv layer with weight blob " + << this->debug_def().input(FILTER) << " is " + << static_cast(outlier_cnt) / W_quantized_.size(); + LOG_FIRST_N(INFO, 10) << "nbits_in_non_outlier " << nbits_in_non_outlier_ + << " copy_to_32bit_frequency " + << copy_to_32bit_frequency_; + + if (can_fallback_to_32_bit_accumulation && + static_cast(outlier_cnt) / W_quantized_.size() > + FLAGS_caffe2_dnnlowp_acc16_density_threshold) { + LOG_FIRST_N(INFO, 10) << "Density of outliers is higher than threshold " + << FLAGS_caffe2_dnnlowp_acc16_density_threshold + << " . Falling back to acc32"; fallback_to_32_bit_accumulation_ = true; Wq_outlier_.reset(); - return true; + // We need to call GetQuantizationParameters_ again to pack for acc32 + return BaseType::GetQuantizationParameters_(); } } @@ -193,8 +219,9 @@ bool ConvDNNLowPAcc16Op::GetQuantizationParameters_() { static int log_occurences = 0; if (log_occurences < 32) { ++log_occurences; - LOG(WARNING) << "Conv with weight " << this->debug_def().input(FILTER) - << " falls back to slow path because " << reason; + LOG_FIRST_N(WARNING, 10) + << "Conv with weight " << this->debug_def().input(FILTER) + << " falls back to slow path because " << reason; } } } @@ -202,8 +229,8 @@ bool ConvDNNLowPAcc16Op::GetQuantizationParameters_() { static int log_occurences = 0; if (log_occurences < 32) { ++log_occurences; - LOG(WARNING) << "Outlier-aware quantization only supports " - "NHWC layout"; + LOG_FIRST_N(WARNING, 10) << "Outlier-aware quantization only supports " + "NHWC layout"; } } first_invocation_ = false; @@ -359,7 +386,7 @@ bool ConvDNNLowPAcc16Op::RunOnDeviceWithOrderNCHW() { static int log_occurences = 0; if (log_occurences < 32) { ++log_occurences; - LOG(WARNING) + LOG_FIRST_N(WARNING, 10) << "Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since " "we're falling back to a slow path because of NCHW layout"; } diff --git a/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py b/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py index 1da2f31..1ddf2ce 100644 --- a/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py +++ b/caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py @@ -7,15 +7,19 @@ import hypothesis.strategies as st import numpy as np from caffe2.python import core, dyndep, utils, workspace from caffe2.quantization.server import utils as dnnlowp_utils -from dnnlowp_test_utils import ( - check_quantized_results_close, - generate_conv_inputs, -) +from dnnlowp_test_utils import check_quantized_results_close from hypothesis import assume, given dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops") -workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"]) +workspace.GlobalInit( + [ + "caffe2", + "--caffe2_omp_num_threads=11", + # Increase this threshold to test acc16 with randomly generated data + "--caffe2_dnnlowp_acc16_density_threshold=0.9", + ] +) class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): @@ -254,9 +258,7 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): W_min = -100 W_max = W_min + 255 W = ( - np.random.rand( - output_channels, kernel, kernel, input_channels_per_group - ) + np.random.rand(output_channels, kernel, kernel, input_channels_per_group) * 4 - 2 + W_min diff --git a/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py b/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py index 44f7aad..d542126 100644 --- a/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py +++ b/caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py @@ -7,15 +7,19 @@ import hypothesis.strategies as st import numpy as np from caffe2.python import core, dyndep, utils, workspace from caffe2.quantization.server import utils as dnnlowp_utils -from dnnlowp_test_utils import ( - check_quantized_results_close, - generate_conv_inputs, -) +from dnnlowp_test_utils import check_quantized_results_close from hypothesis import assume, given dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops") -workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"]) +workspace.GlobalInit( + [ + "caffe2", + "--caffe2_omp_num_threads=11", + # Increase this threshold to test acc16 with randomly generated data + "--caffe2_dnnlowp_acc16_density_threshold=0.9", + ] +) class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): @@ -224,9 +228,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): W_min = -100 W_max = W_min + 255 W = ( - np.random.rand( - output_channels, kernel, kernel, input_channels_per_group - ) + np.random.rand(output_channels, kernel, kernel, input_channels_per_group) * 4 - 2 + W_min @@ -237,9 +239,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase): for g in range(group): W[g * output_channels_per_group, 0, 0, 0] = W_min W[g * output_channels_per_group + 1, 0, 0, 0] = W_max - W[ - g * output_channels_per_group : (g + 1) * output_channels_per_group, - ] += g + W[g * output_channels_per_group : (g + 1) * output_channels_per_group,] += g if order == "NCHW": X = utils.NHWC2NCHW(X) diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc index 704d4e1..9e98b01 100644 --- a/caffe2/quantization/server/fbgemm_pack_op.cc +++ b/caffe2/quantization/server/fbgemm_pack_op.cc @@ -5,6 +5,9 @@ #include "caffe2_dnnlowp_utils.h" C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier); +C10_DECLARE_double(caffe2_dnnlowp_acc16_density_threshold); +C10_DECLARE_int32(caffe2_dnnlowp_acc16_n_threshold); +C10_DECLARE_int32(caffe2_dnnlowp_acc16_k_threshold); namespace caffe2 { @@ -422,9 +425,44 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() { ComputeColumnOffsets( kernel_dim, M, W_quantized.data(), Y->qparams, *Y->column_offsets); + // Check if we should fallback to 32-bit accumulation. + // This check is only meaningful when engine is DNNLOWP_ACC16. + bool fallback_to_32_bit_accumulation = false; + if (nbits_in_non_outlier_ == 0) { + LOG(INFO) << "nbits_in_non_outlier == 0 means everything is outlier so we " + "fallback to acc32"; + fallback_to_32_bit_accumulation = true; + } + // In Skylake, acc16 is not faster when N or K is smaller than 128 + // FIXME : code duplication with conv_dnnlowp_acc16_op.cc + constexpr int SKYLAKE_ACC16_N_THRESHOLD_MIN = 128, + SKYLAKE_ACC16_K_THRESHOLD_MIN = 128; + int acc16_n_threshold = FLAGS_caffe2_dnnlowp_acc16_n_threshold; + if (caffe2::GetCpuId().avx512f() && + acc16_n_threshold < SKYLAKE_ACC16_N_THRESHOLD_MIN) { + acc16_n_threshold = SKYLAKE_ACC16_N_THRESHOLD_MIN; + } + int acc16_k_threshold = FLAGS_caffe2_dnnlowp_acc16_k_threshold; + if (caffe2::GetCpuId().avx512f() && + acc16_k_threshold < SKYLAKE_ACC16_K_THRESHOLD_MIN) { + acc16_k_threshold = SKYLAKE_ACC16_K_THRESHOLD_MIN; + } + if (!fallback_to_32_bit_accumulation && M / group_ < acc16_n_threshold) { + LOG(INFO) << "N " << M / group_ << " of weight blob " + << this->debug_def().input(0) << " is smaller than threshold " + << acc16_n_threshold << " . Falling back to acc32"; + fallback_to_32_bit_accumulation = true; + } + if (!fallback_to_32_bit_accumulation && kernel_dim < acc16_k_threshold) { + LOG(INFO) << "K " << kernel_dim << " of weight blob " + << this->debug_def().input(0) << " is smaller than threshold " + << acc16_k_threshold << " . Falling back to acc32"; + fallback_to_32_bit_accumulation = true; + } + // When nbits_in_non_outlier == 0, we fall back to acc32 if (this->debug_def().engine() == "DNNLOWP_ACC16" && - nbits_in_non_outlier_ > 0) { + !fallback_to_32_bit_accumulation) { if (nbits_in_non_outlier_ < 8) { Y->W_outlier.reset(ExtractOutlierMatrix( group_, kernel_dim, M, nbits_in_non_outlier_, W_quantized)); @@ -434,45 +472,66 @@ bool ConvDNNLowPPackWeightOp::RunOnDevice() { << this->debug_def().input(0) << " is " << static_cast(outlier_cnt) / W_quantized.size(); LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_; + + if (static_cast(outlier_cnt) / W_quantized.size() > + FLAGS_caffe2_dnnlowp_acc16_density_threshold) { + LOG(INFO) << "Density of outliers is higher than threshold " + << FLAGS_caffe2_dnnlowp_acc16_density_threshold + << " . Falling back to acc32"; + fallback_to_32_bit_accumulation = true; + } } - Y->nbits_in_non_outlier = nbits_in_non_outlier_; - Y->W_acc16.reset(new fbgemm::PackBMatrix( - fbgemm::matrix_op_t::Transpose, - group_ * kernel_dim, - M / group_, - W_quantized.data(), - kernel_dim, - nullptr, // pmat - group_)); - } else if (TakeDepthWise3x3FastPath_()) { - Y->W_depthwise_3x3.reset( - new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data())); - } else if (TakeDepthWise3x3x3FastPath_()) { - Y->W_depthwise_3x3x3.reset( - new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data())); - } else if (TakeGConvFastPath_()) { - fbgemm::conv_param_t<> conv_p( - 1, - group_ * C_per_group, - M, - {1, 1}, - group_, - {this->kernel_[0], this->kernel_[1]}, - {this->stride_[0], this->stride_[1]}, - {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]}); - - Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv( - fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data())); - } else { - Y->W.reset(new fbgemm::PackBMatrix( - fbgemm::matrix_op_t::Transpose, - group_ * kernel_dim, - M / group_, - W_quantized.data(), - kernel_dim, - nullptr, // pmat - group_)); + if (!fallback_to_32_bit_accumulation) { + Y->nbits_in_non_outlier = nbits_in_non_outlier_; + Y->W_acc16.reset(new fbgemm::PackBMatrix( + fbgemm::matrix_op_t::Transpose, + group_ * kernel_dim, + M / group_, + W_quantized.data(), + kernel_dim, + nullptr, // pmat + group_)); + } + } + + if (fallback_to_32_bit_accumulation) { + Y->W_acc16.reset(); + Y->W_outlier.reset(); + } + + if (this->debug_def().engine() != "DNNLOWP_ACC16" || + fallback_to_32_bit_accumulation) { + // acc32 + if (TakeDepthWise3x3FastPath_()) { + Y->W_depthwise_3x3.reset( + new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data())); + } else if (TakeDepthWise3x3x3FastPath_()) { + Y->W_depthwise_3x3x3.reset( + new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data())); + } else if (TakeGConvFastPath_()) { + fbgemm::conv_param_t<> conv_p( + 1, + group_ * C_per_group, + M, + {1, 1}, + group_, + {this->kernel_[0], this->kernel_[1]}, + {this->stride_[0], this->stride_[1]}, + {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]}); + + Y->W_gconv.reset(new fbgemm::PackWeightMatrixForGConv( + fbgemm::matrix_op_t::Transpose, conv_p, W_quantized.data())); + } else { + Y->W.reset(new fbgemm::PackBMatrix( + fbgemm::matrix_op_t::Transpose, + group_ * kernel_dim, + M / group_, + W_quantized.data(), + kernel_dim, + nullptr, // pmat + group_)); + } } if (InputSize() >= 2) { -- 2.7.4