From: Jongsoo Park Date: Wed, 6 Feb 2019 23:14:17 +0000 (-0800) Subject: int8 SpatialBN (#16796) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1439 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8105aaca8611acd9e33707d6b57c6b1f144e4ab4;p=platform%2Fupstream%2Fpytorch.git int8 SpatialBN (#16796) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16796 SpatialBN int8 version Reviewed By: dskhudia Differential Revision: D13971224 fbshipit-source-id: e55fd608c161069daaa4e62c618bc14b01f32cb7 --- diff --git a/caffe2/quantization/server/CMakeLists.txt b/caffe2/quantization/server/CMakeLists.txt index 8aedc5a..b21eab5 100644 --- a/caffe2/quantization/server/CMakeLists.txt +++ b/caffe2/quantization/server/CMakeLists.txt @@ -37,6 +37,7 @@ list(APPEND Caffe2_CPU_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quantize_dnnlowp_op.cc" "${CMAKE_CURRENT_SOURCE_DIR}/relu_dnnlowp_op.cc" "${CMAKE_CURRENT_SOURCE_DIR}/sigmoid_dnnlowp_op.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/spatial_batch_norm_dnnlowp_op.cc" "${CMAKE_CURRENT_SOURCE_DIR}/tanh_dnnlowp_op.cc" "${CMAKE_CURRENT_SOURCE_DIR}/utility_dnnlowp_ops.cc" "${CMAKE_CURRENT_SOURCE_DIR}/dynamic_histogram.cc" diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc new file mode 100644 index 0000000..414089f --- /dev/null +++ b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc @@ -0,0 +1,141 @@ +#include "caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h" + +#include "caffe2/quantization/server/caffe2_dnnlowp_utils.h" + +namespace caffe2 { + +template +SpatialBNDNNLowPOp::SpatialBNDNNLowPOp( + const OperatorDef& operator_def, + Workspace* ws) + : DNNLowPOp>(operator_def, ws), + OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5), + order_(StringToStorageOrder( + this->template GetSingleArgument("order", "NCHW"))) { + bool is_test = this->template GetSingleArgument("is_test", false); + OPERATOR_NEEDS_FEATURE( + is_test, "SpatialBN DNNLOWP op only works for inference."); + CAFFE_ENFORCE_NE( + order_, + StorageOrder::UNKNOWN, + "order should be either \"NCHW\" or \"NHWC\"."); + CAFFE_ENFORCE(OutputSize() == 1); + CAFFE_ENFORCE_GT(epsilon_, 0); +} + +template +void SpatialBNDNNLowPOp::ComputeFusedParam_( + const int C, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* alpha, + float* beta) { + EigenVectorArrayMap alpha_arr(alpha, C); + EigenVectorArrayMap beta_arr(beta, C); + alpha_arr = ConstEigenVectorArrayMap(scale, C) * + (ConstEigenVectorArrayMap(var, C) + epsilon_).rsqrt(); + beta_arr = ConstEigenVectorArrayMap(bias, C) - + alpha_arr * ConstEigenVectorArrayMap(mean, C); + + // Adjust alpha and beta considering quantization scales + alpha_arr = alpha_arr * (in_qparams_[0].scale / out_qparams_.scale); + beta_arr = beta_arr / out_qparams_.scale; +} + +template +bool SpatialBNDNNLowPOp::RunOnDevice() { + const auto& X = InputTensorCPU_(INPUT); + const auto& scale = Input(SCALE); + const auto& bias = Input(BIAS); + + const int ndim = X.dim(); + CAFFE_ENFORCE_GE(ndim, 3); + const int N = X.dim32(0); + const int C = (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1)); + const std::vector X_dims(X.sizes().cbegin(), X.sizes().cend()); + const int HxW = + std::accumulate( + X_dims.cbegin() + 1, X_dims.cend(), 1, std::multiplies()) / + C; + CAFFE_ENFORCE_EQ(scale.numel(), C); + CAFFE_ENFORCE_EQ(bias.numel(), C); + + GetOutputQuantizationParams_(); + + in_qparams_[0] = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get()); + + const float* scale_data = scale.template data(); + const float* bias_data = bias.template data(); + ReinitializeTensor( + &alpha_, {C}, at::dtype().device(CPUContext::GetDeviceType())); + ReinitializeTensor( + &beta_, {C}, at::dtype().device(CPUContext::GetDeviceType())); + float* alpha_data = alpha_.template mutable_data(); + float* beta_data = beta_.template mutable_data(); + if (N == 0) { + return true; + } + const auto& mean = Input(EST_MEAN); + const auto& var = Input(EST_VAR); + CAFFE_ENFORCE_EQ(mean.numel(), C); + CAFFE_ENFORCE_EQ(var.numel(), C); + ComputeFusedParam_( + C, + scale_data, + bias_data, + mean.template data(), + var.template data(), + alpha_data, + beta_data); + + vector X_temp; + const T* X_data = + dnnlowp::QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp); + auto* Y = OutputTensorCPU_(OUTPUT); + Y->Resize(X.sizes()); + T* Y_data = GetQuantizedOutputData_(); + + if (order_ == StorageOrder::NCHW) { + for (int c = 0; c < C; ++c) { + for (int i = 0; i < N; ++i) { + for (int j = 0; j < HxW; ++j) { + long quantized_down = out_qparams_.zero_point + + std::lrintf(alpha_data[c] * + (X_data[(i * C + c) * HxW + j] - + in_qparams_[0].zero_point) + + beta_data[c]); + Y_data[(i * C + c) * HxW + j] = + fbgemm::clamp(quantized_down, 8); + } + } + } + } else { + for (int i = 0; i < N * HxW; ++i) { + for (int c = 0; c < C; ++c) { + long quantized_down = out_qparams_.zero_point + + std::lrintf(alpha_data[c] * + (X_data[i * C + c] - in_qparams_[0].zero_point) + + beta_data[c]); + Y_data[i * C + c] = fbgemm::clamp(quantized_down, 8); + } + } + } + + RunOnDeviceEpilogue_(); + + return true; +} + +REGISTER_CPU_OPERATOR_WITH_ENGINE( + SpatialBN, + DNNLOWP, + SpatialBNDNNLowPOp); + +REGISTER_CPU_OPERATOR_WITH_ENGINE( + Int8SpatialBN, + DNNLOWP, + SpatialBNDNNLowPOp); + +} // namespace caffe2 diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h new file mode 100644 index 0000000..076e91d --- /dev/null +++ b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h @@ -0,0 +1,43 @@ +#pragma once + +#include "caffe2/operators/spatial_batch_norm_op.h" +#include "caffe2/quantization/server/dnnlowp_op.h" + +namespace caffe2 { + +/** + * Note this implementation assumes SCALE, BIAS, EST_MEAN, and EST_VAR inputs + * are still in fp32, so is epsilon argument + */ +template +class SpatialBNDNNLowPOp final : public DNNLowPOp> { + public: + USE_OPERATOR_FUNCTIONS(CPUContext); + USE_DNNLOWP_OPERATOR_BASE_FUNCTIONS(T, SpatialBNOp); + SpatialBNDNNLowPOp(const OperatorDef& operator_def, Workspace* ws); + + virtual ~SpatialBNDNNLowPOp() override = default; + + bool RunOnDevice() override; + + private: + void ComputeFusedParam_( + const int C, + const float* scale, + const float* bias, + const float* mean, + const float* var, + float* alpha, + float* beta); + + double epsilon_; + const StorageOrder order_; + + Tensor alpha_; + Tensor beta_; + + INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR); + OUTPUT_TAGS(OUTPUT); +}; + +} // namespace caffe2 diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py new file mode 100644 index 0000000..5dbcfd7 --- /dev/null +++ b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py @@ -0,0 +1,110 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections + +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np +from caffe2.python import core, dyndep, utils, workspace +from caffe2.quantization.server import utils as dnnlowp_utils +from dnnlowp_test_utils import check_quantized_results_close +from hypothesis import given + + +dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops") +workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"]) + + +class DNNLowPOpSpatialBNTest(hu.HypothesisTestCase): + # correctness test with no quantization error in inputs + @given( + size=st.integers(10, 16), + input_channels=st.integers(2, 16), + output_channels=st.integers(2, 16), + batch_size=st.integers(1, 3), + order=st.sampled_from(["NCHW", "NHWC"]), + in_quantized=st.booleans(), + out_quantized=st.booleans(), + **hu.gcs_cpu_only + ) + def test_dnnlowp_spatial_bn_int( + self, + size, + input_channels, + output_channels, + batch_size, + order, + in_quantized, + out_quantized, + gc, + dc, + ): + X_min = -77 + X_max = X_min + 255 + X = np.round(np.random.rand(batch_size, size, size, input_channels)).astype( + np.float32 + ) + X[0, 0, 0, 0] = X_min + X[0, 0, 0, 1] = X_max + + epsilon = np.abs(np.random.rand()) + scale = np.random.rand(input_channels).astype(np.float32) + bias = np.random.rand(input_channels).astype(np.float32) + mean = np.random.rand(input_channels).astype(np.float32) + var = np.random.rand(input_channels).astype(np.float32) + + if order == "NCHW": + X = utils.NHWC2NCHW(X) + + Output = collections.namedtuple("Output", ["Y", "op_type", "engine"]) + outputs = [] + + op_engine_list = [ + ("SpatialBN", ""), + ("SpatialBN", "DNNLOWP"), + ("Int8SpatialBN", "DNNLOWP"), + ] + + for op_type, engine in op_engine_list: + net = core.Net("test_net") + + do_quantize = "DNNLOWP" in engine and in_quantized + do_dequantize = "DNNLOWP" in engine and out_quantized + + if do_quantize: + quantize = core.CreateOperator( + "Quantize", ["X"], ["X_q"], engine=engine + ) + net.Proto().op.extend([quantize]) + + bn = core.CreateOperator( + op_type, + ["X_q" if do_quantize else "X", "scale", "bias", "mean", "var"], + ["Y_q" if do_dequantize else "Y"], + is_test=True, + epsilon=epsilon, + order=order, + engine=engine, + dequantize_output=not do_dequantize, + ) + net.Proto().op.extend([bn]) + if "DNNLOWP" in engine: + dnnlowp_utils.add_quantization_param_args(bn, outputs[0][0]) + + if do_dequantize: + dequantize = core.CreateOperator( + "Dequantize", ["Y_q"], ["Y"], engine=engine + ) + net.Proto().op.extend([dequantize]) + + self.ws.create_blob("X").feed(X, device_option=gc) + self.ws.create_blob("scale").feed(scale, device_option=gc) + self.ws.create_blob("bias").feed(bias, device_option=gc) + self.ws.create_blob("mean").feed(mean, device_option=gc) + self.ws.create_blob("var").feed(var, device_option=gc) + self.ws.run(net) + outputs.append( + Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine) + ) + + check_quantized_results_close(outputs)