int8 SpatialBN (#16796)
authorJongsoo Park <jongsoo@fb.com>
Wed, 6 Feb 2019 23:14:17 +0000 (15:14 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Feb 2019 23:32:01 +0000 (15:32 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16796

SpatialBN int8 version

Reviewed By: dskhudia

Differential Revision: D13971224

fbshipit-source-id: e55fd608c161069daaa4e62c618bc14b01f32cb7

caffe2/quantization/server/CMakeLists.txt
caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc [new file with mode: 0644]
caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h [new file with mode: 0644]
caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py [new file with mode: 0644]

index 8aedc5a..b21eab5 100644 (file)
@@ -37,6 +37,7 @@ list(APPEND Caffe2_CPU_SRCS
   "${CMAKE_CURRENT_SOURCE_DIR}/quantize_dnnlowp_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/relu_dnnlowp_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/sigmoid_dnnlowp_op.cc"
+  "${CMAKE_CURRENT_SOURCE_DIR}/spatial_batch_norm_dnnlowp_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/tanh_dnnlowp_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/utility_dnnlowp_ops.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/dynamic_histogram.cc"
diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc
new file mode 100644 (file)
index 0000000..414089f
--- /dev/null
@@ -0,0 +1,141 @@
+#include "caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h"
+
+#include "caffe2/quantization/server/caffe2_dnnlowp_utils.h"
+
+namespace caffe2 {
+
+template <typename T>
+SpatialBNDNNLowPOp<T>::SpatialBNDNNLowPOp(
+    const OperatorDef& operator_def,
+    Workspace* ws)
+    : DNNLowPOp<T, SpatialBNOp<CPUContext>>(operator_def, ws),
+      OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
+      order_(StringToStorageOrder(
+          this->template GetSingleArgument<std::string>("order", "NCHW"))) {
+  bool is_test = this->template GetSingleArgument<bool>("is_test", false);
+  OPERATOR_NEEDS_FEATURE(
+      is_test, "SpatialBN DNNLOWP op only works for inference.");
+  CAFFE_ENFORCE_NE(
+      order_,
+      StorageOrder::UNKNOWN,
+      "order should be either \"NCHW\" or \"NHWC\".");
+  CAFFE_ENFORCE(OutputSize() == 1);
+  CAFFE_ENFORCE_GT(epsilon_, 0);
+}
+
+template <typename T>
+void SpatialBNDNNLowPOp<T>::ComputeFusedParam_(
+    const int C,
+    const float* scale,
+    const float* bias,
+    const float* mean,
+    const float* var,
+    float* alpha,
+    float* beta) {
+  EigenVectorArrayMap<float> alpha_arr(alpha, C);
+  EigenVectorArrayMap<float> beta_arr(beta, C);
+  alpha_arr = ConstEigenVectorArrayMap<float>(scale, C) *
+      (ConstEigenVectorArrayMap<float>(var, C) + epsilon_).rsqrt();
+  beta_arr = ConstEigenVectorArrayMap<float>(bias, C) -
+      alpha_arr * ConstEigenVectorArrayMap<float>(mean, C);
+
+  // Adjust alpha and beta considering quantization scales
+  alpha_arr = alpha_arr * (in_qparams_[0].scale / out_qparams_.scale);
+  beta_arr = beta_arr / out_qparams_.scale;
+}
+
+template <typename T>
+bool SpatialBNDNNLowPOp<T>::RunOnDevice() {
+  const auto& X = InputTensorCPU_(INPUT);
+  const auto& scale = Input(SCALE);
+  const auto& bias = Input(BIAS);
+
+  const int ndim = X.dim();
+  CAFFE_ENFORCE_GE(ndim, 3);
+  const int N = X.dim32(0);
+  const int C = (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1));
+  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
+  const int HxW =
+      std::accumulate(
+          X_dims.cbegin() + 1, X_dims.cend(), 1, std::multiplies<int>()) /
+      C;
+  CAFFE_ENFORCE_EQ(scale.numel(), C);
+  CAFFE_ENFORCE_EQ(bias.numel(), C);
+
+  GetOutputQuantizationParams_();
+
+  in_qparams_[0] = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
+
+  const float* scale_data = scale.template data<float>();
+  const float* bias_data = bias.template data<float>();
+  ReinitializeTensor(
+      &alpha_, {C}, at::dtype<float>().device(CPUContext::GetDeviceType()));
+  ReinitializeTensor(
+      &beta_, {C}, at::dtype<float>().device(CPUContext::GetDeviceType()));
+  float* alpha_data = alpha_.template mutable_data<float>();
+  float* beta_data = beta_.template mutable_data<float>();
+  if (N == 0) {
+    return true;
+  }
+  const auto& mean = Input(EST_MEAN);
+  const auto& var = Input(EST_VAR);
+  CAFFE_ENFORCE_EQ(mean.numel(), C);
+  CAFFE_ENFORCE_EQ(var.numel(), C);
+  ComputeFusedParam_(
+      C,
+      scale_data,
+      bias_data,
+      mean.template data<float>(),
+      var.template data<float>(),
+      alpha_data,
+      beta_data);
+
+  vector<T> X_temp;
+  const T* X_data =
+      dnnlowp::QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
+  auto* Y = OutputTensorCPU_(OUTPUT);
+  Y->Resize(X.sizes());
+  T* Y_data = GetQuantizedOutputData_();
+
+  if (order_ == StorageOrder::NCHW) {
+    for (int c = 0; c < C; ++c) {
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < HxW; ++j) {
+          long quantized_down = out_qparams_.zero_point +
+              std::lrintf(alpha_data[c] *
+                              (X_data[(i * C + c) * HxW + j] -
+                               in_qparams_[0].zero_point) +
+                          beta_data[c]);
+          Y_data[(i * C + c) * HxW + j] =
+              fbgemm::clamp<long, T>(quantized_down, 8);
+        }
+      }
+    }
+  } else {
+    for (int i = 0; i < N * HxW; ++i) {
+      for (int c = 0; c < C; ++c) {
+        long quantized_down = out_qparams_.zero_point +
+            std::lrintf(alpha_data[c] *
+                            (X_data[i * C + c] - in_qparams_[0].zero_point) +
+                        beta_data[c]);
+        Y_data[i * C + c] = fbgemm::clamp<long, T>(quantized_down, 8);
+      }
+    }
+  }
+
+  RunOnDeviceEpilogue_();
+
+  return true;
+}
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    SpatialBN,
+    DNNLOWP,
+    SpatialBNDNNLowPOp<uint8_t>);
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8SpatialBN,
+    DNNLOWP,
+    SpatialBNDNNLowPOp<uint8_t>);
+
+} // namespace caffe2
diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h
new file mode 100644 (file)
index 0000000..076e91d
--- /dev/null
@@ -0,0 +1,43 @@
+#pragma once
+
+#include "caffe2/operators/spatial_batch_norm_op.h"
+#include "caffe2/quantization/server/dnnlowp_op.h"
+
+namespace caffe2 {
+
+/**
+ * Note this implementation assumes SCALE, BIAS, EST_MEAN, and EST_VAR inputs
+ * are still in fp32, so is epsilon argument
+ */
+template <typename T>
+class SpatialBNDNNLowPOp final : public DNNLowPOp<T, SpatialBNOp<CPUContext>> {
+ public:
+  USE_OPERATOR_FUNCTIONS(CPUContext);
+  USE_DNNLOWP_OPERATOR_BASE_FUNCTIONS(T, SpatialBNOp<CPUContext>);
+  SpatialBNDNNLowPOp(const OperatorDef& operator_def, Workspace* ws);
+
+  virtual ~SpatialBNDNNLowPOp() override = default;
+
+  bool RunOnDevice() override;
+
+ private:
+  void ComputeFusedParam_(
+      const int C,
+      const float* scale,
+      const float* bias,
+      const float* mean,
+      const float* var,
+      float* alpha,
+      float* beta);
+
+  double epsilon_;
+  const StorageOrder order_;
+
+  Tensor alpha_;
+  Tensor beta_;
+
+  INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR);
+  OUTPUT_TAGS(OUTPUT);
+};
+
+} // namespace caffe2
diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py
new file mode 100644 (file)
index 0000000..5dbcfd7
--- /dev/null
@@ -0,0 +1,110 @@
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import collections
+
+import caffe2.python.hypothesis_test_util as hu
+import hypothesis.strategies as st
+import numpy as np
+from caffe2.python import core, dyndep, utils, workspace
+from caffe2.quantization.server import utils as dnnlowp_utils
+from dnnlowp_test_utils import check_quantized_results_close
+from hypothesis import given
+
+
+dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
+workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
+
+
+class DNNLowPOpSpatialBNTest(hu.HypothesisTestCase):
+    # correctness test with no quantization error in inputs
+    @given(
+        size=st.integers(10, 16),
+        input_channels=st.integers(2, 16),
+        output_channels=st.integers(2, 16),
+        batch_size=st.integers(1, 3),
+        order=st.sampled_from(["NCHW", "NHWC"]),
+        in_quantized=st.booleans(),
+        out_quantized=st.booleans(),
+        **hu.gcs_cpu_only
+    )
+    def test_dnnlowp_spatial_bn_int(
+        self,
+        size,
+        input_channels,
+        output_channels,
+        batch_size,
+        order,
+        in_quantized,
+        out_quantized,
+        gc,
+        dc,
+    ):
+        X_min = -77
+        X_max = X_min + 255
+        X = np.round(np.random.rand(batch_size, size, size, input_channels)).astype(
+            np.float32
+        )
+        X[0, 0, 0, 0] = X_min
+        X[0, 0, 0, 1] = X_max
+
+        epsilon = np.abs(np.random.rand())
+        scale = np.random.rand(input_channels).astype(np.float32)
+        bias = np.random.rand(input_channels).astype(np.float32)
+        mean = np.random.rand(input_channels).astype(np.float32)
+        var = np.random.rand(input_channels).astype(np.float32)
+
+        if order == "NCHW":
+            X = utils.NHWC2NCHW(X)
+
+        Output = collections.namedtuple("Output", ["Y", "op_type", "engine"])
+        outputs = []
+
+        op_engine_list = [
+            ("SpatialBN", ""),
+            ("SpatialBN", "DNNLOWP"),
+            ("Int8SpatialBN", "DNNLOWP"),
+        ]
+
+        for op_type, engine in op_engine_list:
+            net = core.Net("test_net")
+
+            do_quantize = "DNNLOWP" in engine and in_quantized
+            do_dequantize = "DNNLOWP" in engine and out_quantized
+
+            if do_quantize:
+                quantize = core.CreateOperator(
+                    "Quantize", ["X"], ["X_q"], engine=engine
+                )
+                net.Proto().op.extend([quantize])
+
+            bn = core.CreateOperator(
+                op_type,
+                ["X_q" if do_quantize else "X", "scale", "bias", "mean", "var"],
+                ["Y_q" if do_dequantize else "Y"],
+                is_test=True,
+                epsilon=epsilon,
+                order=order,
+                engine=engine,
+                dequantize_output=not do_dequantize,
+            )
+            net.Proto().op.extend([bn])
+            if "DNNLOWP" in engine:
+                dnnlowp_utils.add_quantization_param_args(bn, outputs[0][0])
+
+            if do_dequantize:
+                dequantize = core.CreateOperator(
+                    "Dequantize", ["Y_q"], ["Y"], engine=engine
+                )
+                net.Proto().op.extend([dequantize])
+
+            self.ws.create_blob("X").feed(X, device_option=gc)
+            self.ws.create_blob("scale").feed(scale, device_option=gc)
+            self.ws.create_blob("bias").feed(bias, device_option=gc)
+            self.ws.create_blob("mean").feed(mean, device_option=gc)
+            self.ws.create_blob("var").feed(var, device_option=gc)
+            self.ws.run(net)
+            outputs.append(
+                Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine)
+            )
+
+        check_quantized_results_close(outputs)