add Int8FCRelu (#18673)
authorJongsoo Park <jongsoo@fb.com>
Tue, 2 Apr 2019 06:45:01 +0000 (23:45 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 06:50:30 +0000 (23:50 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18673

Add a fused FC + Relu

Reviewed By: csummersea

Differential Revision: D14667055

fbshipit-source-id: d88fefba008fc0ca450291532d2b320694c6b785

caffe2/quantization/server/dnnlowp_op.h
caffe2/quantization/server/fully_connected_dnnlowp_op.cc
caffe2/quantization/server/fully_connected_dnnlowp_op.h
caffe2/quantization/server/fully_connected_dnnlowp_op_test.py

index 9db43c3..da5b8f6 100644 (file)
@@ -122,7 +122,8 @@ class DNNLowPOp : public Operator<CPUContext> {
     }
   }
 
-  Tensor* OutputTensorCPU_(int idx, at::IntArrayRef dims, at::TensorOptions options) {
+  Tensor*
+  OutputTensorCPU_(int idx, at::IntArrayRef dims, at::TensorOptions options) {
     if (dequantize_output_) {
       return Output(idx, dims, options.device(CPU));
     } else {
index 5eee0a9..e5ff334 100644 (file)
@@ -6,6 +6,7 @@
 
 #include "caffe2/core/flags.h"
 #include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/fc_inference.h"
 #include "caffe2/utils/cpuid.h"
 #include "fbgemm_pack_matrix_cache.h"
 #include "fbgemm_pack_op.h"
@@ -24,8 +25,8 @@ namespace caffe2 {
 
 using namespace std;
 
-template <typename T>
-FullyConnectedDNNLowPOp<T>::FullyConnectedDNNLowPOp(
+template <typename T, bool ReluFused>
+FullyConnectedDNNLowPOp<T, ReluFused>::FullyConnectedDNNLowPOp(
     const OperatorDef& operator_def,
     Workspace* ws)
     : BaseType(operator_def, ws),
@@ -50,12 +51,17 @@ FullyConnectedDNNLowPOp<T>::FullyConnectedDNNLowPOp(
   VLOG(2) << "DNNLOWP FC with output " << operator_def.output(0);
 }
 
-template <typename T>
-bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
+template <typename T, bool ReluFused>
+bool FullyConnectedDNNLowPOp<T, ReluFused>::RunOnDevice() {
   using namespace std;
   using namespace dnnlowp;
 
+  bool first_invocation = !this->arguments_parsed_;
   this->ParseDNNLowPOperatorArguments_();
+  if (first_invocation && ReluFused) {
+    followed_by_ = "Relu";
+    AdjustOutputTensorQuantizationParamsWithFollowedBy(this, followed_by_);
+  }
 
   if ((!GetCpuId().avx2() || FLAGS_caffe2_dnnlowp_enforce_default_operators) &&
       dequantize_output_) {
@@ -199,9 +205,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
             row_offsets_.data());
 
         if (quantize_channelwise_) {
-          ReQuantizeOutput<
-              false /* FUSE_RELU */,
-              QuantizationGranularity::OUT_CHANNEL>
+          ReQuantizeOutput<ReluFused, QuantizationGranularity::OUT_CHANNEL>
               outputProcObj(
                   doNothingObj,
                   requantization_multipliers_.data(),
@@ -224,7 +228,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
               0, // thread_id
               1); // num_threads
         } else {
-          ReQuantizeOutput<false /* FUSE_RELU */> outputProcObj(
+          ReQuantizeOutput<ReluFused> outputProcObj(
               doNothingObj,
               requantization_multipliers_.data(),
               out_qparams_.zero_point,
@@ -258,7 +262,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
             X_pack_buf_.data(), // buffer for packed matrix
             1); // group
 
-        ReQuantizeOutput<false /* FUSE_RELU */> outputProcObj(
+        ReQuantizeOutput<ReluFused> outputProcObj(
             doNothingObj,
             requantization_multipliers_.data(),
             out_qparams_.zero_point,
@@ -305,9 +309,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
         DoNothing<float, float> doNothingObj{};
 
         if (quantize_channelwise_) {
-          ReQuantizeForFloat<
-              false /* FUSE_RELU*/,
-              QuantizationGranularity::OUT_CHANNEL>
+          ReQuantizeForFloat<ReluFused, QuantizationGranularity::OUT_CHANNEL>
               outputProcObj(
                   doNothingObj,
                   in_qparams_[0].scale,
@@ -329,7 +331,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
               0, // thread_id
               1); // num_threads
         } else {
-          ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
+          ReQuantizeForFloat<ReluFused> outputProcObj(
               doNothingObj,
               in_qparams_[0].scale,
               filter_scales_.data(),
@@ -367,9 +369,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
         DoNothing<float, float> doNothingObj{};
 
         if (quantize_channelwise_) {
-          ReQuantizeForFloat<
-              false /* FUSE_RELU*/,
-              QuantizationGranularity::OUT_CHANNEL>
+          ReQuantizeForFloat<ReluFused, QuantizationGranularity::OUT_CHANNEL>
               outputProcObj(
                   doNothingObj,
                   in_qparams_[0].scale,
@@ -391,7 +391,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
               0, // thread_id
               1); // num_threads
         } else {
-          ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
+          ReQuantizeForFloat<ReluFused> outputProcObj(
               doNothingObj,
               in_qparams_[0].scale,
               filter_scales_.data(),
@@ -491,6 +491,9 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
           Ydata[i * N + j] = Y_int32_[i * N + j] * in_qparams_[0].scale *
                   filter_qparams_[quant_group].scale +
               b_dequantized_data_[j];
+          if (ReluFused) {
+            Ydata[i * N + j] = std::max(Ydata[i * N + j], 0.0f);
+          }
         }
       }
     }
@@ -516,6 +519,10 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
 
           Ydata[i * N + j] = fbgemm::Requantize<T>(
               Y_int32_[i * N + j], requantization_params_[quant_group]);
+          if (ReluFused) {
+            Ydata[i * N + j] =
+                std::max<T>(out_qparams_.zero_point, Ydata[i * N + j]);
+          }
         }
       }
     }
@@ -546,8 +553,8 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
   return true;
 }
 
-template <typename T>
-bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
+template <typename T, bool ReluFused>
+bool FullyConnectedDNNLowPOp<T, ReluFused>::GetQuantizationParameters_() {
   using namespace dnnlowp;
 
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -855,4 +862,20 @@ REGISTER_CPU_OPERATOR_WITH_ENGINE(
     DNNLOWP_ROWWISE,
     FullyConnectedDNNLowPOp<uint8_t>);
 
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8FCRelu,
+    DNNLOWP,
+    FullyConnectedDNNLowPOp<uint8_t, true>);
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8FCRelu,
+    DNNLOWP_ROWWISE,
+    FullyConnectedDNNLowPOp<uint8_t, true>);
+
+using namespace std::placeholders;
+OPERATOR_SCHEMA(Int8FCRelu)
+    .NumInputs(3)
+    .NumOutputs(1)
+    .TensorInferenceFunction(std::bind(FCShapeInference, _1, _2, false))
+    .CostInferenceFunction(std::bind(CostInferenceForFC, _1, _2, false));
+
 } // namespace caffe2
index 39c6db6..5dd90e1 100644 (file)
@@ -6,7 +6,7 @@
 
 namespace caffe2 {
 
-template <typename T>
+template <typename T, bool ReluFused = false>
 class FullyConnectedDNNLowPOp
     : public DNNLowPOp<T, FullyConnectedOp<CPUContext>> {
  public:
index 064d16b..04f42eb 100644 (file)
@@ -30,6 +30,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         prepack_weight=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
+        fuse_relu=st.booleans(),
         **hu.gcs_cpu_only
     )
     def test_dnnlowp_fully_connected_int(
@@ -43,6 +44,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         prepack_weight,
         preserve_activation_sparsity,
         preserve_weight_sparsity,
+        fuse_relu,
         gc,
         dc,
     ):
@@ -92,10 +94,17 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
 
         op_engine_list = [
             ("FC", ""),
-            ("FC", "DNNLOWP"),
-            ("FC", "DNNLOWP_16"),
-            ("Int8FC", "DNNLOWP"),
         ]
+        if fuse_relu:
+            op_engine_list += [
+                ("Int8FCRelu", "DNNLOWP"),
+            ]
+        else:
+            op_engine_list += [
+                ("FC", "DNNLOWP"),
+                ("FC", "DNNLOWP_16"),
+                ("Int8FC", "DNNLOWP"),
+            ]
 
         for op_type, engine in op_engine_list:
             init_net = core.Net("test_init_net")
@@ -173,6 +182,8 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
                     fc, outputs[0][0], preserve_activation_sparsity
                 )
             net.Proto().op.extend([fc])
+            if fuse_relu and "DNNLOWP" not in engine:
+                net.Relu(["Y"], "Y")
 
             if do_dequantize:
                 dequantize = core.CreateOperator(