pre-pack operation of dnnlowp conv with 16-bit accumulation (#14881)
authorJongsoo Park <jongsoo@fb.com>
Mon, 10 Dec 2018 09:06:17 +0000 (01:06 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 10 Dec 2018 09:08:21 +0000 (01:08 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14881

This diff allows us to pre-quantize and pre-pack weight matrix used in DNNLOWP_ACC16 .
The intended use pattern is run Int8ConvPackWeight in init_net that generates a packed weight and Int8Conv with DNNLOWP_ACC16 engine uses the the packed weight.

Reviewed By: csummersea

Differential Revision: D13374662

fbshipit-source-id: dd02b9a4eb7af1fe208aa857fcd0b445e6e395af

26 files changed:
caffe2/quantization/server/CMakeLists.txt
caffe2/quantization/server/conv_depthwise_dnnlowp_op_test.py
caffe2/quantization/server/conv_dnnlowp_acc16_op.cc
caffe2/quantization/server/conv_dnnlowp_acc16_op.h
caffe2/quantization/server/conv_dnnlowp_acc16_op_test.py
caffe2/quantization/server/conv_dnnlowp_op.cc
caffe2/quantization/server/conv_dnnlowp_op.h
caffe2/quantization/server/conv_dnnlowp_op_test.py
caffe2/quantization/server/conv_groupwise_dnnlowp_acc16_op_test.py
caffe2/quantization/server/conv_groupwise_dnnlowp_op_test.py
caffe2/quantization/server/conv_pool_dnnlowp_op_base.h
caffe2/quantization/server/dnnlowp_op.h
caffe2/quantization/server/fbgemm_pack_blob.h [new file with mode: 0644]
caffe2/quantization/server/fbgemm_pack_op.cc [new file with mode: 0644]
caffe2/quantization/server/fbgemm_pack_op.h [new file with mode: 0644]
caffe2/quantization/server/fully_connected_dnnlowp_acc16_op.cc
caffe2/quantization/server/fully_connected_dnnlowp_acc16_op.h
caffe2/quantization/server/fully_connected_dnnlowp_acc16_op_test.py
caffe2/quantization/server/fully_connected_dnnlowp_op.cc
caffe2/quantization/server/fully_connected_dnnlowp_op.h
caffe2/quantization/server/fully_connected_dnnlowp_op_test.py
caffe2/quantization/server/fully_connected_rowwise_dnnlowp_op.cc
caffe2/quantization/server/fully_connected_rowwise_dnnlowp_op.h
caffe2/quantization/server/fully_connected_rowwise_dnnlowp_op_test.py
caffe2/quantization/server/group_norm_dnnlowp_op_test.py
caffe2/quantization/server/utils.py

index 465a466..5d85281 100644 (file)
@@ -26,6 +26,7 @@ list(APPEND Caffe2_CPU_SRCS
   "${CMAKE_CURRENT_SOURCE_DIR}/elementwise_sum_dnnlowp_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/elementwise_sum_relu_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/fbgemm_pack_matrix_cache.cc"
+  "${CMAKE_CURRENT_SOURCE_DIR}/fbgemm_pack_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/fully_connected_dnnlowp_acc16_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/fully_connected_dnnlowp_op.cc"
   "${CMAKE_CURRENT_SOURCE_DIR}/fully_connected_fake_lowp_op.cc"
index 4057437..4c1953d 100644 (file)
@@ -26,6 +26,7 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
         # depthwise 3x3 fast path only works for a multiple of 8
         group=st.sampled_from([8, 32, 40]),
         batch_size=st.integers(1, 3),
+        prepack_weight=st.booleans(),
         share_col_buffer=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
@@ -38,6 +39,7 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
         size,
         group,
         batch_size,
+        prepack_weight,
         share_col_buffer,
         preserve_activation_sparsity,
         preserve_weight_sparsity,
@@ -84,25 +86,42 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
             ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine
             do_dequantize = "DNNLOWP" in engine
-
-            preserve_activation_sparsity_int = 1 if preserve_activation_sparsity else 0
-            preserve_weight_sparsity_int = 1 if preserve_weight_sparsity else 0
+            do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
                     "Quantize",
                     ["X"],
                     ["X_q"],
-                    preserve_activation_sparsity=preserve_activation_sparsity_int,
+                    preserve_activation_sparsity=preserve_activation_sparsity,
                     engine=engine,
                     device_option=gc,
                 )
                 net.Proto().op.extend([quantize])
 
+            if do_prepack_weight:
+                x_q_param = dnnlowp_utils.choose_quantization_params(
+                    X.min(), X.max(), preserve_activation_sparsity
+                )
+                inputs = ["W"]
+                if do_dequantize:
+                    inputs += ["b"]
+                pack = core.CreateOperator(
+                    "Int8ConvPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    group=group,
+                    preserve_weight_sparsity=preserve_weight_sparsity,
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
+
             conv = core.CreateOperator(
                 op_type,
                 ["X_q" if do_quantize else "X", "W", "b"],
@@ -113,13 +132,13 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
                 pad=pad,
                 order=order,
                 shared_buffer=(1 if share_col_buffer else 0),
-                preserve_activation_sparsity=preserve_activation_sparsity_int,
-                preserve_weight_sparsity=preserve_weight_sparsity_int,
+                preserve_activation_sparsity=preserve_activation_sparsity,
+                preserve_weight_sparsity=preserve_weight_sparsity,
                 engine=engine,
                 group=group,
                 device_option=gc,
             )
-            if do_dequantize:
+            if do_dequantize or do_prepack_weight:
                 dnnlowp_utils.add_quantization_param_args(
                     conv, outputs[0][0], preserve_activation_sparsity
                 )
@@ -139,6 +158,7 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             Y = self.ws.blobs["Y"].fetch()
             outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order))
@@ -151,6 +171,7 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
         # depthwise 3x3 fast path only works for a multiple of 8
         group=st.sampled_from([8, 32, 40]),
         batch_size=st.integers(1, 3),
+        prepack_weight=st.booleans(),
         fuse_relu=st.booleans(),
         share_col_buffer=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
@@ -163,6 +184,7 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
         size,
         group,
         batch_size,
+        prepack_weight,
         fuse_relu,
         share_col_buffer,
         preserve_activation_sparsity,
@@ -199,31 +221,48 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
         op_engine_list = [(op, ""), (op, "DNNLOWP"), ("Int8" + op, "DNNLOWP")]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
+            # TODO: no fall back to NCHW
             fall_back_to_NCHW = "DNNLOWP" not in engine
 
             if fall_back_to_NCHW:
                 X_nchw = nhwc2nchw(X)
                 W_nchw = nhwc2nchw(W)
-
             do_quantize = "DNNLOWP" in engine
             do_dequantize = "DNNLOWP" in engine
-
-            preserve_activation_sparsity_int = 1 if preserve_activation_sparsity else 0
-            preserve_weight_sparsity_int = 1 if preserve_weight_sparsity else 0
+            do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
                     "Quantize",
                     ["X"],
                     ["X_q"],
-                    preserve_activation_sparsity=preserve_activation_sparsity_int,
+                    preserve_activation_sparsity=preserve_activation_sparsity,
                     engine=engine,
                     device_option=gc,
                 )
                 net.Proto().op.extend([quantize])
 
+            if do_prepack_weight:
+                x_q_param = dnnlowp_utils.choose_quantization_params(
+                    X.min(), X.max(), preserve_activation_sparsity
+                )
+                inputs = ["W"]
+                if do_dequantize:
+                    inputs += ["b"]
+                pack = core.CreateOperator(
+                    "Int8ConvPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    group=group,
+                    preserve_weight_sparsity=preserve_weight_sparsity,
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
+
             conv = core.CreateOperator(
                 op_type,
                 ["X_q" if do_quantize else "X", "W", "b"],
@@ -234,13 +273,13 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
                 pads=[pad] * (3 * 2),
                 order="NCHW" if fall_back_to_NCHW else order,
                 shared_buffer=(1 if share_col_buffer else 0),
-                preserve_activation_sparsity=preserve_activation_sparsity_int,
-                preserve_weight_sparsity=preserve_weight_sparsity_int,
+                preserve_activation_sparsity=preserve_activation_sparsity,
+                preserve_weight_sparsity=preserve_weight_sparsity,
                 engine=engine,
                 group=group,
                 device_option=gc,
             )
-            if do_dequantize:
+            if do_dequantize or do_prepack_weight:
                 dnnlowp_utils.add_quantization_param_args(
                     conv, outputs[0][0], preserve_activation_sparsity
                 )
@@ -259,6 +298,7 @@ class DNNLowPOpConvDepthWiseTest(hu.HypothesisTestCase):
                 W_nchw if fall_back_to_NCHW else W, device_option=gc
             )
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             Y = self.ws.blobs["Y"].fetch()
             if fall_back_to_NCHW:
index f6f14cf..58c4306 100644 (file)
@@ -1,5 +1,4 @@
 #include "conv_dnnlowp_acc16_op.h"
-#include "dnnlowp_op.h"
 
 // #define DNNLOWP_ACC16_IN_SLOW_PATH
 // #define DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -10,7 +9,9 @@
 #include <omp.h>
 #endif
 
+#include "dnnlowp_op.h"
 #include "dnnlowp_partition.h"
+#include "fbgemm_pack_op.h"
 #include "im2col_dnnlowp.h"
 
 C10_DECLARE_int32(dnnlowp_nbits_in_non_outlier);
@@ -61,6 +62,30 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
     return false;
   }
 
+  if (!Wq_acc16_packed_ &&
+      this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+    CAFFE_ENFORCE_EQ(
+        ConvPoolOpBase<CPUContext>::order_,
+        StorageOrder::NHWC,
+        "Pre-packed weight only works with NHWC layout");
+    // If the input is already packed
+    const auto& packed_filter =
+        this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+    Wq_outlier_ = packed_filter.W_outlier;
+    Wq_acc16_packed_ = packed_filter.W_acc16;
+
+    if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
+      LOG(WARNING)
+          << "nbits_in_non_outlier in packed weight "
+          << packed_filter.nbits_in_non_outlier
+          << " doesn't match with nbits_in_non_outlier specified in operator "
+          << nbits_in_non_outlier_;
+    }
+
+    first_invocation_ = false;
+    return true;
+  }
+
   int kernel_dim = this->KernelDim_();
   const auto& filter = InputTensorCPU_(FILTER);
   int M = filter.dim32(0);
@@ -71,46 +96,9 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
       nbits_in_non_outlier_ < 8) {
     CAFFE_ENFORCE(!W_quantized_.empty());
 
-    int outlier_cnt = 0;
-    for (int group_id = 0; group_id < group_; ++group_id) {
-      for (int i = 0; i < (M / group_) * kernel_dim; ++i) {
-        int8_t w = W_quantized_[group_id * (M / group_) * kernel_dim + i];
-        bool is_outlier = nbits_in_non_outlier_ == 0 ||
-            w < -(1 << (nbits_in_non_outlier_ - 1)) ||
-            w >= (1 << (nbits_in_non_outlier_ - 1));
-        if (is_outlier) {
-          ++outlier_cnt;
-        }
-      }
-    }
-
-    Wq_outlier_.reset(new fbgemm::CompressedSparseColumn(kernel_dim, M));
-    Wq_outlier_->RowIdx().resize(outlier_cnt);
-    Wq_outlier_->Values().resize(outlier_cnt);
-
-    outlier_cnt = 0;
-    for (int group_id = 0; group_id < group_; ++group_id) {
-      for (int j = 0; j < M / group_; ++j) {
-        Wq_outlier_->ColPtr()[group_id * (M / group_) + j] = outlier_cnt;
-
-        for (int k = 0; k < kernel_dim; ++k) {
-          int8_t w =
-              W_quantized_[(group_id * (M / group_) + j) * kernel_dim + k];
-          bool is_outlier = nbits_in_non_outlier_ == 0 ||
-              w < -(1 << (nbits_in_non_outlier_ - 1)) ||
-              w >= (1 << (nbits_in_non_outlier_ - 1));
-          if (is_outlier) {
-            CAFFE_ENFORCE_LE(k, numeric_limits<int16_t>::max());
-            Wq_outlier_->RowIdx()[outlier_cnt] = k;
-            Wq_outlier_->Values()[outlier_cnt] = w;
-            ++outlier_cnt;
-
-            W_quantized_[(group_id * (M / group_) + j) * kernel_dim + k] = 0;
-          }
-        }
-      }
-    } // for each group
-    Wq_outlier_->ColPtr()[M] = outlier_cnt;
+    Wq_outlier_.reset(ExtractOutlierMatrix(
+        group_, kernel_dim, M, nbits_in_non_outlier_, W_quantized_));
+    int outlier_cnt = Wq_outlier_->ColPtr()[M];
 
     LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
               << OperatorBase::debug_def().input(1) << " is "
@@ -255,7 +243,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
     } else {
       Y_data = Y->template mutable_data<uint8_t>();
     }
-    this->column_offsets_.resize(output_image_size * dnnlowp_get_max_threads());
+    this->column_offsets_->resize(
+        output_image_size * dnnlowp_get_max_threads());
 
 #ifdef _OPENMP
 #pragma omp parallel for
@@ -496,7 +485,7 @@ void ConvDNNLowPAcc16Op<ReluFused>::DispatchFBGEMM(
       in_qparams_[INPUT].zero_point,
       this->filter_zero_points_.data(),
       packA.getRowOffsetBuffer(),
-      this->column_offsets_.data(),
+      this->column_offsets_->data(),
       InputSize() == 3 ? this->b_quantized_data_ : nullptr,
       M,
       group_);
@@ -569,7 +558,7 @@ void ConvDNNLowPAcc16Op<ReluFused>::ConvOutlier_(
           dnnlowp_get_thread_num());
 
       for (int group_id = group_begin; group_id < group_end; ++group_id) {
-        assert(Wq_outlier_->NumOfRows() == kernel_dim);
+        CAFFE_ENFORCE_EQ(Wq_outlier_->NumOfRows(), kernel_dim);
         // Dense-matrix times sparse-matrix multiplication for outlier
         fbgemm::block_type_t block = {
             0, i_end - i_begin, group_id * (M / group_), M / group_};
index b978129..61440f1 100644 (file)
@@ -52,11 +52,11 @@ class ConvDNNLowPAcc16Op final : public ConvDNNLowPOp<std::uint8_t, ReluFused> {
       const std::uint8_t* col_buffer,
       vector<std::int32_t>* Y_int32);
 
-  std::unique_ptr<fbgemm::PackBMatrix<std::int8_t, std::int16_t>>
+  std::shared_ptr<fbgemm::PackBMatrix<std::int8_t, std::int16_t>>
       Wq_acc16_packed_;
 
   // Wq outlier in CSC format
-  std::unique_ptr<fbgemm::CompressedSparseColumn> Wq_outlier_;
+  std::shared_ptr<fbgemm::CompressedSparseColumn> Wq_outlier_;
 
   // Threshold to decide whether a weight is outlier.
   // For example, if nbits_in_non_outlier_ == 7, w is an outlier if w < -64 or
index 6dc6be8..cfa82d2 100644 (file)
@@ -6,7 +6,6 @@ import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep
-from caffe2.python.fb import hardcode_scale_zp
 from caffe2.quantization.server import utils as dnnlowp_utils
 from dnnlowp_test_utils import (
     check_quantized_results_close,
@@ -146,8 +145,8 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 net.Proto().op.extend([int8_given_tensor_fill])
 
                 # Bias
-                x_q_param = hardcode_scale_zp.choose_quantization_params(
-                    X.min(), X.max()
+                x_q_param = dnnlowp_utils.choose_quantization_params(
+                    X.min(), X.max(), preserve_activation_sparsity
                 )
                 int8_bias_tensor_fill = dnnlowp_utils.create_int8_bias_tensor_fill(
                     b, "b_q", x_q_param, w_q_param
@@ -214,7 +213,8 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
         weight_quantized=st.booleans(),
-        nbits_in_non_outlier=st.sampled_from((0, 6)),
+        prepack_weight=st.booleans(),
+        nbits_in_non_outlier=st.sampled_from((6, 8)),
         share_col_buffer=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
@@ -235,6 +235,7 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         in_quantized,
         out_quantized,
         weight_quantized,
+        prepack_weight,
         nbits_in_non_outlier,
         share_col_buffer,
         preserve_activation_sparsity,
@@ -291,10 +292,6 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
             W[1, 0, 0, 0] = W_max
             W[..., 1] = W_min + 128
 
-            if order == "NCHW":
-                X = nhwc2nchw(X)
-                W = nhwc2nchw(W)
-
             # No input quantization error in bias
             b = np.round(np.random.randn(output_channels)).astype(np.float32)
 
@@ -308,11 +305,13 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine and in_quantized
             do_dequantize = "DNNLOWP" in engine and out_quantized
             do_quantize_weight = "DNNLOWP" in engine and weight_quantized
+            do_prepack_weight = "DNNLOWP" in engine and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -325,26 +324,44 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            x_q_param = dnnlowp_utils.choose_quantization_params(
+                X.min(), X.max(), preserve_activation_sparsity
+            )
             if do_quantize_weight:
                 int8_given_tensor_fill, w_q_param = dnnlowp_utils.create_int8_given_tensor_fill(
                     W, "W_q", preserve_weight_sparsity
                 )
-                net.Proto().op.extend([int8_given_tensor_fill])
+                init_net.Proto().op.extend([int8_given_tensor_fill])
 
                 # Bias
-                x_q_param = hardcode_scale_zp.choose_quantization_params(
-                    X.min(), X.max()
-                )
                 int8_bias_tensor_fill = dnnlowp_utils.create_int8_bias_tensor_fill(
                     b, "b_q", x_q_param, w_q_param
                 )
-                net.Proto().op.extend([int8_bias_tensor_fill])
+                init_net.Proto().op.extend([int8_bias_tensor_fill])
+
+            if do_prepack_weight:
+                inputs = ["W_q" if do_quantize_weight else "W"]
+                if do_dequantize:
+                    inputs += ["b_q" if do_quantize_weight else "b"]
+                pack = core.CreateOperator(
+                    "Int8ConvPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    group=group,
+                    nbits_in_non_outlier=nbits_in_non_outlier,
+                    preserve_weight_sparsity=preserve_weight_sparsity,
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
 
             conv = core.CreateOperator(
                 op_type,
                 [
                     "X_q" if do_quantize else "X",
-                    "W_q" if do_quantize_weight else "W",
+                    "W_packed"
+                    if do_prepack_weight
+                    else ("W_q" if do_quantize_weight else "W"),
                     "b_q" if do_quantize_weight else "b",
                 ],
                 ["Y_q" if do_dequantize else "Y"],
@@ -362,7 +379,7 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 group=group,
                 device_option=gc,
             )
-            if do_dequantize or do_quantize_weight:
+            if do_dequantize or do_quantize_weight or do_prepack_weight:
                 # When quantized weight is provided, we can't rescale the
                 # output dynamically by looking at the range of output of each
                 # batch, so here we provide the range of output observed from
@@ -381,6 +398,7 @@ class DNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             Y = self.ws.blobs["Y"].fetch()
             outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order))
index 0cbb1d7..3df909c 100644 (file)
@@ -1,5 +1,4 @@
 #include "conv_dnnlowp_op.h"
-#include "dnnlowp_op.h"
 
 // #define DNNLOWP_MEASURE_TIME_BREAKDOWN
 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
@@ -15,7 +14,9 @@
 
 #include <fbgemm/src/RefImplementations.h>
 
+#include "dnnlowp_op.h"
 #include "dnnlowp_partition.h"
+#include "fbgemm_pack_op.h"
 #include "im2col_dnnlowp.h"
 #include "mmio.h"
 
@@ -40,7 +41,9 @@ template <typename T, bool ReluFused>
 ConvDNNLowPOp<T, ReluFused>::ConvDNNLowPOp(
     const OperatorDef& operator_def,
     Workspace* ws)
-    : BaseType(operator_def, ws) {
+    : BaseType(operator_def, ws),
+      column_offsets_(make_shared<vector<int32_t>>()),
+      b_quantized_(make_shared<vector<int32_t>>()) {
   in_qparams_.resize(1);
 
   // Create shared buffer mutex in the constructor
@@ -99,7 +102,7 @@ bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
   return StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
       is_same<T, uint8_t>::value && X.template IsType<T>() &&
-      OperatorBase::debug_def().engine() != "DNNLOWP_ACC16" &&
+      this->debug_def().engine() != "DNNLOWP_ACC16" &&
       group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 2 && kernel_h() == 3 && kernel_w() == 3 &&
       stride_h() == stride_w() && (stride_h() == 1 || stride_h() == 2) &&
@@ -113,7 +116,7 @@ bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
   const Tensor& X = InputTensorCPU_(INPUT);
   bool ret = StorageOrder::NHWC == ConvPoolOpBase<CPUContext>::order_ &&
       is_same<T, uint8_t>::value && X.template IsType<T>() &&
-      OperatorBase::debug_def().engine() != "DNNLOWP_ACC16" &&
+      this->debug_def().engine() != "DNNLOWP_ACC16" &&
       group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
       this->kernel_.size() == 3 && this->kernel_[0] == 3 &&
       this->kernel_[1] == 3 && this->kernel_[2] == 3 &&
@@ -201,20 +204,16 @@ void ConvDNNLowPOp<T, ReluFused>::PreComputeRowColumnOffsets_() {
   vector<int>& offsets =
       StorageOrder::NCHW == ConvPoolOpBase<CPUContext>::order_
       ? row_offsets_
-      : column_offsets_;
+      : *column_offsets_;
 
   if (offsets.empty()) {
-    offsets.resize(M);
-    for (int g = 0; g < filter_qparams_.size(); ++g) {
-      int i_begin = g * (M / filter_qparams_.size());
-      int i_end = i_begin + (M / filter_qparams_.size());
-      for (int i = i_begin; i < i_end; ++i) {
-        int32_t sum = 0;
-        for (int k = 0; k < kernel_dim; ++k) {
-          sum += W_quantized_[i * kernel_dim + k];
-        }
-        offsets[i] = sum - FilterQuantizationParams(g).zero_point * kernel_dim;
-      }
+    if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+      const auto& packed_filter =
+          this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+      column_offsets_ = packed_filter.column_offsets;
+    } else {
+      ComputeColumnOffsets<T_signed>(
+          kernel_dim, M, W_quantized_.data(), filter_qparams_, offsets);
     }
   }
 }
@@ -230,50 +229,61 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeBias_() {
   if (InputSize() == 3 &&
       ((!b_quantized_data_ && !b_dequantized_data_) ||
        in_qparams_[INPUT].scale != in_qparams_scale_old_)) {
-    const auto& bias = InputTensorCPU_(BIAS);
-    if (OperatorBase::InputIsType<int8::Int8TensorCPU>(BIAS)) {
-      TensorQuantizationParams bias_qparams;
-      bias_qparams.scale = OperatorBase::Input<int8::Int8TensorCPU>(BIAS).scale;
-      bias_qparams.zero_point =
-          OperatorBase::Input<int8::Int8TensorCPU>(BIAS).zero_point;
-      CAFFE_ENFORCE_LE(
-          std::abs(
-              bias_qparams.scale -
-              in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale),
-          1e-4);
-      CAFFE_ENFORCE_EQ(bias_qparams.zero_point, 0);
-      b_quantized_data_ = bias.template data<int32_t>();
-      if (dequantize_output_) {
-        b_dequantized_.resize(bias.numel());
+    if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
+        this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER)
+            .bias.get()) {
+      const auto& packed_filter =
+          this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+      CAFFE_ENFORCE(!dequantize_output_);
+      b_quantized_ = packed_filter.bias;
+      b_quantized_data_ = b_quantized_->data();
+    } else {
+      const auto& bias = InputTensorCPU_(BIAS);
+      if (OperatorBase::InputIsType<int8::Int8TensorCPU>(BIAS)) {
+        TensorQuantizationParams bias_qparams;
+        bias_qparams.scale =
+            OperatorBase::Input<int8::Int8TensorCPU>(BIAS).scale;
+        bias_qparams.zero_point =
+            OperatorBase::Input<int8::Int8TensorCPU>(BIAS).zero_point;
+        CAFFE_ENFORCE_LE(
+            std::abs(
+                bias_qparams.scale -
+                in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale),
+            1e-4);
+        CAFFE_ENFORCE_EQ(bias_qparams.zero_point, 0);
+        b_quantized_data_ = bias.template data<int32_t>();
+        if (dequantize_output_) {
+          b_dequantized_.resize(bias.numel());
 #ifdef _OPENMP
 #pragma omp parallel for
 #endif
-        for (int i = 0; i < b_dequantized_.size(); ++i) {
-          b_dequantized_[i] =
-              fbgemm::Dequantize<int32_t>(b_quantized_data_[i], bias_qparams);
+          for (int i = 0; i < b_dequantized_.size(); ++i) {
+            b_dequantized_[i] =
+                fbgemm::Dequantize<int32_t>(b_quantized_data_[i], bias_qparams);
+          }
+          b_dequantized_data_ = b_dequantized_.data();
         }
-        b_dequantized_data_ = b_dequantized_.data();
-      }
-    } else {
-      b_dequantized_data_ = bias.template data<float>();
-      if (!dequantize_output_) {
-        b_quantized_.resize(bias.numel());
-        for (int g = 0; g < filter_qparams_.size(); ++g) {
-          int i_begin = g * (M / filter_qparams_.size());
-          int i_end = i_begin + (M / filter_qparams_.size());
-          for (int i = i_begin; i < i_end; ++i) {
-            b_quantized_[i] = fbgemm::Quantize<int32_t>(
-                b_dequantized_data_[i],
-                0,
-                in_qparams_[INPUT].scale * FilterQuantizationParams(g).scale,
-                32,
-                true /* signed */);
+      } else {
+        b_dequantized_data_ = bias.template data<float>();
+        if (!dequantize_output_) {
+          b_quantized_->resize(bias.numel());
+          for (int g = 0; g < filter_qparams_.size(); ++g) {
+            int i_begin = g * (M / filter_qparams_.size());
+            int i_end = i_begin + (M / filter_qparams_.size());
+            for (int i = i_begin; i < i_end; ++i) {
+              (*b_quantized_)[i] = fbgemm::Quantize<int32_t>(
+                  b_dequantized_data_[i],
+                  0,
+                  in_qparams_[INPUT].scale * FilterQuantizationParams(g).scale,
+                  32,
+                  true /* signed */);
+            }
           }
+          b_quantized_data_ = b_quantized_->data();
         }
-        b_quantized_data_ = b_quantized_.data();
       }
+      in_qparams_scale_old_ = in_qparams_[INPUT].scale;
     }
-    in_qparams_scale_old_ = in_qparams_[INPUT].scale;
 
     CAFFE_ENFORCE(
         (dequantize_output_ && b_dequantized_data_) ||
@@ -307,84 +317,80 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
   if ((depthwise_3x3_fast_path && !Wq_depthwise_3x3_packed_) ||
       (depthwise_3x3x3_fast_path && !Wq_depthwise_3x3x3_packed_) ||
       (packW && !Wq_packed_) || (!packW && W_quantized_.empty())) {
-    W_quantized_.resize(filter.numel());
-    if (quantize_groupwise_) {
-      filter_qparams_.resize(group_);
-      filter_scales_.resize(group_);
-      filter_zero_points_.resize(group_);
-      requantization_params_.resize(group_);
-      requantization_multipliers_.resize(group_);
+    if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+      CAFFE_ENFORCE_EQ(
+          ConvPoolOpBase<CPUContext>::order_,
+          StorageOrder::NHWC,
+          "Pre-packed weight only works with NHWC layout");
+
+      const auto& packed_filter =
+          this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+      filter_qparams_ = packed_filter.qparams;
     } else {
-      filter_qparams_.resize(1);
-      filter_scales_.resize(1);
-      filter_zero_points_.resize(1);
-      requantization_params_.resize(1);
-      requantization_multipliers_.resize(1);
-    }
+      filter_qparams_.resize(quantize_groupwise_ ? group_ : 1);
+      QuantizeWeight<T>(
+          InputBlob(FILTER),
+          kernel_dim,
+          M,
+          filter_qparams_,
+          W_quantized_,
+          qfactory_.get());
 
-    int signed_min = 1 << (qfactory_->GetWeightPrecision() - 1);
-    if (OperatorBase::InputIsType<int8::Int8TensorCPU>(FILTER)) {
-      if (quantize_groupwise_) {
+      if (this->template InputIsType<int8::Int8TensorCPU>(FILTER) &&
+          quantize_groupwise_) {
         static int log_occurences = 0;
         if (log_occurences < 32) {
           ++log_occurences;
           LOG(WARNING) << "Cannot do group-wise quantization for "
                           "pre-quantized weight "
-                       << OperatorBase::debug_def().input(FILTER);
+                       << this->debug_def().input(FILTER);
         }
       }
-      FilterQuantizationParams(0).scale =
-          OperatorBase::Input<int8::Int8TensorCPU>(FILTER).scale;
-      FilterQuantizationParams(0).zero_point =
-          OperatorBase::Input<int8::Int8TensorCPU>(FILTER).zero_point -
-          signed_min;
-
-      const auto& W = InputTensorCPU_(FILTER);
-      const T* W_data = W.template data<T>();
-      for (auto i = 0; i < W.numel(); ++i) {
-        W_quantized_[i] = W_data[i] - signed_min;
-      }
-    } else {
-      for (int g = 0; g < filter_qparams_.size(); ++g) {
-        size_t offset = g * (M / filter_qparams_.size()) * kernel_dim;
-        filter_qparams_[g] = qfactory_->ChooseQuantizationParams(
-            filter.template data<float>() + offset,
-            (M / filter_qparams_.size()) * kernel_dim,
-            true /*weight*/);
-
-        // filter_qparams_[g] is computed for unsigned type.
-        // Adjust for the fact that weight will actually use signed.
-        FilterQuantizationParams(g).zero_point -= signed_min;
-
-        fbgemm::Quantize<T_signed>(
-            filter.template data<float>() + offset,
-            W_quantized_.data() + offset,
-            (M / filter_qparams_.size()) * kernel_dim,
-            FilterQuantizationParams(g));
-      }
     }
 
+    filter_scales_.resize(filter_qparams_.size());
+    filter_zero_points_.resize(filter_qparams_.size());
+    requantization_params_.resize(filter_qparams_.size());
+    requantization_multipliers_.resize(filter_qparams_.size());
     for (int i = 0; i < filter_qparams_.size(); ++i) {
       filter_scales_[i] = filter_qparams_[i].scale;
       filter_zero_points_[i] = filter_qparams_[i].zero_point;
     }
 
     if (depthwise_3x3_fast_path) {
-      Wq_depthwise_3x3_packed_.reset(new fbgemm::Packed3x3ConvMatrix(
-          group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
+      if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+        const auto& packed_filter =
+            this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+        Wq_depthwise_3x3_packed_ = packed_filter.W_depthwise_3x3;
+      } else {
+        Wq_depthwise_3x3_packed_.reset(new fbgemm::Packed3x3ConvMatrix(
+            group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
+      }
     } else if (depthwise_3x3x3_fast_path) {
-      Wq_depthwise_3x3x3_packed_.reset(new fbgemm::Packed3x3x3ConvMatrix(
-          group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
+      if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+        const auto& packed_filter =
+            this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+        Wq_depthwise_3x3x3_packed_ = packed_filter.W_depthwise_3x3x3;
+      } else {
+        Wq_depthwise_3x3x3_packed_.reset(new fbgemm::Packed3x3x3ConvMatrix(
+            group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
+      }
     } else if (packW) {
-      // fast path using fbgemm
-      Wq_packed_.reset(new fbgemm::PackBMatrix<int8_t>(
-          fbgemm::matrix_op_t::Transpose,
-          group_ * kernel_dim,
-          M / group_,
-          reinterpret_cast<const int8_t*>(W_quantized_.data()),
-          kernel_dim, // ld
-          nullptr, // pmat
-          group_));
+      if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
+        const auto& packed_filter =
+            this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
+        Wq_packed_ = packed_filter.W;
+      } else {
+        // fast path using fbgemm
+        Wq_packed_.reset(new fbgemm::PackBMatrix<int8_t>(
+            fbgemm::matrix_op_t::Transpose,
+            group_ * kernel_dim,
+            M / group_,
+            reinterpret_cast<const int8_t*>(W_quantized_.data()),
+            kernel_dim, // ld
+            nullptr, // pmat
+            group_));
+      }
     } else {
       string reason;
       if (ConvPoolOpBase<CPUContext>::order_ != StorageOrder::NHWC) {
@@ -499,7 +505,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNCHW_(
   // See batch_matmul_dnnlowp_op.cc to why we compute column_offsets,
   // row_offset, and const_offset in this way.
   int tid = dnnlowp_get_thread_num();
-  int32_t* column_offsets = column_offsets_.data() + tid * Y_HxW;
+  int32_t* column_offsets = column_offsets_->data() + tid * Y_HxW;
 
   const dnnlowp::TensorQuantizationParams& filter_qparams =
       FilterQuantizationParams(group_id);
@@ -622,7 +628,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
   } else {
     Y_data_T = Y->template mutable_data<T>();
   }
-  column_offsets_.resize(Y_HxW * dnnlowp_get_max_threads());
+  column_offsets_->resize(Y_HxW * dnnlowp_get_max_threads());
 
   auto f = [&](Tensor* col_buffer) {
     col_buffer->Resize(buffer_shape);
@@ -782,7 +788,8 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
         for (int j = group_id * (M / group_); j < (group_id + 1) * (M / group_);
              ++j) {
           Y_int32[i * M + j] -=
-              in_qparams_[INPUT].zero_point * column_offsets_[j] + row_offset;
+              in_qparams_[INPUT].zero_point * (*column_offsets_)[j] +
+              row_offset;
           Ydata[i * M + j] = Y_int32[i * M + j] * in_qparams_[INPUT].scale *
                   FilterQuantizationParams(group_id).scale +
               ((InputSize() == 3) ? b_dequantized_data_[j] : 0.f);
@@ -828,7 +835,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
                j < (group_id + 1) * (M / group_);
                ++j) {
             int32_t raw = Y_int32[i * M + j] -
-                A_zero_point * column_offsets_[j] - row_offset;
+                A_zero_point * (*column_offsets_)[j] - row_offset;
             if (b_quantized_data_) {
               raw += b_quantized_data_[j];
             }
@@ -891,7 +898,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
               A_zero_point,
               &B_zero_point,
               &row_offset,
-              column_offsets_.data() + group_id * (M / group_),
+              column_offsets_->data() + group_id * (M / group_),
               b_quantized_data_ ? b_quantized_data_ + group_id * (M / group_)
                                 : nullptr,
               M / group_,
@@ -916,7 +923,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
                j < (group_id + 1) * (M / group_);
                ++j) {
             int32_t raw = Y_int32[i * M + j] -
-                A_zero_point * column_offsets_[j] - row_offset;
+                A_zero_point * (*column_offsets_)[j] - row_offset;
             if (b_quantized_data_) {
               raw += b_quantized_data_[j];
             }
@@ -1093,7 +1100,7 @@ void ConvDNNLowPOp<T, ReluFused>::DispatchFBGEMM(
         in_qparams_[INPUT].zero_point,
         filter_zero_points_.data(),
         packA.getRowOffsetBuffer(),
-        column_offsets_.data(),
+        column_offsets_->data(),
         InputSize() == 3 ? b_quantized_data_ : nullptr,
         M,
         group_);
@@ -1116,7 +1123,7 @@ void ConvDNNLowPOp<T, ReluFused>::DispatchFBGEMM(
         in_qparams_[INPUT].zero_point,
         filter_zero_points_.data(),
         packA.getRowOffsetBuffer(),
-        column_offsets_.data(),
+        column_offsets_->data(),
         InputSize() == 3 ? b_dequantized_data_ : nullptr,
         M,
         group_);
@@ -1187,7 +1194,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
         requantization_params_[0].real_multiplier,
         out_qparams_.zero_point,
         Y_uint8_data,
-        column_offsets_.data(),
+        column_offsets_->data(),
         b_quantized_data_,
         ReluFused,
         dnnlowp_get_thread_num(),
@@ -1217,7 +1224,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
         requantization_params_[0].real_multiplier,
         out_qparams_.zero_point,
         Y_uint8_data,
-        column_offsets_.data(),
+        column_offsets_->data(),
         b_quantized_data_,
         dnnlowp_get_thread_num(),
         dnnlowp_get_num_threads(),
index 1c9c8ee..7807035 100644 (file)
@@ -68,7 +68,7 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   std::vector<T_signed> W_quantized_;
 
   // pre-computed biases and offsets
-  std::vector<std::int32_t> column_offsets_;
+  std::shared_ptr<std::vector<std::int32_t>> column_offsets_;
   std::vector<std::int32_t> row_offsets_;
   const std::int32_t* b_quantized_data_{nullptr};
 
@@ -122,15 +122,15 @@ class ConvDNNLowPOp : public ConvPoolDNNLowPOpBase<T, ConvFp32Op> {
   std::vector<dnnlowp::RequantizationParams> requantization_params_;
 
   // used in fast path for T == uint8_t
-  std::unique_ptr<fbgemm::PackBMatrix<std::int8_t>> Wq_packed_;
+  std::shared_ptr<fbgemm::PackBMatrix<std::int8_t>> Wq_packed_;
 
   // For depthwise 3x3 conv
-  std::unique_ptr<fbgemm::Packed3x3ConvMatrix> Wq_depthwise_3x3_packed_;
+  std::shared_ptr<fbgemm::Packed3x3ConvMatrix> Wq_depthwise_3x3_packed_;
   // For depthwise 3x3x3 conv
-  std::unique_ptr<fbgemm::Packed3x3x3ConvMatrix> Wq_depthwise_3x3x3_packed_;
+  std::shared_ptr<fbgemm::Packed3x3x3ConvMatrix> Wq_depthwise_3x3x3_packed_;
 
   // pre-computed biases and offsets
-  std::vector<std::int32_t> b_quantized_;
+  std::shared_ptr<std::vector<std::int32_t>> b_quantized_;
 
   // Dequantized bias populated when input bias is quantized and
   // dequantized_output_ == true
index 0bc5b75..3678780 100644 (file)
@@ -5,7 +5,6 @@ import collections
 import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 from caffe2.python import core, dyndep
-from caffe2.python.fb import hardcode_scale_zp
 from caffe2.quantization.server import utils as dnnlowp_utils
 from dnnlowp_test_utils import (
     check_quantized_results_close,
@@ -36,6 +35,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
         weight_quantized=st.booleans(),
+        prepack_weight=st.booleans(),
         share_col_buffer=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
@@ -56,6 +56,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         in_quantized,
         out_quantized,
         weight_quantized,
+        prepack_weight,
         share_col_buffer,
         preserve_activation_sparsity,
         preserve_weight_sparsity,
@@ -63,6 +64,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         dc,
     ):
         assume(group == 1 or dilation == 1)
+        assume((not prepack_weight) or order == "NHWC")
 
         X, W, b = generate_conv_inputs(
             stride,
@@ -90,6 +92,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine and in_quantized
@@ -101,6 +104,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
             do_quantize_weight = (
                 engine == "DNNLOWP" and weight_quantized and len(outputs) > 0
             )
+            do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -113,26 +117,41 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            x_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max(), preserve_activation_sparsity)  # noqa
             if do_quantize_weight:
                 int8_given_tensor_fill, w_q_param = dnnlowp_utils.create_int8_given_tensor_fill(
                     W, "W_q", preserve_weight_sparsity
                 )
-                net.Proto().op.extend([int8_given_tensor_fill])
+                init_net.Proto().op.extend([int8_given_tensor_fill])
 
                 # Bias
-                x_q_param = hardcode_scale_zp.choose_quantization_params(
-                    X.min(), X.max()
-                )
                 int8_bias_tensor_fill = dnnlowp_utils.create_int8_bias_tensor_fill(
                     b, "b_q", x_q_param, w_q_param
                 )
-                net.Proto().op.extend([int8_bias_tensor_fill])
+                init_net.Proto().op.extend([int8_bias_tensor_fill])
+
+            if do_prepack_weight:
+                inputs = ["W_q" if do_quantize_weight else "W"]
+                if do_dequantize:
+                    inputs += ["b_q" if do_quantize_weight else "b"]
+                pack = core.CreateOperator(
+                    "Int8ConvPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    group=group,
+                    preserve_weight_sparsity=preserve_weight_sparsity,
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
 
             conv = core.CreateOperator(
                 op_type,
                 [
                     "X_q" if do_quantize else "X",
-                    "W_q" if do_quantize_weight else "W",
+                    "W_packed"
+                    if do_prepack_weight
+                    else ("W_q" if do_quantize_weight else "W"),
                     "b_q" if do_quantize_weight else "b",
                 ],
                 ["Y_q" if do_dequantize else "Y"],
@@ -149,7 +168,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
                 group=group,
                 device_option=gc,
             )
-            if do_quantize_weight:
+            if do_quantize_weight or do_prepack_weight:
                 # When quantized weight is provided, we can't rescale the
                 # output dynamically by looking at the range of output of each
                 # batch, so here we provide the range of output observed from
@@ -168,6 +187,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             Y = self.ws.blobs["Y"].fetch()
             outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order))
@@ -302,10 +322,12 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
+        prepack_weight,
         gc,
         dc,
     ):
         assume(group == 1 or dilation == 1)
+        assume((not prepack_weight) or order == "NHWC")
         ndim = len(kernels)
 
         X, W, b = generate_convnd_inputs(
@@ -327,6 +349,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         op_engine_list = [("Conv", ""), ("Conv", "DNNLOWP_16"), ("Int8Conv", "DNNLOWP")]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             fall_back_to_NCHW = "DNNLOWP" not in engine and order == "NHWC"
@@ -342,6 +365,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
             # Make sure atleast one output is collected to compute output
             # scale/zp.
             do_quantize_weight = engine == "DNNLOWP" and len(outputs) > 0
+            do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -349,26 +373,40 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            x_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max())
             if do_quantize_weight:
                 int8_given_tensor_fill, w_q_param = dnnlowp_utils.create_int8_given_tensor_fill(
                     W, "W_q"
                 )
-                net.Proto().op.extend([int8_given_tensor_fill])
+                init_net.Proto().op.extend([int8_given_tensor_fill])
 
                 # Bias
-                x_q_param = hardcode_scale_zp.choose_quantization_params(
-                    X.min(), X.max()
-                )
                 int8_bias_tensor_fill = dnnlowp_utils.create_int8_bias_tensor_fill(
                     b, "b_q", x_q_param, w_q_param
                 )
-                net.Proto().op.extend([int8_bias_tensor_fill])
+                init_net.Proto().op.extend([int8_bias_tensor_fill])
+
+            if do_prepack_weight:
+                inputs = ["W_q" if do_quantize_weight else "W"]
+                if do_dequantize:
+                    inputs += ["b_q" if do_quantize_weight else "b"]
+                pack = core.CreateOperator(
+                    "Int8ConvPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    group=group,
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
 
             conv = core.CreateOperator(
                 op_type,
                 [
                     "X_q" if do_quantize else "X",
-                    "W_q" if do_quantize_weight else "W",
+                    "W_packed"
+                    if do_prepack_weight
+                    else ("W_q" if do_quantize_weight else "W"),
                     "b_q" if do_quantize_weight else "b",
                 ],
                 ["Y_q" if do_dequantize else "Y"],
@@ -382,7 +420,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
                 group=group,
                 device_option=gc,
             )
-            if do_quantize_weight:
+            if do_quantize_weight or do_prepack_weight:
                 # When quantized weight is provided, we can't rescale the
                 # output dynamically by looking at the range of output of each
                 # batch, so here we provide the range of output observed from
@@ -403,6 +441,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
                 W_nchw if fall_back_to_NCHW else W, device_option=gc
             )
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             Y = self.ws.blobs["Y"].fetch()
             if fall_back_to_NCHW:
@@ -423,6 +462,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group=st.sampled_from([2, 3]),
         batch_size=st.integers(1, 2),
         order=st.sampled_from(["NCHW", "NHWC"]),
+        prepack_weight=st.booleans(),
         **hu.gcs_cpu_only
     )
     def test_dnnlowp_conv3d_int(
@@ -438,6 +478,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
+        prepack_weight,
         gc,
         dc,
     ):
@@ -452,6 +493,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
             output_channels_per_group,
             batch_size,
             order,
+            prepack_weight,
             gc,
             dc,
         )
@@ -467,6 +509,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group=st.sampled_from([2, 3]),
         batch_size=st.integers(1, 2),
         order=st.sampled_from(["NCHW", "NHWC"]),
+        prepack_weight=st.booleans(),
         **hu.gcs_cpu_only
     )
     def test_dnnlowp_conv1d_int(
@@ -481,6 +524,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
         output_channels_per_group,
         batch_size,
         order,
+        prepack_weight,
         gc,
         dc,
     ):
@@ -495,6 +539,7 @@ class DNNLowPOpConvTest(hu.HypothesisTestCase):
             output_channels_per_group,
             batch_size,
             order,
+            prepack_weight,
             gc,
             dc,
         )
index 3eeae96..032c04f 100644 (file)
@@ -192,7 +192,8 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         order=st.sampled_from(["NHWC"]),
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
-        nbits_in_non_outlier=st.sampled_from((0, 6)),
+        prepack_weight=st.booleans(),
+        nbits_in_non_outlier=st.sampled_from((6, 8)),
         share_col_buffer=st.booleans(),
         **hu.gcs_cpu_only
     )
@@ -210,6 +211,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         order,
         in_quantized,
         out_quantized,
+        prepack_weight,
         nbits_in_non_outlier,
         share_col_buffer,
         gc,
@@ -280,10 +282,12 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
         ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine and in_quantized
             do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_prepack_weight = "DNNLOWP" in engine and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -291,9 +295,30 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            if do_prepack_weight:
+                x_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max())
+                inputs = ["W"]
+                if do_dequantize:
+                    inputs += ["b"]
+                pack = core.CreateOperator(
+                    "Int8ConvPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    group=group,
+                    quantize_groupwise=1,
+                    nbits_in_non_outlier=nbits_in_non_outlier,
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
+
             conv = core.CreateOperator(
                 op_type,
-                ["X_q" if do_quantize else "X", "W", "b"],
+                [
+                    "X_q" if do_quantize else "X",
+                    "W_packed" if do_prepack_weight else "W",
+                    "b",
+                ],
                 ["Y_q" if do_dequantize else "Y"],
                 stride=stride,
                 kernel=kernel,
@@ -308,7 +333,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
                 quantize_groupwise=1,
                 device_option=gc,
             )
-            if do_dequantize:
+            if do_dequantize or do_prepack_weight:
                 # groupwise quantization only works with static quantization
                 # so we need to set quantization parameters
                 dnnlowp_utils.add_quantization_param_args(conv, outputs[0][0])
@@ -323,6 +348,7 @@ class GroupWiseDNNLowPOpConvAcc16OpTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             Y = self.ws.blobs["Y"].fetch()
             outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order))
index e20bbee..9ca2461 100644 (file)
@@ -5,11 +5,9 @@ import collections
 import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 from caffe2.python import core, dyndep
+from caffe2.python.fb import hardcode_scale_zp
 from caffe2.quantization.server import utils as dnnlowp_utils
-from dnnlowp_test_utils import (
-    check_quantized_results_close,
-    generate_conv_inputs,
-)
+from dnnlowp_test_utils import check_quantized_results_close, generate_conv_inputs
 from hypothesis import assume, given
 
 
@@ -31,6 +29,7 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
         order=st.sampled_from(["NCHW", "NHWC"]),
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
+        prepack_weight=st.booleans(),
         preserve_activation_sparsity=st.booleans(),
         preserve_weight_sparsity=st.booleans(),
         **hu.gcs_cpu_only
@@ -49,12 +48,14 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
         order,
         in_quantized,
         out_quantized,
+        prepack_weight,
         preserve_activation_sparsity,
         preserve_weight_sparsity,
         gc,
         dc,
     ):
         assume(group == 1 or dilation == 1)
+        assume((not prepack_weight) or order == "NHWC")
 
         X, W, b = generate_conv_inputs(
             stride,
@@ -83,10 +84,12 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
         ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine and in_quantized
             do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -99,6 +102,25 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            if do_prepack_weight:
+                x_q_param = hardcode_scale_zp.choose_quantization_params(
+                    X.min(), X.max()
+                )
+                inputs = ["W"]
+                if do_dequantize:
+                    inputs += ["b"]
+                pack = core.CreateOperator(
+                    "Int8ConvPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    group=group,
+                    quantize_groupwise=1,
+                    preserve_weight_sparsity=preserve_weight_sparsity,
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
+
             conv = core.CreateOperator(
                 op_type,
                 ["X_q" if do_quantize else "X", "W", "b"],
@@ -116,7 +138,7 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
                 quantize_groupwise=1,
                 device_option=gc,
             )
-            if do_dequantize:
+            if do_dequantize or do_prepack_weight:
                 # groupwise quantization only works with static quantization
                 # so we need to set quantization parameters
                 dnnlowp_utils.add_quantization_param_args(
@@ -138,6 +160,7 @@ class GroupWiseDNNLowPOpConvTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             Y = self.ws.blobs["Y"].fetch()
             outputs.append(Output(Y=Y, op_type=op_type, engine=engine, order=order))
index e5606d4..9b553d5 100644 (file)
@@ -6,6 +6,7 @@
 
 #include "caffe2/core/tensor_int8.h"
 #include "caffe2/operators/conv_pool_op_base.h"
+#include "caffe2/quantization/server/fbgemm_pack_blob.h"
 #include "caffe2/quantization/server/op_wrapper.h"
 
 #ifdef _OPENMP
@@ -44,9 +45,13 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
 
  protected:
   const TensorCPU& InputTensorCPU_(int idx) {
-    return InputIsType<int8::Int8TensorCPU>(idx)
-        ? OperatorBase::Input<int8::Int8TensorCPU>(idx).t
-        : Input(idx);
+    if (InputIsType<int8::Int8TensorCPU>(idx)) {
+      return this->Input<int8::Int8TensorCPU>(idx).t;
+    } else if (InputIsType<Int8ConvDNNLowPPackedWeightBlob>(idx)) {
+      return this->Input<Int8ConvDNNLowPPackedWeightBlob>(idx).original_tensor;
+    } else {
+      return Input(idx);
+    }
   }
 
   TensorCPU* OutputTensorCPU_(int idx) {
index 1ff5156..b67fa1a 100644 (file)
@@ -8,6 +8,7 @@
 #include "caffe2/core/tensor_int8.h"
 #include "caffe2/quantization/server/caffe2_dnnlowp_utils.h"
 #include "caffe2/quantization/server/dnnlowp.h"
+#include "caffe2/quantization/server/fbgemm_pack_blob.h"
 #include "caffe2/quantization/server/op_wrapper.h"
 #include "caffe2/quantization/server/sigmoid.h"
 #include "caffe2/quantization/server/tanh.h"
@@ -92,9 +93,13 @@ class DNNLowPOp : public Operator<CPUContext> {
 
  protected:
   const TensorCPU& InputTensorCPU_(int idx) {
-    return InputIsType<int8::Int8TensorCPU>(idx)
-        ? OperatorBase::Input<int8::Int8TensorCPU>(idx).t
-        : Input(idx);
+    if (InputIsType<int8::Int8TensorCPU>(idx)) {
+      return this->Input<int8::Int8TensorCPU>(idx).t;
+    } else if (InputIsType<Int8FCDNNLowPPackedWeightBlob>(idx)) {
+      return this->Input<Int8FCDNNLowPPackedWeightBlob>(idx).original_tensor;
+    } else {
+      return Input(idx);
+    }
   }
 
   TensorCPU* OutputTensorCPU_(int idx) {
diff --git a/caffe2/quantization/server/fbgemm_pack_blob.h b/caffe2/quantization/server/fbgemm_pack_blob.h
new file mode 100644 (file)
index 0000000..56396cc
--- /dev/null
@@ -0,0 +1,44 @@
+#pragma once
+
+#include <memory>
+
+#include <fbgemm/Fbgemm.h>
+#include <fbgemm/src/FbgemmI8DepthwiseAvx2.h>
+
+#include "caffe2/quantization/server/dnnlowp.h"
+
+namespace caffe2 {
+
+/**
+ * Packed weight matrix for DNNLOWP Int8FC operator
+ */
+struct Int8FCDNNLowPPackedWeightBlob {
+  std::vector<dnnlowp::TensorQuantizationParams> qparams;
+  std::shared_ptr<std::vector<std::int32_t>> column_offsets;
+
+  // The original tensor before packing
+  Tensor original_tensor{CPU};
+
+  std::shared_ptr<std::vector<std::int32_t>> bias;
+
+  // Only for 32-bit accumulation
+  std::shared_ptr<fbgemm::PackBMatrix<std::int8_t>> W;
+
+  // Only for 16-bit accumulation
+  // Dense matrix holding common values
+  std::shared_ptr<fbgemm::PackBMatrix<std::int8_t, std::int16_t>> W_acc16;
+  // Sparse matrix holding outliers
+  std::shared_ptr<fbgemm::CompressedSparseColumn> W_outlier;
+  int nbits_in_non_outlier;
+};
+
+/**
+ * Packed weight matrix for DNNLOWP Int8Conv operator
+ */
+struct Int8ConvDNNLowPPackedWeightBlob : public Int8FCDNNLowPPackedWeightBlob {
+  // Only for 32-bit accumulation
+  std::shared_ptr<fbgemm::Packed3x3ConvMatrix> W_depthwise_3x3;
+  std::shared_ptr<fbgemm::Packed3x3x3ConvMatrix> W_depthwise_3x3x3;
+};
+
+} // namespace caffe2
diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc
new file mode 100644 (file)
index 0000000..2a0ce89
--- /dev/null
@@ -0,0 +1,496 @@
+#include "fbgemm_pack_op.h"
+
+#include "caffe2/core/tensor_int8.h"
+
+#include "caffe2_dnnlowp_utils.h"
+
+C10_DECLARE_int32(dnnlowp_nbits_in_non_outlier);
+
+namespace caffe2 {
+
+using namespace std;
+using dnnlowp::TensorQuantizationParams;
+
+// Helper functions
+
+template <typename T>
+void QuantizeWeight(
+    const Blob& blob,
+    int kernel_dim,
+    int M,
+    vector<TensorQuantizationParams>& qparams,
+    vector<typename make_signed<T>::type>& W_quantized,
+    dnnlowp::QuantizationFactory* qfactory) {
+  using T_signed = typename make_signed<T>::type;
+
+  const auto& filter = blob.IsType<int8::Int8TensorCPU>()
+      ? blob.Get<int8::Int8TensorCPU>().t
+      : blob.Get<TensorCPU>();
+
+  W_quantized.resize(filter.numel());
+
+  int signed_min = -(1 << (qfactory->GetWeightPrecision() - 1));
+  if (blob.IsType<int8::Int8TensorCPU>()) {
+    qparams[0].scale = blob.Get<int8::Int8TensorCPU>().scale;
+    qparams[0].zero_point =
+        blob.Get<int8::Int8TensorCPU>().zero_point + signed_min;
+
+    const T* W_data = filter.data<T>();
+    for (auto i = 0; i < filter.numel(); ++i) {
+      W_quantized[i] = W_data[i] + signed_min;
+    }
+  } else {
+    for (int g = 0; g < qparams.size(); ++g) {
+      size_t offset = g * (M / qparams.size()) * kernel_dim;
+      qparams[g] = qfactory->ChooseQuantizationParams(
+          filter.data<float>() + offset,
+          (M / qparams.size()) * kernel_dim,
+          true /*weight*/);
+
+      // qparams[g] is computed for unsigned type.
+      // Adjust for the fact that weight will actually use signed.
+      qparams[g].zero_point += signed_min;
+
+      fbgemm::Quantize<T_signed>(
+          filter.data<float>() + offset,
+          W_quantized.data() + offset,
+          (M / qparams.size()) * kernel_dim,
+          qparams[g]);
+    }
+  }
+}
+
+template void QuantizeWeight<uint8_t>(
+    const Blob& blob,
+    int kernel_dim,
+    int M,
+    vector<TensorQuantizationParams>& qparams,
+    vector<int8_t>& W_quantized,
+    dnnlowp::QuantizationFactory* qfactory);
+
+template void QuantizeWeight<uint16_t>(
+    const Blob& blob,
+    int kernel_dim,
+    int M,
+    vector<TensorQuantizationParams>& qparams,
+    vector<int16_t>& W_quantized,
+    dnnlowp::QuantizationFactory* qfactory);
+
+// TODO reuse col_offsets_with_zero_pt_s8acc32_ref in fbgemm
+// RefImplementations.cc . We can't do this now because W_quantized is
+// not transposed here.
+template <typename T>
+void ComputeColumnOffsets(
+    int num_rows,
+    int num_cols,
+    const T* W,
+    const vector<TensorQuantizationParams>& qparams,
+    vector<int32_t>& col_offsets) {
+  col_offsets.resize(num_cols);
+  int num_quant_groups = qparams.size();
+  for (int g = 0; g < num_quant_groups; ++g) {
+    int j_begin = g * (num_cols / num_quant_groups);
+    int j_end = j_begin + (num_cols / num_quant_groups);
+    for (int j = j_begin; j < j_end; ++j) {
+      int32_t sum = 0;
+      for (int k = 0; k < num_rows; ++k) {
+        sum += W[j * num_rows + k];
+      }
+      col_offsets[j] = sum - qparams[g].zero_point * num_rows;
+    }
+  }
+}
+
+template void ComputeColumnOffsets<int8_t>(
+    int num_rows,
+    int num_cols,
+    const int8_t* W,
+    const vector<TensorQuantizationParams>& qparams,
+    vector<int32_t>& col_offsets);
+
+template void ComputeColumnOffsets<int16_t>(
+    int num_rows,
+    int num_cols,
+    const int16_t* W,
+    const vector<TensorQuantizationParams>& qparams,
+    vector<int32_t>& col_offsets);
+
+fbgemm::CompressedSparseColumn* ExtractOutlierMatrix(
+    int groups,
+    int kernel_dim,
+    int M,
+    int nbits_in_non_outlier,
+    vector<int8_t>& W_quantized) {
+  int outlier_cnt = 0;
+  for (int group_id = 0; group_id < groups; ++group_id) {
+    for (int i = 0; i < (M / groups) * kernel_dim; ++i) {
+      int8_t w = W_quantized[group_id * (M / groups) * kernel_dim + i];
+      bool is_outlier = nbits_in_non_outlier == 0 ||
+          w < -(1 << (nbits_in_non_outlier - 1)) ||
+          w >= (1 << (nbits_in_non_outlier - 1));
+      if (is_outlier) {
+        ++outlier_cnt;
+      }
+    }
+  }
+
+  fbgemm::CompressedSparseColumn* Wq_outlier =
+      new fbgemm::CompressedSparseColumn(kernel_dim, M);
+  Wq_outlier->RowIdx().resize(outlier_cnt);
+  Wq_outlier->Values().resize(outlier_cnt);
+
+  outlier_cnt = 0;
+  for (int group_id = 0; group_id < groups; ++group_id) {
+    for (int j = 0; j < M / groups; ++j) {
+      Wq_outlier->ColPtr()[group_id * (M / groups) + j] = outlier_cnt;
+
+      for (int k = 0; k < kernel_dim; ++k) {
+        int8_t w = W_quantized[(group_id * (M / groups) + j) * kernel_dim + k];
+        bool is_outlier = nbits_in_non_outlier == 0 ||
+            w < -(1 << (nbits_in_non_outlier - 1)) ||
+            w >= (1 << (nbits_in_non_outlier - 1));
+        if (is_outlier) {
+          CAFFE_ENFORCE_LE(k, numeric_limits<int16_t>::max());
+          Wq_outlier->RowIdx()[outlier_cnt] = k;
+          Wq_outlier->Values()[outlier_cnt] = w;
+          ++outlier_cnt;
+
+          W_quantized[(group_id * (M / groups) + j) * kernel_dim + k] = 0;
+        }
+      }
+    }
+  } // for each group
+  Wq_outlier->ColPtr()[M] = outlier_cnt;
+
+  return Wq_outlier;
+}
+
+// FIXME: code duplication with ConvDNNLowPOp::QuantizeBias_
+static void QuantizeConvBias(
+    const Blob& blob,
+    int M,
+    const TensorQuantizationParams& in_qparams,
+    const vector<TensorQuantizationParams>& filter_qparams,
+    vector<int32_t>& b_quantized) {
+  const auto& bias = blob.IsType<int8::Int8TensorCPU>()
+      ? blob.Get<int8::Int8TensorCPU>().t
+      : blob.Get<TensorCPU>();
+  if (blob.IsType<int8::Int8TensorCPU>()) {
+    TensorQuantizationParams bias_qparams;
+    bias_qparams.scale = blob.Get<int8::Int8TensorCPU>().scale;
+    bias_qparams.zero_point = blob.Get<int8::Int8TensorCPU>().zero_point;
+    CAFFE_ENFORCE_LE(
+        std::abs(
+            bias_qparams.scale - in_qparams.scale * filter_qparams[0].scale),
+        1e-4);
+    CAFFE_ENFORCE_EQ(bias_qparams.zero_point, 0);
+    b_quantized.resize(bias.numel());
+    b_quantized.assign(
+        bias.data<int32_t>(), bias.data<int32_t>() + bias.numel());
+  } else {
+    const float* bdata = bias.data<float>();
+    b_quantized.resize(bias.numel());
+    for (int g = 0; g < filter_qparams.size(); ++g) {
+      int i_begin = g * (M / filter_qparams.size());
+      int i_end = i_begin + (M / filter_qparams.size());
+      for (int i = i_begin; i < i_end; ++i) {
+        b_quantized[i] = fbgemm::Quantize<int32_t>(
+            bdata[i],
+            0,
+            in_qparams.scale * filter_qparams[g].scale,
+            32,
+            true /* signed */);
+      }
+    }
+  }
+}
+
+// FullyConnectedDNNLowPPackWeightOp
+
+FullyConnectedDNNLowPPackWeightOp::FullyConnectedDNNLowPPackWeightOp(
+    const OperatorDef& operator_def,
+    Workspace* ws)
+    : DNNLowPOp<uint8_t, FCFp32Op>(operator_def, ws),
+      axis_w_(this->GetSingleArgument<int32_t>("axis_w", 1)) {
+  if (this->debug_def().engine() == "DNNLOWP_ACC16") {
+    nbits_in_non_outlier_ = this->GetSingleArgument<int>(
+        "nbits_in_non_outlier", FLAGS_dnnlowp_nbits_in_non_outlier);
+  }
+}
+
+bool FullyConnectedDNNLowPPackWeightOp::RunOnDevice() {
+  const auto& filter = InputTensorCPU_(0);
+  const auto canonical_axis_w = filter.canonical_axis_index(axis_w_);
+  const auto K = filter.size_from_dim(canonical_axis_w);
+  const auto N = filter.size_to_dim(canonical_axis_w);
+
+  auto* Y = this->Output<Int8FCDNNLowPPackedWeightBlob>(0);
+
+  // Create tensor with the same shape but this new tensor shouldn't actually
+  // allocate memory for the tensor.
+  // This is just a convenient way to pass tensor shape information
+  Y->original_tensor.ResizeLike(filter);
+
+  Y->qparams.resize((this->debug_def().engine() == "DNNLOWP_ROWWISE") ? N : 1);
+
+  vector<int8_t> W_quantized;
+  QuantizeWeight<uint8_t>(
+      InputBlob(0), K, N, Y->qparams, W_quantized, qfactory_.get());
+
+  if (this->InputIsType<int8::Int8TensorCPU>(0) &&
+      this->debug_def().engine() == "DNNLOWP_ROWWISE") {
+    static int log_occurences = 0;
+    if (log_occurences < 32) {
+      ++log_occurences;
+      LOG(WARNING) << "Cannot do row-wise quantization for "
+                      "pre-quantized weight "
+                   << this->debug_def().input(0);
+    }
+  }
+
+  // Pre-compute column offsets
+  // This should happen before ExtractOutlierMatrix because W_quantized is
+  // changed in ExtractOutlierMatrix.
+  Y->column_offsets.reset(new vector<int32_t>());
+  ComputeColumnOffsets(
+      K, N, W_quantized.data(), Y->qparams, *Y->column_offsets);
+
+  if (this->debug_def().engine() == "DNNLOWP_ACC16") {
+    if (nbits_in_non_outlier_ < 8) {
+      Y->W_outlier.reset(
+          ExtractOutlierMatrix(1, K, N, nbits_in_non_outlier_, W_quantized));
+      int outlier_cnt = Y->W_outlier->ColPtr()[N];
+
+      LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
+                << this->debug_def().input(0) << " is "
+                << static_cast<float>(outlier_cnt) / W_quantized.size();
+      LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_;
+    }
+
+    Y->nbits_in_non_outlier = nbits_in_non_outlier_;
+    Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
+        fbgemm::matrix_op_t::Transpose,
+        K,
+        N,
+        W_quantized.data(),
+        K,
+        nullptr, // pmat
+        1)); // group
+  } else {
+    Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
+        fbgemm::matrix_op_t::Transpose,
+        K,
+        N,
+        W_quantized.data(),
+        K,
+        nullptr, // pmat
+        1)); // group
+  }
+
+  // Quantize bias
+  if (InputSize() >= 2) {
+    TensorQuantizationParams in_qparams;
+    CAFFE_ENFORCE(HasSingleArgumentOfType<float>("in_scale"));
+    in_qparams.scale = GetSingleArgument<float>("in_scale", 0);
+    Y->bias.reset(new vector<int32_t>());
+    QuantizeConvBias(InputBlob(1), N, in_qparams, Y->qparams, *Y->bias);
+  } else {
+    Y->bias = nullptr;
+  }
+
+  return true;
+}
+
+// ConvDNNLowPPackWeightOp
+
+ConvDNNLowPPackWeightOp::ConvDNNLowPPackWeightOp(
+    const OperatorDef& operator_def,
+    Workspace* ws)
+    : ConvPoolDNNLowPOpBase<uint8_t, ConvFp32Op>(operator_def, ws),
+      quantize_groupwise_(
+          this->GetSingleArgument<bool>("quantize_groupwise", false)) {
+  if (this->debug_def().engine() == "DNNLOWP_ACC16") {
+    nbits_in_non_outlier_ = this->GetSingleArgument<int>(
+        "nbits_in_non_outlier", FLAGS_dnnlowp_nbits_in_non_outlier);
+  }
+}
+
+bool ConvDNNLowPPackWeightOp::TakeDepthWise3x3FastPath_() {
+  const auto& filter = this->InputTensorCPU_(FILTER);
+  // The number of output channels
+  int M = filter.dim32(0);
+  // The number of input channels per group
+  int C_per_group = filter.dim32(filter.dim() - 1);
+  return this->debug_def().engine() != "DNNLOWP_ACC16" && group_ == M &&
+      C_per_group == 1 && group_ % 8 == 0 && this->kernel_.size() == 2 &&
+      kernel_h() == 3 && kernel_w() == 3 && stride_h() == stride_w() &&
+      (stride_h() == 1 || stride_h() == 2) && dilation_h() == 1 &&
+      dilation_w() == 1 && pad_t() == 1 && pad_b() == 1 && pad_l() == 1 &&
+      pad_r() == 1 && GetCpuId().avx2() && !quantize_groupwise_;
+}
+
+bool ConvDNNLowPPackWeightOp::TakeDepthWise3x3x3FastPath_() {
+  const auto& filter = this->InputTensorCPU_(FILTER);
+  // The number of output channels
+  int M = filter.dim32(0);
+  // The number of input channels per group
+  int C_per_group = filter.dim32(filter.dim() - 1);
+  bool ret = this->debug_def().engine() != "DNNLOWP_ACC16" && group_ == M &&
+      C_per_group == 1 && group_ % 8 == 0 && this->kernel_.size() == 3 &&
+      this->kernel_[0] == 3 && this->kernel_[1] == 3 && this->kernel_[2] == 3 &&
+      this->stride_[0] == this->stride_[1] &&
+      this->stride_[0] == this->stride_[2] &&
+      (this->stride_[0] == 1 || this->stride_[0] == 2) &&
+      this->dilation_[0] == 1 && this->dilation_[1] == 1 &&
+      this->dilation_[2] == 1 &&
+      accumulate(
+          this->pads_.begin(), this->pads_.end(), 1, multiplies<int>()) == 1 &&
+      GetCpuId().avx2() && !quantize_groupwise_;
+  return ret;
+}
+
+bool ConvDNNLowPPackWeightOp::RunOnDevice() {
+  const auto& filter = InputTensorCPU_(FILTER);
+
+  auto* Y = this->Output<Int8ConvDNNLowPPackedWeightBlob>(0);
+  // Create tensor with the same shape but this new tensor shouldn't actually
+  // allocate memory for the tensor.
+  // This is just a convenient way to pass tensor shape information
+  Y->original_tensor.ResizeLike(filter);
+
+  // Assume KRSC layout
+  // The number of output channels
+  int M = filter.dim32(0);
+  // The number of input channels per group
+  int C_per_group = filter.dim32(filter.dim() - 1);
+
+  int kernel_dims_size = 1;
+  for (int i = 0; i < filter.dim() - 2; ++i) {
+    kernel_dims_size *= filter.dim32(i + 1);
+  }
+  int kernel_dim = C_per_group * kernel_dims_size;
+
+  vector<int8_t> W_quantized;
+  Y->qparams.resize(quantize_groupwise_ ? group_ : 1);
+  QuantizeWeight<uint8_t>(
+      InputBlob(FILTER),
+      kernel_dim,
+      M,
+      Y->qparams,
+      W_quantized,
+      qfactory_.get());
+
+  if (this->InputIsType<int8::Int8TensorCPU>(FILTER) && quantize_groupwise_) {
+    static int log_occurences = 0;
+    if (log_occurences < 32) {
+      ++log_occurences;
+      LOG(WARNING) << "Cannot do group-wise quantization for "
+                      "pre-quantized weight "
+                   << this->debug_def().input(0);
+    }
+  }
+
+  // Pre-compute column offsets
+  // This should happen before ExtractOutlierMatrix because W_quantized is
+  // changed in ExtractOutlierMatrix.
+  Y->column_offsets.reset(new vector<int32_t>());
+  ComputeColumnOffsets(
+      kernel_dim, M, W_quantized.data(), Y->qparams, *Y->column_offsets);
+
+  if (this->debug_def().engine() == "DNNLOWP_ACC16") {
+    if (nbits_in_non_outlier_ < 8) {
+      Y->W_outlier.reset(ExtractOutlierMatrix(
+          group_, kernel_dim, M, nbits_in_non_outlier_, W_quantized));
+      int outlier_cnt = Y->W_outlier->ColPtr()[M];
+
+      LOG(INFO) << "Proportion of outlier for Conv layer with weight blob "
+                << this->debug_def().input(0) << " is "
+                << static_cast<float>(outlier_cnt) / W_quantized.size();
+      LOG(INFO) << "nbits_in_non_outlier " << nbits_in_non_outlier_;
+    }
+
+    Y->nbits_in_non_outlier = nbits_in_non_outlier_;
+    Y->W_acc16.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
+        fbgemm::matrix_op_t::Transpose,
+        group_ * kernel_dim,
+        M / group_,
+        W_quantized.data(),
+        kernel_dim,
+        nullptr, // pmat
+        group_));
+  } else if (TakeDepthWise3x3FastPath_()) {
+    Y->W_depthwise_3x3.reset(
+        new fbgemm::Packed3x3ConvMatrix(group_, W_quantized.data()));
+  } else if (TakeDepthWise3x3x3FastPath_()) {
+    Y->W_depthwise_3x3x3.reset(
+        new fbgemm::Packed3x3x3ConvMatrix(group_, W_quantized.data()));
+  } else {
+    Y->W.reset(new fbgemm::PackBMatrix<int8_t>(
+        fbgemm::matrix_op_t::Transpose,
+        group_ * kernel_dim,
+        M / group_,
+        W_quantized.data(),
+        kernel_dim,
+        nullptr, // pmat
+        group_));
+  }
+
+  if (InputSize() >= 2) {
+    TensorQuantizationParams in_qparams;
+    CAFFE_ENFORCE(HasSingleArgumentOfType<float>("in_scale"));
+    in_qparams.scale = GetSingleArgument<float>("in_scale", 0);
+    Y->bias.reset(new vector<int32_t>());
+    QuantizeConvBias(InputBlob(BIAS), M, in_qparams, Y->qparams, *Y->bias);
+  } else {
+    Y->bias = nullptr;
+  }
+
+  return true;
+}
+
+// Explicitly register TypeMeta
+CAFFE_KNOWN_TYPE(Int8FCDNNLowPPackedWeightBlob);
+CAFFE_KNOWN_TYPE(Int8ConvDNNLowPPackedWeightBlob);
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8FCPackWeight,
+    DNNLOWP,
+    FullyConnectedDNNLowPPackWeightOp);
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8FCPackWeight,
+    DNNLOWP_ACC16,
+    FullyConnectedDNNLowPPackWeightOp);
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8FCPackWeight,
+    DNNLOWP_ROWWISE,
+    FullyConnectedDNNLowPPackWeightOp);
+
+OPERATOR_SCHEMA(Int8FCPackWeight)
+    .NumInputs(1, 2)
+    .NumOutputs(1)
+    .SetDoc(R"DOC(Prepack weight for Int8FC)DOC")
+    .Input(0, "W", "Weight tensor in KRSC layout")
+    .Input(1, "b", "Bias tensor")
+    .Output(0, "W_q", "Weight/bias tensor in a packed format");
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8ConvPackWeight,
+    DNNLOWP,
+    ConvDNNLowPPackWeightOp);
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+    Int8ConvPackWeight,
+    DNNLOWP_ACC16,
+    ConvDNNLowPPackWeightOp);
+
+OPERATOR_SCHEMA(Int8ConvPackWeight)
+    .NumInputs(1, 2)
+    .NumOutputs(1)
+    .SetDoc(R"DOC(Prepack weight for Int8Conv)DOC")
+    .Input(0, "W", "Weight tensor in KRSC layout")
+    .Input(1, "b", "Bias tensor")
+    .Output(0, "W_q", "Weight/bias tensor in a packed format");
+
+} // namespace caffe2
diff --git a/caffe2/quantization/server/fbgemm_pack_op.h b/caffe2/quantization/server/fbgemm_pack_op.h
new file mode 100644 (file)
index 0000000..af6203b
--- /dev/null
@@ -0,0 +1,93 @@
+#pragma once
+
+#include "caffe2/core/operator.h"
+#include "caffe2/operators/conv_op.h"
+#include "caffe2/quantization/server/conv_pool_dnnlowp_op_base.h"
+#include "caffe2/quantization/server/fbgemm_pack_blob.h"
+#include "caffe2/quantization/server/fully_connected_dnnlowp_op.h"
+
+namespace caffe2 {
+
+using FCFp32Op = FullyConnectedOp<CPUContext>;
+
+class FullyConnectedDNNLowPPackWeightOp final
+    : public DNNLowPOp<std::uint8_t, FCFp32Op> {
+ public:
+  FullyConnectedDNNLowPPackWeightOp(
+      const OperatorDef& operator_def,
+      Workspace* ws);
+  USE_OPERATOR_FUNCTIONS(CPUContext);
+
+  bool RunOnDevice() override;
+
+ private:
+  int axis_w_;
+  int nbits_in_non_outlier_; // only for DNNLOWP_ACC16
+
+  INPUT_TAGS(FILTER, BIAS);
+};
+
+using ConvFp32Op = ConvOp<float, CPUContext>;
+
+/**
+ * Pack a weight matrix that can be used by DNNLOWP Int8Conv operators.
+ * DNNLOWP operators can pack matrix on demand during their first invocations
+ * but calling this operator to pre-pack can have benefits like saving memory
+ * space when multiple operators are sharing the same weight.
+ * This operator should be a part of init net to be called once to populate
+ * packed blob to be used by Int8Conv DNNLOWP operators in the predictor net
+ *
+ * This operator optionally can also pre-quantize bias.
+ * Then, we should also provide the scale of input activation tensor as in_scale
+ * argument.
+ */
+class ConvDNNLowPPackWeightOp final
+    : public ConvPoolDNNLowPOpBase<std::uint8_t, ConvFp32Op> {
+ public:
+  USE_CONV_POOL_BASE_FUNCTIONS(CPUContext);
+  USE_CONV_POOL_DNNLOWP_OPERATOR_BASE_FUNCTIONS(std::uint8_t, ConvFp32Op);
+  ConvDNNLowPPackWeightOp(const OperatorDef& operator_def, Workspace* ws);
+
+  bool RunOnDevice() override;
+
+ private:
+  bool TakeDepthWise3x3FastPath_();
+  bool TakeDepthWise3x3x3FastPath_();
+
+  bool quantize_groupwise_;
+  int nbits_in_non_outlier_; // only for DNNLOWP_ACC16
+
+  INPUT_TAGS(FILTER, BIAS);
+};
+
+// Helper functions for packing weights that can be used by
+// ConvDNNLowPAcc16PackWeightOp, ConvDNNLowPOp, and ConvDNNLowPAcc16Op
+
+template <typename T>
+void QuantizeWeight(
+    const Blob& blob,
+    int kernel_dim,
+    int M,
+    vector<dnnlowp::TensorQuantizationParams>& qparams,
+    vector<typename std::make_signed<T>::type>& w_quantized,
+    dnnlowp::QuantizationFactory* qfactory);
+
+template <typename T>
+void ComputeColumnOffsets(
+    int num_rows,
+    int num_cols,
+    const T* W,
+    const vector<dnnlowp::TensorQuantizationParams>& qparams,
+    vector<int32_t>& col_offsets);
+
+/**
+ * @param W_quantized input quantized weight that is not packed yet
+ */
+fbgemm::CompressedSparseColumn* ExtractOutlierMatrix(
+    int groups,
+    int kernel_dim,
+    int M,
+    int nbits_in_non_outlier,
+    vector<std::int8_t>& W_quantized);
+
+} // namespace caffe2
index 36b2cc7..70ebc62 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <fbgemm/src/RefImplementations.h>
 
+#include "fbgemm_pack_op.h"
+
 C10_DECLARE_int32(dnnlowp_nbits_in_non_outlier);
 C10_DECLARE_int32(dnnlowp_copy_to_32bit_frequency);
 
@@ -45,66 +47,52 @@ bool FullyConnectedDNNLowPAcc16Op::RunOnDevice() {
 
   // Pack W if needed
   if (!Wq_acc16_packed_ || !is_weight_constant_) {
-    if (!Wq_acc16_packed_ && nbits_in_non_outlier_ < 8) {
-      static int log_occurences = 0;
-      if (log_occurences < 32) {
-        ++log_occurences;
-        LOG(WARNING) << "FC DNNLOWP_ACC16 using outlier-aware quantization";
+    if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(1)) {
+      // If the input is already packed
+      const auto& packed_filter =
+          this->template Input<Int8FCDNNLowPPackedWeightBlob>(1);
+      Wq_outlier_ = packed_filter.W_outlier;
+      Wq_acc16_packed_ = packed_filter.W_acc16;
+
+      if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
+        LOG(WARNING)
+            << "nbits_in_non_outlier in packed weight "
+            << packed_filter.nbits_in_non_outlier
+            << " doesn't match with nbits_in_non_outlier specified in operator "
+            << nbits_in_non_outlier_;
       }
-
-      // Separate out outliers
-      CAFFE_ENFORCE(!W_quantized_.empty());
-
-      int32_t outlier_cnt = 0;
-      for (int i = 0; i < W_quantized_.size(); ++i) {
-        int8_t w = W_quantized_[i];
-        bool is_outlier = nbits_in_non_outlier_ == 0 ||
-            w < -(1 << (nbits_in_non_outlier_ - 1)) ||
-            w >= (1 << (nbits_in_non_outlier_ - 1));
-        if (is_outlier) {
-          ++outlier_cnt;
+    } else {
+      if (!Wq_acc16_packed_ && nbits_in_non_outlier_ < 8) {
+        static int log_occurences = 0;
+        if (log_occurences < 32) {
+          ++log_occurences;
+          LOG(WARNING) << "FC DNNLOWP_ACC16 using outlier-aware quantization";
         }
-      }
 
-      Wq_outlier_.reset(new fbgemm::CompressedSparseColumn(K, N));
-      Wq_outlier_->RowIdx().resize(outlier_cnt);
-      Wq_outlier_->Values().resize(outlier_cnt);
-
-      outlier_cnt = 0;
-      for (int j = 0; j < N; ++j) {
-        Wq_outlier_->ColPtr()[j] = outlier_cnt;
-        for (int16_t k = 0; k < K; ++k) {
-          int8_t w = W_quantized_[j * K + k];
-          bool is_outlier = nbits_in_non_outlier_ == 0 ||
-              w < -(1 << (nbits_in_non_outlier_ - 1)) ||
-              w >= (1 << (nbits_in_non_outlier_ - 1));
-          if (is_outlier) {
-            CAFFE_ENFORCE_LE(k, numeric_limits<int16_t>::max());
-            Wq_outlier_->RowIdx()[outlier_cnt] = k;
-            Wq_outlier_->Values()[outlier_cnt] = w;
-            ++outlier_cnt;
-            W_quantized_[j * K + k] = 0;
-          }
-        }
-      }
-      Wq_outlier_->ColPtr()[N] = outlier_cnt;
+        // Separate out outliers
+        CAFFE_ENFORCE(!W_quantized_.empty());
 
-      LOG(INFO) << "Proportion of outlier for FC layer with weight blob "
-                << OperatorBase::debug_def().input(1) << " is "
-                << (float)outlier_cnt / W_quantized_.size();
+        Wq_outlier_.reset(
+            ExtractOutlierMatrix(1, K, N, nbits_in_non_outlier_, W_quantized_));
+        int outlier_cnt = Wq_outlier_->ColPtr()[N];
 
-      LOG(INFO) << "copy_to_32bit_frequency " << copy_to_32bit_frequency_;
-    }
+        LOG(INFO) << "Proportion of outlier for FC layer with weight blob "
+                  << OperatorBase::debug_def().input(1) << " is "
+                  << (float)outlier_cnt / W_quantized_.size();
 
-    Wq_acc16_packed_.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
-        fbgemm::matrix_op_t::Transpose,
-        K,
-        N,
-        reinterpret_cast<const int8_t*>(W_quantized_.data()),
-        K));
+        LOG(INFO) << "copy_to_32bit_frequency " << copy_to_32bit_frequency_;
+      }
 
-    if (is_weight_constant_) {
-      vector<T_signed>().swap(W_quantized_);
+      Wq_acc16_packed_.reset(new fbgemm::PackBMatrix<int8_t, int16_t>(
+          fbgemm::matrix_op_t::Transpose,
+          K,
+          N,
+          reinterpret_cast<const int8_t*>(W_quantized_.data()),
+          K));
+
+      if (is_weight_constant_) {
+        vector<T_signed>().swap(W_quantized_);
+      }
     }
   }
 
@@ -145,7 +133,7 @@ bool FullyConnectedDNNLowPAcc16Op::RunOnDevice() {
           in_qparams_[0].zero_point,
           &in_qparams_[1].zero_point,
           packA.getRowOffsetBuffer(),
-          column_offsets_.data(),
+          column_offsets_->data(),
           this->b_quantized_data_,
           N); // ncols per quant group
 
@@ -185,7 +173,7 @@ bool FullyConnectedDNNLowPAcc16Op::RunOnDevice() {
           in_qparams_[0].zero_point,
           &in_qparams_[1].zero_point,
           packA.getRowOffsetBuffer(),
-          column_offsets_.data(),
+          column_offsets_->data(),
           this->b_dequantized_data_,
           N); // ncols per quant group
 
@@ -235,7 +223,7 @@ bool FullyConnectedDNNLowPAcc16Op::RunOnDevice() {
 
         for (int j = 0; j < N; ++j) {
           Y_int32_[i * N + j] -=
-              in_qparams_[0].zero_point * column_offsets_[j] + row_offset;
+              in_qparams_[0].zero_point * (*column_offsets_)[j] + row_offset;
           Ydata_float[i * N + j] = Y_int32_[i * N + j] * in_qparams_[0].scale *
                   in_qparams_[1].scale +
               b_dequantized_data_[j];
@@ -261,8 +249,8 @@ bool FullyConnectedDNNLowPAcc16Op::RunOnDevice() {
             in_qparams_[0].zero_point,
             &in_qparams_[1].zero_point,
             &row_offset,
-            column_offsets_.data(),
-            b_quantized_.data(),
+            column_offsets_->data(),
+            b_quantized_->data(),
             N); // ncols per quant group
       }
     }
index f83e925..53a140b 100644 (file)
@@ -25,11 +25,11 @@ class FullyConnectedDNNLowPAcc16Op final
   using BaseType::W_quantized_;
 
  private:
-  std::unique_ptr<fbgemm::PackBMatrix<std::int8_t, std::int16_t>>
+  std::shared_ptr<fbgemm::PackBMatrix<std::int8_t, std::int16_t>>
       Wq_acc16_packed_;
 
   // Wq outlier in CSC format
-  std::unique_ptr<fbgemm::CompressedSparseColumn> Wq_outlier_;
+  std::shared_ptr<fbgemm::CompressedSparseColumn> Wq_outlier_;
   int nbits_in_non_outlier_;
   int copy_to_32bit_frequency_;
 }; // class FullyConnectedDNNLowPAcc16Op
index 9fda926..ce3598e 100644 (file)
@@ -6,6 +6,7 @@ import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep
+from caffe2.quantization.server import utils as dnnlowp_utils
 from dnnlowp_test_utils import check_quantized_results_close
 from hypothesis import given
 
@@ -119,6 +120,7 @@ class DNNLowPFullyConnectedAcc16OpTest(hu.HypothesisTestCase):
         nbits_in_non_outlier=st.sampled_from((0, 6)),
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
+        prepack_weight=st.booleans(),
         **hu.gcs_cpu_only
     )
     def test_dnnlowp_fully_connected_acc16_outlier(
@@ -129,6 +131,7 @@ class DNNLowPFullyConnectedAcc16OpTest(hu.HypothesisTestCase):
         nbits_in_non_outlier,
         in_quantized,
         out_quantized,
+        prepack_weight,
         gc,
         dc,
     ):
@@ -171,10 +174,12 @@ class DNNLowPFullyConnectedAcc16OpTest(hu.HypothesisTestCase):
         ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine and in_quantized
             do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -182,9 +187,28 @@ class DNNLowPFullyConnectedAcc16OpTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            x_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max())
+
+            if do_prepack_weight:
+                inputs = ["W"]
+                if do_dequantize:
+                    inputs += ["b"]
+                pack = core.CreateOperator(
+                    "Int8FCPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
+
             fc = core.CreateOperator(
                 op_type,
-                ["X_q" if do_quantize else "X", "W", "b"],
+                [
+                    "X_q" if do_quantize else "X",
+                    "W_packed" if do_prepack_weight else "W",
+                    "b",
+                ],
                 ["Y_q" if do_dequantize else "Y"],
                 dequantize_output=(0 if do_dequantize else 1),
                 engine=engine,
@@ -202,6 +226,7 @@ class DNNLowPFullyConnectedAcc16OpTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             outputs.append(
                 Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine)
index 11ba4e9..d32beff 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffe2/core/tensor_int8.h"
 #include "caffe2/utils/cpuid.h"
 #include "fbgemm_pack_matrix_cache.h"
+#include "fbgemm_pack_op.h"
 #include "mmio.h"
 
 C10_DEFINE_bool(
@@ -28,6 +29,8 @@ FullyConnectedDNNLowPOp<T>::FullyConnectedDNNLowPOp(
     : BaseType(operator_def, ws),
       axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
       axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)),
+      b_quantized_(make_shared<vector<int32_t>>()),
+      column_offsets_(make_shared<vector<int32_t>>()),
       is_weight_constant_(
           OperatorBase::GetSingleArgument<bool>("constant_weight", true)) {
   if (!is_weight_constant_) {
@@ -171,7 +174,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
           in_qparams_[0].zero_point,
           &in_qparams_[1].zero_point,
           packA.getRowOffsetBuffer(),
-          column_offsets_.data(),
+          column_offsets_->data(),
           b_quantized_data_,
           N); // ncols per quant group
 
@@ -216,7 +219,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
             in_qparams_[0].zero_point,
             &in_qparams_[1].zero_point,
             packA.getRowOffsetBuffer(),
-            column_offsets_.data(),
+            column_offsets_->data(),
             b_dequantized_data_, // bias
             N); // ncols per quant group
 
@@ -251,7 +254,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
             in_qparams_[0].zero_point,
             &in_qparams_[1].zero_point,
             packA.getRowOffsetBuffer(),
-            column_offsets_.data(),
+            column_offsets_->data(),
             b_dequantized_data_, // bias
             N); // ncols per quant group
 
@@ -327,7 +330,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
 
         for (int j = 0; j < N; ++j) {
           Y_int32_[i * N + j] -=
-              in_qparams_[0].zero_point * column_offsets_[j] + row_offset;
+              in_qparams_[0].zero_point * (*column_offsets_)[j] + row_offset;
           Ydata[i * N + j] = Y_int32_[i * N + j] * in_qparams_[0].scale *
                   in_qparams_[1].scale +
               b_dequantized_data_[j];
@@ -346,7 +349,7 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
 
         for (int j = 0; j < N; ++j) {
           Y_int32_[i * N + j] -=
-              in_qparams_[0].zero_point * column_offsets_[j] + row_offset;
+              in_qparams_[0].zero_point * (*column_offsets_)[j] + row_offset;
           Y_int32_[i * N + j] += b_quantized_data_[j];
 
           Ydata[i * N + j] = fbgemm::Requantize<T>(
@@ -412,43 +415,34 @@ bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
         OperatorBase::debug_def().engine() != "DNNLOWP_ACC16";
 
     if ((fast_path && !Wq_packed_) || (!fast_path && W_quantized_.empty())) {
-      W_quantized_.resize(W.size());
-
-      if (OperatorBase::InputIsType<int8::Int8TensorCPU>(1)) {
-        in_qparams_[1].scale =
-            OperatorBase::Input<int8::Int8TensorCPU>(1).scale;
-        in_qparams_[1].zero_point =
-            OperatorBase::Input<int8::Int8TensorCPU>(1).zero_point + signed_min;
-
-        const T* W_data = W.template data<T>();
-        for (auto i = 0; i < W.size(); ++i) {
-          W_quantized_[i] = W_data[i] + signed_min;
-        }
+      if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(1)) {
+        const auto& packed_filter =
+            this->template Input<Int8FCDNNLowPPackedWeightBlob>(1);
+        CAFFE_ENFORCE_EQ(packed_filter.qparams.size(), 1);
+        in_qparams_[1] = packed_filter.qparams[0];
       } else {
-        in_qparams_[1] = qfactory_->ChooseQuantizationParams(
-            W.template data<float>(), W.size(), true /*weight*/);
-
-        // in_qparams_[1] is computed for unsigned type.
-        // Adjust for the fact that weight will actually use signed.
-        in_qparams_[1].zero_point += signed_min;
-
-        fbgemm::Quantize<T_signed>(
-            W.template data<float>(),
-            W_quantized_.data(),
-            W_quantized_.size(),
-            in_qparams_[1]);
+        vector<TensorQuantizationParams> temp_qparams(1);
+        QuantizeWeight<T>(
+            InputBlob(1), K, N, temp_qparams, W_quantized_, qfactory_.get());
+        in_qparams_[1] = temp_qparams[0];
       }
 
       if (fast_path) {
         // fast path using fbgemm
-        Wq_packed_ = GetOrCreateFbgemmPackBMatrix<int32_t>(
-            fbgemm::matrix_op_t::Transpose,
-            K,
-            N,
-            W.raw_data(),
-            reinterpret_cast<const int8_t*>(W_quantized_.data()),
-            K, // ld
-            in_qparams_[1].zero_point);
+        if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(1)) {
+          const auto& packed_filter =
+              this->template Input<Int8FCDNNLowPPackedWeightBlob>(1);
+          Wq_packed_ = packed_filter.W;
+        } else {
+          Wq_packed_ = GetOrCreateFbgemmPackBMatrix<int32_t>(
+              fbgemm::matrix_op_t::Transpose,
+              K,
+              N,
+              W.raw_data(),
+              reinterpret_cast<const int8_t*>(W_quantized_.data()),
+              K, // ld
+              in_qparams_[1].zero_point);
+        }
       } else {
         string reason;
         if (!is_same<T, uint8_t>::value) {
@@ -489,14 +483,16 @@ bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
     t_begin = chrono::system_clock::now();
   }
   // Pre-compute column_offset
-  if (!is_weight_constant_ || column_offsets_.empty()) {
-    column_offsets_.resize(N);
-    for (int j = 0; j < N; ++j) {
-      int32_t sum = 0;
-      for (int k = 0; k < K; ++k) {
-        sum += W_quantized_[j * K + k];
-      }
-      column_offsets_[j] = sum - in_qparams_[1].zero_point * K;
+  if (!is_weight_constant_ || column_offsets_->empty()) {
+    if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(1)) {
+      const auto& packed_filter =
+          this->template Input<Int8FCDNNLowPPackedWeightBlob>(1);
+      column_offsets_ = packed_filter.column_offsets;
+    } else {
+      vector<TensorQuantizationParams> temp_qparams;
+      temp_qparams.push_back(in_qparams_[1]);
+      ComputeColumnOffsets<T_signed>(
+          K, N, W_quantized_.data(), temp_qparams, *column_offsets_);
     }
   }
   if (VLOG_IS_ON(3)) {
@@ -514,43 +510,51 @@ bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
   // Quantize bias
   if (!is_weight_constant_ || (!b_quantized_data_ && !b_dequantized_data_) ||
       in_qparams_[0].scale != in_qparams0_scale_old_) {
-    const auto& bias = InputTensorCPU_(2);
-    if (OperatorBase::InputIsType<int8::Int8TensorCPU>(2)) {
-      in_qparams_[2].scale = OperatorBase::Input<int8::Int8TensorCPU>(2).scale;
-      in_qparams_[2].zero_point =
-          OperatorBase::Input<int8::Int8TensorCPU>(2).zero_point;
-      CAFFE_ENFORCE_LE(
-          std::abs(
-              in_qparams_[2].scale -
-              in_qparams_[0].scale * in_qparams_[1].scale),
-          1e-4);
-      CAFFE_ENFORCE_EQ(in_qparams_[2].zero_point, 0);
-      b_quantized_data_ = bias.template data<int32_t>();
-      if (dequantize_output_) {
-        b_dequantized_.resize(N);
-        for (int j = 0; j < N; ++j) {
-          b_dequantized_[j] =
-              fbgemm::Dequantize<int32_t>(b_quantized_data_[j], in_qparams_[2]);
-        }
-        b_dequantized_data_ = b_dequantized_.data();
-      }
+    if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(2) &&
+        this->template Input<Int8FCDNNLowPPackedWeightBlob>(2).bias.get()) {
+      const auto& packed_filter =
+          this->template Input<Int8FCDNNLowPPackedWeightBlob>(2);
+      CAFFE_ENFORCE(!dequantize_output_);
+      b_quantized_ = packed_filter.bias;
+      b_quantized_data_ = b_quantized_->data();
     } else {
-      in_qparams_[2].scale = in_qparams_[0].scale * in_qparams_[1].scale;
-      in_qparams_[2].zero_point = 0;
-      b_dequantized_data_ = bias.template data<float>();
-      if (!dequantize_output_) {
-        b_quantized_.resize(N);
-        for (int j = 0; j < N; ++j) {
-          b_quantized_[j] = fbgemm::Quantize<int32_t>(
-              b_dequantized_data_[j],
-              in_qparams_[2].zero_point,
-              in_qparams_[2].scale,
-              32);
+      const auto& bias = InputTensorCPU_(2);
+      if (this->template InputIsType<int8::Int8TensorCPU>(2)) {
+        TensorQuantizationParams bias_qparams;
+        bias_qparams.scale = this->template Input<int8::Int8TensorCPU>(2).scale;
+        bias_qparams.zero_point =
+            this->template Input<int8::Int8TensorCPU>(2).zero_point;
+        CAFFE_ENFORCE_LE(
+            std::abs(
+                bias_qparams.scale -
+                in_qparams_[0].scale * in_qparams_[1].scale),
+            1e-4);
+        CAFFE_ENFORCE_EQ(bias_qparams.zero_point, 0);
+        b_quantized_data_ = bias.template data<int32_t>();
+        if (dequantize_output_) {
+          b_dequantized_.resize(N);
+          for (int j = 0; j < N; ++j) {
+            b_dequantized_[j] = fbgemm::Dequantize<int32_t>(
+                b_quantized_data_[j], in_qparams_[2]);
+          }
+          b_dequantized_data_ = b_dequantized_.data();
+        }
+      } else {
+        b_dequantized_data_ = bias.template data<float>();
+        if (!dequantize_output_) {
+          b_quantized_->resize(N);
+          for (int j = 0; j < N; ++j) {
+            (*b_quantized_)[j] = fbgemm::Quantize<int32_t>(
+                b_dequantized_data_[j],
+                0,
+                in_qparams_[0].scale * in_qparams_[1].scale,
+                32);
+          }
+          b_quantized_data_ = b_quantized_->data();
         }
-        b_quantized_data_ = b_quantized_.data();
       }
+      in_qparams0_scale_old_ = in_qparams_[0].scale;
     }
-    in_qparams0_scale_old_ = in_qparams_[0].scale;
 
     CAFFE_ENFORCE(
         (dequantize_output_ && b_dequantized_data_) ||
index fcde873..398ed8b 100644 (file)
@@ -40,9 +40,10 @@ class FullyConnectedDNNLowPOp
   std::vector<T_signed> W_quantized_;
 
   // pre-computed biases and offsets
-  std::vector<std::int32_t> b_quantized_;
+  std::shared_ptr<std::vector<std::int32_t>> b_quantized_;
   const std::int32_t* b_quantized_data_{nullptr};
-  std::vector<std::int32_t> row_offsets_, column_offsets_;
+  std::vector<std::int32_t> row_offsets_;
+  std::shared_ptr<std::vector<std::int32_t>> column_offsets_;
 
   // Dequantized bias populated when input bias is quantized and
   // dequantized_output_ == true
index 58ca716..edbf38e 100644 (file)
@@ -6,7 +6,6 @@ import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep
-from caffe2.python.fb import hardcode_scale_zp
 from caffe2.quantization.server import utils as dnnlowp_utils
 from dnnlowp_test_utils import (
     avoid_vpmaddubsw_overflow_fc,
@@ -14,6 +13,7 @@ from dnnlowp_test_utils import (
 )
 from hypothesis import given
 
+
 dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
 
 
@@ -26,6 +26,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
         weight_quantized=st.booleans(),
+        prepack_weight=st.booleans(),
         **hu.gcs_cpu_only
     )
     def test_dnnlowp_fully_connected_int(
@@ -36,6 +37,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         in_quantized,
         out_quantized,
         weight_quantized,
+        prepack_weight,
         gc,
         dc,
     ):
@@ -87,6 +89,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine and in_quantized
@@ -94,6 +97,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
             do_quantize_weight = (
                 engine == "DNNLOWP" and weight_quantized and len(outputs) > 0
             )
+            do_prepack_weight = engine == "DNNLOWP" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -101,26 +105,39 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            x_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max())
             if do_quantize_weight:
                 int8_given_tensor_fill, w_q_param = dnnlowp_utils.create_int8_given_tensor_fill(
                     W, "W_q"
                 )
-                net.Proto().op.extend([int8_given_tensor_fill])
+                init_net.Proto().op.extend([int8_given_tensor_fill])
 
                 # Bias
-                x_q_param = hardcode_scale_zp.choose_quantization_params(
-                    X.min(), X.max()
-                )
                 int8_bias_tensor_fill = dnnlowp_utils.create_int8_bias_tensor_fill(
                     b, "b_q", x_q_param, w_q_param
                 )
-                net.Proto().op.extend([int8_bias_tensor_fill])
+                init_net.Proto().op.extend([int8_bias_tensor_fill])
+
+            if do_prepack_weight:
+                inputs = ["W_q" if do_quantize_weight else "W"]
+                if do_dequantize:
+                    inputs += ["b_q" if do_quantize_weight else "b"]
+                pack = core.CreateOperator(
+                    "Int8FCPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
 
             fc = core.CreateOperator(
                 op_type,
                 [
                     "X_q" if do_quantize else "X",
-                    "W_q" if do_quantize_weight else "W",
+                    "W_packed"
+                    if do_prepack_weight
+                    else ("W_q" if do_quantize_weight else "W"),
                     "b_q" if do_quantize_weight else "b",
                 ],
                 ["Y_q" if do_dequantize else "Y"],
@@ -128,7 +145,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
                 engine=engine,
                 device_option=gc,
             )
-            if do_quantize_weight:
+            if do_quantize_weight or do_prepack_weight:
                 # When quantized weight is provided, we can't rescale the
                 # output dynamically by looking at the range of output of each
                 # batch, so here we provide the range of output observed from
@@ -145,6 +162,7 @@ class DNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             outputs.append(
                 Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine)
index 0f64d22..00fbb87 100644 (file)
@@ -1,7 +1,10 @@
 #include "fully_connected_rowwise_dnnlowp_op.h"
+
 #include <fbgemm/src/RefImplementations.h>
 #include <chrono>
 
+#include "fbgemm_pack_op.h"
+
 namespace caffe2 {
 
 using namespace std;
@@ -13,6 +16,8 @@ FullyConnectedRowWiseDNNLowPOp<T>::FullyConnectedRowWiseDNNLowPOp(
     : BaseType(operator_def, ws),
       axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
       axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)),
+      b_quantized_(make_shared<vector<int32_t>>()),
+      column_offsets_(make_shared<vector<int32_t>>()),
       is_weight_constant_(
           OperatorBase::GetSingleArgument<bool>("constant_weight", true)) {
   using namespace dnnlowp;
@@ -131,9 +136,9 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::RunOnDevice() {
       for (int i = 0; i < M; ++i) {
         for (int j = 0; j < N; ++j) {
           Y_int32_[i * N + j] -=
-              in_qparams_[0].zero_point * column_offsets_[j] +
+              in_qparams_[0].zero_point * (*column_offsets_)[j] +
               rowwise_qparams_[j].zero_point * row_offsets_[i];
-          Y_int32_[i * N + j] += b_quantized_[j];
+          Y_int32_[i * N + j] += (*b_quantized_)[j];
           Ydata[i * N + j] = Y_int32_[i * N + j] * rowwise_qparams_[j].scale *
                   in_qparams_[0].scale +
               b_data[j];
@@ -144,9 +149,9 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::RunOnDevice() {
       for (int i = 0; i < M; ++i) {
         for (int j = 0; j < N; ++j) {
           Y_int32_[i * N + j] -=
-              in_qparams_[0].zero_point * column_offsets_[j] +
+              in_qparams_[0].zero_point * (*column_offsets_)[j] +
               rowwise_qparams_[j].zero_point * row_offsets_[i];
-          Y_int32_[i * N + j] += b_quantized_[j];
+          Y_int32_[i * N + j] += (*b_quantized_)[j];
           Ydata[i * N + j] = Requantize<T>(
               Y_int32_[i * N + j], rowwise_requantization_params_[j]);
         }
@@ -184,7 +189,7 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::RunOnDevice() {
         }
         for (int j = 0; j < N; ++j) {
           Y_int32_[i * N + j] -=
-              in_qparams_[0].zero_point * column_offsets_[j] +
+              in_qparams_[0].zero_point * (*column_offsets_)[j] +
               rowwise_qparams_[j].zero_point * row_offset;
           Ydata[i * N + j] = Y_int32_[i * N + j] * rowwise_qparams_[j].scale *
                   in_qparams_[0].scale +
@@ -200,9 +205,9 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::RunOnDevice() {
         }
         for (int j = 0; j < N; ++j) {
           Y_int32_[i * N + j] -=
-              in_qparams_[0].zero_point * column_offsets_[j] +
+              in_qparams_[0].zero_point * (*column_offsets_)[j] +
               rowwise_qparams_[j].zero_point * row_offset;
-          Y_int32_[i * N + j] += b_quantized_[j];
+          Y_int32_[i * N + j] += (*b_quantized_)[j];
           Ydata[i * N + j] = fbgemm::Requantize<T>(
               Y_int32_[i * N + j], rowwise_requantization_params_[j]);
         }
@@ -250,33 +255,51 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::GetQuantizationParameters_() {
     if ((fast_path && !Wq_packed_) || (!fast_path && W_quantized_.empty())) {
       LOG(INFO) << "Choose rowwise quantization params";
       if (rowwise_qparams_.empty()) {
-        // choose rowwise quantization params
-        rowwise_qparams_.resize(N);
-        W_quantized_.resize(W.size());
-        for (int i = 0; i < N; ++i) {
-          rowwise_qparams_[i] = qfactory_->ChooseQuantizationParams(
-              W.template data<float>() + K * i, K, true /*weight*/);
-          rowwise_qparams_[i].zero_point -=
-              (1 << (qfactory_->GetWeightPrecision() - 1));
-          fbgemm::Quantize<T_signed>(
-              W.template data<float>() + K * i,
-              W_quantized_.data() + K * i,
+        if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(1)) {
+          const auto& packed_filter =
+              this->template Input<Int8FCDNNLowPPackedWeightBlob>(1);
+          CAFFE_ENFORCE_EQ(packed_filter.qparams.size(), N);
+          // TODO: optimize the overhead of copy
+          rowwise_qparams_ = packed_filter.qparams;
+        } else {
+          // choose rowwise quantization params
+          if (this->template InputIsType<int8::Int8TensorCPU>(1)) {
+            static int log_occurences = 0;
+            if (log_occurences < 32) {
+              ++log_occurences;
+              LOG(WARNING) << "Cannot do row-wise quantization for "
+                              "pre-quantized weight "
+                           << this->debug_def().input(1);
+            }
+          }
+          rowwise_qparams_.resize(N);
+          QuantizeWeight<T>(
+              InputBlob(1),
               K,
-              rowwise_qparams_[i]);
+              N,
+              rowwise_qparams_,
+              W_quantized_,
+              qfactory_.get());
         }
       }
       if (fast_path) {
         // fast path using fbgemm
         LOG(INFO)
             << "Using fast path with int8 fbgemm and generating Wq_packed_";
-        Wq_packed_.reset(new fbgemm::PackBMatrix<int8_t>(
-            fbgemm::matrix_op_t::Transpose,
-            K,
-            N,
-            reinterpret_cast<const int8_t*>(W_quantized_.data()),
-            K, // ld
-            nullptr, // pmat
-            1)); // groups
+        if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(1)) {
+          const auto& packed_filter =
+              this->template Input<Int8FCDNNLowPPackedWeightBlob>(1);
+          Wq_packed_ = packed_filter.W;
+        } else {
+          Wq_packed_.reset(new fbgemm::PackBMatrix<int8_t>(
+              fbgemm::matrix_op_t::Transpose,
+              K,
+              N,
+              reinterpret_cast<const int8_t*>(W_quantized_.data()),
+              K, // ld
+              nullptr, // pmat
+              1)); // groups
+        }
       } else {
         LOG(WARNING)
             << "Falling back to slow path because fbgemm doesn't support "
@@ -301,29 +324,36 @@ bool FullyConnectedRowWiseDNNLowPOp<T>::GetQuantizationParameters_() {
     }
   }
 
-  if (!is_weight_constant_ || column_offsets_.empty()) {
-    // pre-compute column_offsets_
-    column_offsets_.resize(N);
-    for (int j = 0; j < N; ++j) {
-      int32_t sum = 0;
-      for (int k = 0; k < K; ++k) {
-        sum += W_quantized_[j * K + k];
-      }
-      column_offsets_[j] = sum - rowwise_qparams_[j].zero_point * K;
+  if (!is_weight_constant_ || column_offsets_->empty()) {
+    if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(1)) {
+      const auto& packed_filter =
+          this->template Input<Int8FCDNNLowPPackedWeightBlob>(1);
+      column_offsets_ = packed_filter.column_offsets;
+    } else {
+      ComputeColumnOffsets<T_signed>(
+          K, N, W_quantized_.data(), rowwise_qparams_, *column_offsets_);
     }
   }
 
   if (Wq_packed_) {
     vector<T_signed>().swap(W_quantized_);
   }
-  if (!is_weight_constant_ || b_quantized_.empty()) {
+  if (!is_weight_constant_ || b_quantized_->empty()) {
     // Quantize bias
-    b_quantized_.resize(N);
-    const auto& b = InputTensorCPU_(2);
-    const float* b_data = b.template data<float>();
-    for (int j = 0; j < N; ++j) {
-      b_quantized_[j] = fbgemm::Quantize<int32_t>(
-          b_data[j], 0, in_qparams_[0].scale * rowwise_qparams_[j].scale, 32);
+    if (this->template InputIsType<Int8FCDNNLowPPackedWeightBlob>(2) &&
+        this->template Input<Int8FCDNNLowPPackedWeightBlob>(2).bias.get()) {
+      const auto& packed_filter =
+          this->template Input<Int8FCDNNLowPPackedWeightBlob>(2);
+      CAFFE_ENFORCE(!dequantize_output_);
+      b_quantized_ = packed_filter.bias;
+    } else {
+      b_quantized_->resize(N);
+      const auto& b = InputTensorCPU_(2);
+      const float* b_data = b.template data<float>();
+      for (int j = 0; j < N; ++j) {
+        (*b_quantized_)[j] = fbgemm::Quantize<int32_t>(
+            b_data[j], 0, in_qparams_[0].scale * rowwise_qparams_[j].scale, 32);
+      }
     }
   }
   if (!dequantize_output_) {
index 0eba2fe..c756f98 100644 (file)
@@ -31,14 +31,14 @@ class FullyConnectedRowWiseDNNLowPOp final
   using T_signed = typename std::make_signed<T>::type;
 
   // used in fast path for T == uint8_t
-  std::unique_ptr<fbgemm::PackBMatrix<std::int8_t>> Wq_packed_;
+  std::shared_ptr<fbgemm::PackBMatrix<std::int8_t>> Wq_packed_;
   std::vector<std::uint8_t> X_pack_buf_;
 
   // used in slow path for T != uint8_t
   std::vector<T_signed> W_quantized_;
-  std::vector<std::int32_t> b_quantized_;
+  std::shared_ptr<std::vector<std::int32_t>> b_quantized_;
 
-  std::vector<std::int32_t> column_offsets_;
+  std::shared_ptr<std::vector<std::int32_t>> column_offsets_;
   std::vector<std::int32_t> row_offsets_;
   std::vector<std::int32_t> Y_int32_;
 
index 195d847..550ed7b 100644 (file)
@@ -6,6 +6,7 @@ import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep
+from caffe2.quantization.server import utils as dnnlowp_utils
 from dnnlowp_test_utils import (
     avoid_vpmaddubsw_overflow_fc,
     check_quantized_results_close,
@@ -24,6 +25,7 @@ class RowWiseDNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         batch_size=st.integers(1, 16),
         in_quantized=st.booleans(),
         out_quantized=st.booleans(),
+        prepack_weight=st.booleans(),
         **hu.gcs_cpu_only
     )
     def test_rowwise_dnnlowp_fully_connected_int(
@@ -33,6 +35,7 @@ class RowWiseDNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         batch_size,
         in_quantized,
         out_quantized,
+        prepack_weight,
         gc,
         dc,
     ):
@@ -90,10 +93,12 @@ class RowWiseDNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
         ]
 
         for op_type, engine in op_engine_list:
+            init_net = core.Net("test_init_net")
             net = core.Net("test_net")
 
             do_quantize = "DNNLOWP" in engine and in_quantized
             do_dequantize = "DNNLOWP" in engine and out_quantized
+            do_prepack_weight = engine == "DNNLOWP_ROWWISE" and prepack_weight
 
             if do_quantize:
                 quantize = core.CreateOperator(
@@ -101,14 +106,39 @@ class RowWiseDNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([quantize])
 
+            x_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max())
+
+            if do_prepack_weight:
+                inputs = ["W"]
+                if do_dequantize:
+                    inputs += ["b"]
+                pack = core.CreateOperator(
+                    "Int8FCPackWeight",
+                    inputs,
+                    ["W_packed"],
+                    in_scale=x_q_param.scale,
+                    engine=engine,
+                )
+                init_net.Proto().op.extend([pack])
+
             fc = core.CreateOperator(
                 op_type,
-                ["X_q" if do_quantize else "X", "W", "b"],
+                [
+                    "X_q" if do_quantize else "X",
+                    "W_packed" if do_prepack_weight else "W",
+                    "b",
+                ],
                 ["Y_q" if do_dequantize else "Y"],
                 dequantize_output=not do_dequantize,
                 engine=engine,
                 device_option=gc,
             )
+            if do_prepack_weight:
+                # When pre-packed quantized weight is provided, we can't rescale
+                # the output dynamically by looking at the range of output of
+                # each batch, so here we provide the range of output observed
+                # from fp32 reference implementation
+                dnnlowp_utils.add_quantization_param_args(fc, outputs[0][0])
             net.Proto().op.extend([fc])
 
             if do_dequantize:
@@ -120,6 +150,7 @@ class RowWiseDNNLowPFullyConnectedOpTest(hu.HypothesisTestCase):
             self.ws.create_blob("X").feed(X, device_option=gc)
             self.ws.create_blob("W").feed(W, device_option=gc)
             self.ws.create_blob("b").feed(b, device_option=gc)
+            self.ws.run(init_net)
             self.ws.run(net)
             outputs.append(
                 Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine)
index d6ef057..d386ea1 100644 (file)
@@ -6,7 +6,6 @@ import caffe2.python.hypothesis_test_util as hu
 import hypothesis.strategies as st
 import numpy as np
 from caffe2.python import core, dyndep
-from caffe2.python.fb import hardcode_scale_zp
 from caffe2.quantization.server import utils as dnnlowp_utils
 from dnnlowp_test_utils import check_quantized_results_close
 from hypothesis import given
@@ -80,9 +79,7 @@ class DNNLowPOpGroupNormTest(hu.HypothesisTestCase):
                 )
                 net.Proto().op.extend([int8_given_tensor_fill])
 
-                X_q_param = hardcode_scale_zp.choose_quantization_params(
-                    X.min(), X.max()
-                )
+                X_q_param = dnnlowp_utils.choose_quantization_params(X.min(), X.max())
                 int8_bias_tensor_fill = dnnlowp_utils.create_int8_bias_tensor_fill(
                     beta, "beta_q", X_q_param, gamma_q_param
                 )
index 3aa361c..6be2679 100644 (file)
@@ -352,10 +352,7 @@ def add_quantization_param_args_(op, q_param):
     )
 
 
-def add_quantization_param_args(op, tensor, preserve_sparsity=False):
-    tensor_min = 0 if tensor.size == 0 else tensor.min()
-    tensor_max = 0 if tensor.size == 0 else tensor.max()
-
+def choose_quantization_params(tensor_min, tensor_max, preserve_sparsity=False):
     if tensor_min < 0 and tensor_max > 0 and preserve_sparsity:
         symmetric_qmin = -(255 // 2 + 1)
         symmetric_qmax = 255 // 2
@@ -370,6 +367,15 @@ def add_quantization_param_args(op, tensor, preserve_sparsity=False):
     if tensor_min < 0 and tensor_max > 0 and preserve_sparsity:
         q_param = hardcode_scale_zp.QuantizationParam(q_param.scale, 128)
 
+    return q_param
+
+
+def add_quantization_param_args(op, tensor, preserve_sparsity=False):
+    tensor_min = 0 if tensor.size == 0 else tensor.min()
+    tensor_max = 0 if tensor.size == 0 else tensor.max()
+
+    q_param = choose_quantization_params(tensor_min, tensor_max, preserve_sparsity)
+
     add_quantization_param_args_(op, q_param)
     return q_param