Quantized RNNCell modules (#15469)
authorJames Reed <jamesreed@fb.com>
Tue, 15 Jan 2019 18:07:18 +0000 (10:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 18:40:51 +0000 (10:40 -0800)
Summary:
Similarly to https://github.com/pytorch/pytorch/pull/13777, we apply post-processing quantization to RNN cell modules (`RNNCell`, `LSTMCell`, and `GRUCell`).

A further follow-up PR will involve quantizing the full `RNN`, `GRU`, and `LSTM` modules. This depends on those modules being scriptable as part of the standard library scripting effort, though. Note that infrastructure in this pr such as `gather_quantized_params` is currently unused but should be used in the future when we can port over the full RNN modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15469

Differential Revision: D13545802

Pulled By: jamesr66a

fbshipit-source-id: ad3b694517842893ea619438e9f5e88fd7b96510

aten/src/ATen/native/RNN.cpp
aten/src/ATen/native/native_functions.yaml
test/test_jit.py
torch/jit/quantized.py

index d0178a1..7ec339c 100644 (file)
@@ -42,6 +42,91 @@ struct CellParams {
   const Tensor& w_hh;
   const Tensor& b_ih; /* optional */
   const Tensor& b_hh; /* optional */
+
+  Tensor matmul_ih(Tensor input) const {
+    return at::matmul(input, w_ih.t());
+  }
+  Tensor matmul_hh(Tensor h) const {
+    return at::matmul(h, w_hh.t());
+  }
+  Tensor linear_ih(Tensor input) const {
+    return at::linear(input, w_ih, b_ih);
+  }
+  Tensor linear_hh(Tensor h) const {
+    return at::linear(h, w_hh, b_hh);
+  }
+};
+
+// Run this Python script and pipe to clang-format to generate the constructor
+// and data members:
+//
+// names = ['w', 'b', 'packed', 'col_offsets', 'scale', 'zero_point']
+//
+//
+// get_type = lambda i: 'Scalar' if i == 4 or i == 5 else 'Tensor'
+// member_ref = lambda i: '' if i == 4 or i == 5 else '&'
+//
+// suffixes = ['ih', 'hh']
+//
+// params = []
+// initializers = []
+// members = []
+// for i in range(len(names)*2):
+//     params.append('const {typ}& _{name}_{suffix}'.format(typ=get_type(
+//         i//2), name=names[(i//2) % len(names)], suffix=suffixes[i % 2]))
+//     initializers.append('{name}_{suffix}(_{name}_{suffix})'.format(
+//         name=names[(i//2) % len(names)], suffix=suffixes[i % 2]))
+//     members.append('const {typ}{member_ref} {name}_{suffix};'.format(typ=get_type(
+//         i//2), name=names[(i//2) % len(names)], suffix=suffixes[i % 2], member_ref=member_ref(i//2)))
+//
+// params_str = ', '.join(params)
+// initializers_str = ', '.join(initializers)
+// members_str = '\n'.join(members)
+//
+// ctor = 'QuantizedCellParams(' + params_str + ') : ' + initializers_str + '{}'
+// print('struct QuantizedCellParams {', '\n\n'.join([ctor, members_str]), '};')
+
+struct QuantizedCellParams {
+  QuantizedCellParams(const Tensor &_w_ih, const Tensor &_w_hh,
+                      const Tensor &_b_ih, const Tensor &_b_hh,
+                      const Tensor &_packed_ih, const Tensor &_packed_hh,
+                      const Tensor &_col_offsets_ih,
+                      const Tensor &_col_offsets_hh, const Scalar &_scale_ih,
+                      const Scalar &_scale_hh, const Scalar &_zero_point_ih,
+                      const Scalar &_zero_point_hh)
+      : w_ih(_w_ih), w_hh(_w_hh), b_ih(_b_ih), b_hh(_b_hh),
+        packed_ih(_packed_ih), packed_hh(_packed_hh),
+        col_offsets_ih(_col_offsets_ih), col_offsets_hh(_col_offsets_hh),
+        scale_ih(_scale_ih), scale_hh(_scale_hh), zero_point_ih(_zero_point_ih),
+        zero_point_hh(_zero_point_hh) {}
+
+  const Tensor &w_ih;
+  const Tensor &w_hh;
+  const Tensor &b_ih;
+  const Tensor &b_hh;
+  const Tensor &packed_ih;
+  const Tensor &packed_hh;
+  const Tensor &col_offsets_ih;
+  const Tensor &col_offsets_hh;
+  const Scalar scale_ih;
+  const Scalar scale_hh;
+  const Scalar zero_point_ih;
+  const Scalar zero_point_hh;
+
+  Tensor matmul_ih(Tensor input) const {
+    AT_CHECK(false, "matmul is not supported with quantized cell params");
+  }
+  Tensor matmul_hh(Tensor h) const {
+    AT_CHECK(false, "matmul is not supported with quantized cell params");
+  }
+  Tensor linear_ih(Tensor input) const {
+    return at::fbgemm_linear_int8_weight(
+        input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih);
+  }
+  Tensor linear_hh(Tensor h) const {
+    return at::fbgemm_linear_int8_weight(
+        h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh);
+  }
 };
 
 // Gathers every two elements of a vector in a vector of pairs
@@ -86,6 +171,19 @@ static std::vector<CellParams> gather_params(TensorList params, bool has_biases)
   return result;
 }
 
+static std::vector<QuantizedCellParams> gather_quantized_params(TensorList params) {
+  static at::Tensor undefined;
+  std::vector<QuantizedCellParams> result;
+  AT_CHECK(params.size() % 12 == 0, "got an incorrect number of quantized RNN parameters");
+  for (size_t i = 0; i < params.size(); i += 12) {
+    result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3],
+                        params[i + 4], params[i + 5], params[i + 6], params[i + 7],
+                        params[i + 8].item(), params[i + 9].item(),
+                        params[i + 10].item(), params[i + 11].item());
+  }
+  return result;
+}
+
 
 ////////////////////////////////////////////////////////////////////////////////
 // HIDDEN STATE FUNCTIONS
@@ -134,35 +232,39 @@ tpair_of<Tensor> hidden_slice(const tpair_of<Tensor>& t, int64_t start, int64_t
 // It's a struct only because functional programming in C++ is a pain, and it's easier
 // to pass around "vtable pointers" than actual function pointers.
 
-template<typename hidden_type_tmpl>
+template<typename hidden_type_tmpl, typename cell_params_tmpl>
 struct Cell {
   using hidden_type = hidden_type_tmpl;
+  using cell_params = cell_params_tmpl;
   virtual ~Cell() {} // This is really dumb, but enables projects with -Wnon-virtual-dtor to compile...
-  virtual hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const = 0;
+  virtual hidden_type operator()(const Tensor& input, const hidden_type& hidden, const cell_params& params) const = 0;
 };
 
-template<typename nonlinearity>
-struct SimpleCell : Cell<Tensor> {
-  hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const override {
-    return nonlinearity{}(at::linear(input, params.w_ih, params.b_ih) + at::linear(hidden, params.w_hh, params.b_hh));
+template<typename nonlinearity, typename cell_params>
+struct SimpleCell : Cell<Tensor, cell_params> {
+  using hidden_type = Tensor;
+  Tensor operator()(const Tensor& input, const Tensor& hidden, const cell_params& params) const override {
+    return nonlinearity{}(params.linear_ih(input) + params.linear_hh(hidden));
   }
 };
 
 // TODO: can use inplace ops?
-struct LSTMCell : Cell<std::tuple<Tensor, Tensor>> {
-  hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const override {
+template <typename cell_params>
+struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
+  using hidden_type = std::tuple<Tensor, Tensor>;
+  hidden_type operator()(const Tensor& input, const hidden_type& hidden, const cell_params& params) const override {
     auto hx = std::get<0>(hidden);
     auto cx = std::get<1>(hidden);
 
     if (input.is_cuda()) {
-      auto igates = at::matmul(input, params.w_ih.t());
-      auto hgates = at::matmul(hx, params.w_hh.t());
+      auto igates = params.matmul_ih(input);
+      auto hgates = params.matmul_hh(hx);
       auto result = at::_thnn_fused_lstm_cell(igates, hgates, cx, params.b_ih, params.b_hh);
       // Slice off the workspace argument (it's needed only for AD).
       return std::make_tuple(std::get<0>(result), std::get<1>(result));
     }
 
-    auto gates = at::linear(input, params.w_ih, params.b_ih) + at::linear(hx, params.w_hh, params.b_hh);
+    auto gates = params.linear_ih(input) + params.linear_hh(hx);
     auto chunked_gates = gates.chunk(4, 1);
 
     auto ingate = chunked_gates[0].sigmoid();
@@ -177,18 +279,20 @@ struct LSTMCell : Cell<std::tuple<Tensor, Tensor>> {
   }
 };
 
-struct GRUCell : Cell<Tensor> {
-  hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const override {
+template <typename cell_params>
+struct GRUCell : Cell<Tensor, cell_params> {
+  using hidden_type = Tensor;
+  hidden_type operator()(const Tensor& input, const hidden_type& hidden, const cell_params& params) const override {
     if (input.is_cuda()) {
-      auto igates = at::matmul(input, params.w_ih.t());
-      auto hgates = at::matmul(hidden, params.w_hh.t());
+      auto igates = params.matmul_ih(input);
+      auto hgates = params.matmul_hh(hidden);
       auto result = at::_thnn_fused_gru_cell(igates, hgates, hidden, params.b_ih, params.b_hh);
       // Slice off the workspace argument (it's needed only for AD).
       return std::get<0>(result);
     }
 
-    auto igates = at::linear(input, params.w_ih, params.b_ih);
-    auto hgates = at::linear(hidden, params.w_hh, params.b_hh);
+    auto igates = params.linear_ih(input);
+    auto hgates = params.linear_hh(hidden);
     auto chunked_igates = igates.chunk(3, 1);
     auto chunked_hgates = hgates.chunk(3, 1);
 
@@ -224,15 +328,15 @@ struct Layer {
   virtual output_type operator()(const io_type& input, const hidden_type& input_hidden, const param_type& params) const = 0;
 };
 
-template<typename hidden_type>
-struct FullLayer : Layer<Tensor, hidden_type, CellParams> {
-  using output_type = typename Layer<Tensor, hidden_type, CellParams>::output_type;
+template<typename hidden_type, typename cell_params>
+struct FullLayer : Layer<Tensor, hidden_type, cell_params> {
+  using output_type = typename Layer<Tensor, hidden_type, cell_params>::output_type;
   using unstacked_output_type = LayerOutput<std::vector<Tensor>, hidden_type>;
 
-  FullLayer(Cell<hidden_type>& cell)
+  FullLayer(Cell<hidden_type, cell_params>& cell)
     : cell_(cell) {};
 
-  unstacked_output_type operator()(std::vector<Tensor> step_inputs, const hidden_type& input_hidden, const CellParams& params) const {
+  unstacked_output_type operator()(std::vector<Tensor> step_inputs, const hidden_type& input_hidden, const cell_params& params) const {
     std::vector<Tensor> step_outputs;
     auto hidden = input_hidden;
     for (size_t i = 0; i < step_inputs.size(); i++) {
@@ -242,21 +346,21 @@ struct FullLayer : Layer<Tensor, hidden_type, CellParams> {
     return {step_outputs, hidden};
   }
 
-  output_type operator()(const Tensor& inputs, const hidden_type& input_hidden, const CellParams& params) const override {
+  output_type operator()(const Tensor& inputs, const hidden_type& input_hidden, const cell_params& params) const override {
     auto unstacked_output = (*this)(inputs.unbind(0), input_hidden, params);
     return {at::stack(unstacked_output.outputs, 0), unstacked_output.final_hidden};
   }
 
-  Cell<hidden_type>& cell_;
+  Cell<hidden_type, cell_params>& cell_;
 };
 
-template<typename dir_hidden_type>
-struct FullBidirectionalLayer : Layer<Tensor, pair_of<dir_hidden_type>, pair_of<CellParams>> {
+template<typename dir_hidden_type, typename cell_params>
+struct FullBidirectionalLayer : Layer<Tensor, pair_of<dir_hidden_type>, pair_of<cell_params>> {
   using hidden_type = pair_of<dir_hidden_type>;
-  using param_type = pair_of<CellParams>;
+  using param_type = pair_of<cell_params>;
   using output_type = typename Layer<Tensor, hidden_type, param_type>::output_type;
 
-  FullBidirectionalLayer(Cell<dir_hidden_type>& cell)
+  FullBidirectionalLayer(Cell<dir_hidden_type, cell_params>& cell)
     : layer_(cell) {};
 
   output_type operator()(const Tensor& input, const hidden_type& input_hidden, const param_type& params) const override {
@@ -278,17 +382,17 @@ struct FullBidirectionalLayer : Layer<Tensor, pair_of<dir_hidden_type>, pair_of<
     return std::move(x);
   }
 
-  FullLayer<dir_hidden_type> layer_;
+  FullLayer<dir_hidden_type, cell_params> layer_;
 };
 
-template<typename hidden_type>
-struct PackedLayer : Layer<PackedSequence, hidden_type, CellParams> {
-  using output_type = typename Layer<PackedSequence, hidden_type, CellParams>::output_type;
+template<typename hidden_type, typename cell_params>
+struct PackedLayer : Layer<PackedSequence, hidden_type, cell_params> {
+  using output_type = typename Layer<PackedSequence, hidden_type, cell_params>::output_type;
 
-  PackedLayer(Cell<hidden_type>& cell)
+  PackedLayer(Cell<hidden_type, cell_params>& cell)
     : cell_(cell) {};
 
-  output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const CellParams& params) const override {
+  output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const cell_params& params) const override {
     std::vector<at::Tensor> step_outputs;
     std::vector<hidden_type> hiddens;
     int64_t input_offset = 0;
@@ -324,17 +428,17 @@ struct PackedLayer : Layer<PackedSequence, hidden_type, CellParams> {
     return { PackedSequence{ at::cat(step_outputs, 0), input.batch_sizes }, hidden_concat(hiddens) };
   }
 
-  Cell<hidden_type>& cell_;
+  Cell<hidden_type, cell_params>& cell_;
 };
 
-template<typename hidden_type>
-struct ReversedPackedLayer : Layer<PackedSequence, hidden_type, CellParams> {
-  using output_type = typename Layer<PackedSequence, hidden_type, CellParams>::output_type;
+template<typename hidden_type, typename cell_params>
+struct ReversedPackedLayer : Layer<PackedSequence, hidden_type, cell_params> {
+  using output_type = typename Layer<PackedSequence, hidden_type, cell_params>::output_type;
 
-  ReversedPackedLayer(Cell<hidden_type>& cell)
+  ReversedPackedLayer(Cell<hidden_type, cell_params>& cell)
     : cell_(cell) {};
 
-  output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const CellParams& params) const override {
+  output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const cell_params& params) const override {
     std::vector<at::Tensor> step_outputs;
     int64_t input_offset = input.data.size(0);
     int64_t num_steps = input.batch_sizes.size(0);
@@ -364,16 +468,16 @@ struct ReversedPackedLayer : Layer<PackedSequence, hidden_type, CellParams> {
     return { PackedSequence{ at::cat(step_outputs, 0), input.batch_sizes }, hidden };
   }
 
-  Cell<hidden_type>& cell_;
+  Cell<hidden_type, cell_params>& cell_;
 };
 
-template<typename dir_hidden_type>
-struct PackedBidirectionalLayer : Layer<PackedSequence, pair_of<dir_hidden_type>, pair_of<CellParams>> {
+template<typename dir_hidden_type, typename cell_params>
+struct PackedBidirectionalLayer : Layer<PackedSequence, pair_of<dir_hidden_type>, pair_of<cell_params>> {
   using hidden_type = pair_of<dir_hidden_type>;
-  using param_type = pair_of<CellParams>;
+  using param_type = pair_of<cell_params>;
   using output_type = typename Layer<PackedSequence, hidden_type, param_type>::output_type;
 
-  PackedBidirectionalLayer(Cell<dir_hidden_type>& cell)
+  PackedBidirectionalLayer(Cell<dir_hidden_type, cell_params>& cell)
     : layer_(cell), rev_layer_(cell) {};
 
   output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const param_type& params) const override {
@@ -383,8 +487,8 @@ struct PackedBidirectionalLayer : Layer<PackedSequence, pair_of<dir_hidden_type>
     return { output, std::make_pair(fw_result.final_hidden, rev_result.final_hidden) };
   }
 
-  PackedLayer<dir_hidden_type> layer_;
-  ReversedPackedLayer<dir_hidden_type> rev_layer_;
+  PackedLayer<dir_hidden_type, cell_params> layer_;
+  ReversedPackedLayer<dir_hidden_type, cell_params> rev_layer_;
 };
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -432,50 +536,50 @@ apply_layer_stack(const Layer<io_type, hidden_type, weight_type>& layer, const i
 // HELPERS SIMPLIFYING DISPATCH TO FUNCTIONS ABOVE
 ////////////////////////////////////////////////////////////////////////////////
 
-template<typename CellType, template<typename> class LayerT, template<typename> class BidirLayerT, typename io_type>
+template<typename CellType, template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
 LayerOutput<io_type, std::vector<typename CellType::hidden_type>> _rnn_impl(
       const io_type& input,
-      const std::vector<CellParams>& params,
+      const std::vector<cell_params>& params,
       const std::vector<typename CellType::hidden_type>& hiddens,
       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
   using hidden_type = typename CellType::hidden_type;
   CellType cell;
   if (bidirectional) {
-    using BidirLayer = BidirLayerT<hidden_type>;
+    using BidirLayer = BidirLayerT<hidden_type, cell_params>;
     auto bidir_result = apply_layer_stack(BidirLayer{cell}, input, pair_vec(hiddens), pair_vec(params), num_layers, dropout_p, train);
     return {bidir_result.outputs, unpair_vec(std::move(bidir_result.final_hidden))};
   } else {
-    return apply_layer_stack(LayerT<hidden_type>{cell}, input, hiddens, params, num_layers, dropout_p, train);
+    return apply_layer_stack(LayerT<hidden_type,cell_params>{cell}, input, hiddens, params, num_layers, dropout_p, train);
   }
 }
 
-template<typename CellType, template<typename> class LayerT, template<typename> class BidirLayerT, typename io_type>
+template<typename CellType, template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
 std::tuple<io_type, Tensor> _rnn_impl_with_concat(
       const io_type& input,
-      const std::vector<CellParams>& params,
+      const std::vector<cell_params>& params,
       const std::vector<typename CellType::hidden_type>& hiddens,
       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
   auto result = _rnn_impl<CellType, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
   return std::make_tuple(result.outputs, at::stack(result.final_hidden, 0));
 }
 
-template<template<typename> class LayerT, template<typename> class BidirLayerT, typename io_type>
+template<template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
 std::tuple<io_type, Tensor, Tensor> _lstm_impl(
       const io_type& input,
-      const std::vector<CellParams>& params, const Tensor& hx, const Tensor& cx,
+      const std::vector<cell_params>& params, const Tensor& hx, const Tensor& cx,
       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
   // It's much more useful for us to work on lists of pairs of hx and cx for each layer, so we need
   // to transpose a pair of those tensors.
   auto layer_hx = hx.unbind(0);
   auto layer_cx = cx.unbind(0);
   int64_t total_layers = layer_hx.size();
-  std::vector<LSTMCell::hidden_type> hiddens;
+  std::vector<typename LSTMCell<cell_params>::hidden_type> hiddens;
   hiddens.reserve(total_layers);
   for (int64_t i = 0; i < total_layers; ++i) {
     hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
   }
 
-  auto result = _rnn_impl<LSTMCell, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
+  auto result = _rnn_impl<LSTMCell<cell_params>, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
 
   // Now, we need to reverse the transposed we performed above.
   std::vector<Tensor> hy, cy;
@@ -539,9 +643,11 @@ std::tuple<Tensor, Tensor> NAME(                                               \
   return std::make_tuple(packed_output.data, std::get<1>(result));             \
 }
 
-ONE_HIDDEN_RNN(gru, GRUCell)
-ONE_HIDDEN_RNN(rnn_tanh, SimpleCell<tanh_f>)
-ONE_HIDDEN_RNN(rnn_relu, SimpleCell<relu_f>)
+ONE_HIDDEN_RNN(gru, GRUCell<CellParams>)
+using tanf_cell_type = SimpleCell<tanh_f, CellParams>;
+ONE_HIDDEN_RNN(rnn_tanh, tanf_cell_type)
+using relu_cell_type = SimpleCell<relu_f, CellParams>;
+ONE_HIDDEN_RNN(rnn_relu, relu_cell_type);
 
 DEFINE_DISPATCH(lstm_cudnn_stub);
 DEFINE_DISPATCH(lstm_packed_cudnn_stub);
@@ -593,25 +699,113 @@ std::tuple<Tensor, Tensor> lstm_cell(
     const Tensor& input, TensorList hx,
     const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
   AT_CHECK(hx.size() == 2, "lstm_cell expects two hidden states");
-  return LSTMCell{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh});
+  return LSTMCell<CellParams>{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh});
 }
 
 Tensor gru_cell(
     const Tensor& input, const Tensor& hx,
     const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
-  return GRUCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+  return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
 }
 
 Tensor rnn_tanh_cell(
     const Tensor& input, const Tensor& hx,
     const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
-  return SimpleCell<tanh_f>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+  return SimpleCell<tanh_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
 }
 
 Tensor rnn_relu_cell(
     const Tensor& input, const Tensor& hx,
     const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
-  return SimpleCell<relu_f>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+  return SimpleCell<relu_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+}
+
+// Quantized implementations
+//
+// These implementations use FBGEMM to do the i2h and h2h linear layers with
+// an int8 quantized weight. This is advantageous in small-batch-size scenarios
+// where runtime is dominated by memory fetches of the weight matrix.
+
+std::tuple<Tensor, Tensor, Tensor> quantized_lstm(
+      const Tensor& _input, TensorList hx,
+      TensorList _params, bool has_biases,
+      int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
+  AT_CHECK(hx.size() == 2, "lstm expects two hidden states");
+  if (at::cudnn_is_acceptable(_input)) {
+    Tensor output, hy, cy;
+    lstm_cudnn_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases,
+            num_layers, dropout_p, train, bidirectional, batch_first);
+    return std::make_tuple(output, hy, cy);
+  }
+  check_device(_input, _params, hx);
+  auto input = batch_first ? _input.transpose(0, 1) : _input;
+  AT_CHECK(has_biases, "quantized LSTM requires biases");
+  auto params = gather_quantized_params(_params);
+  auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
+      input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
+  if (batch_first) {
+    std::get<0>(results) = std::get<0>(results).transpose(0, 1);
+  }
+  return results;
 }
 
+#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \
+return_type name( \
+    const Tensor& input, \
+    hx_type hx, \
+    const Tensor& w_ih, \
+    const Tensor& w_hh, \
+    const Tensor& b_ih, \
+    const Tensor& b_hh, \
+    const Tensor& packed_ih, \
+    const Tensor& packed_hh, \
+    const Tensor& col_offsets_ih, \
+    const Tensor& col_offsets_hh, \
+    const Scalar scale_ih, \
+    const Scalar scale_hh, \
+    const Scalar zero_point_ih, \
+    const Scalar zero_point_hh) { \
+  QuantizedCellParams params( \
+      w_ih, \
+      w_hh, \
+      b_ih, \
+      b_hh, \
+      packed_ih, \
+      packed_hh, \
+      col_offsets_ih, \
+      col_offsets_hh, \
+      scale_ih, \
+      scale_hh, \
+      zero_point_ih, \
+      zero_point_hh); \
+  return cell_type{}( \
+      input, prepare_hx_fn(hx), params); \
+}
+
+// Quantized LSTM cell
+using quantized_lstm_cell_type = LSTMCell<QuantizedCellParams>;
+using quantized_lstm_return_type = std::tuple<Tensor, Tensor>;
+std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
+  return std::make_tuple(hx[0], hx[1]);
+}
+DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
+
+// Helpers for simpler cells
+using simple_hx_type = const Tensor&;
+simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
+  return hx;
+}
+
+// Quantized GRU cell
+using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
+DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
+
+// Quantized RNN w/ ReLU cell
+using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
+DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
+
+// Quantized RNN w/ tanh cell
+using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
+DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
+
 }}  // namespace at::native
index 7a87353..8d520cd 100644 (file)
 
 - func: rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih={}, Tensor? b_hh={}) -> Tensor
 
+# Quantized RNN layers
+- func: quantized_lstm(Tensor input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)
+
+# Quantized RNN cells
+- func: quantized_lstm_cell(Tensor input, TensorList hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)
+
+- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+
+- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+
+- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+
+
 # PackedSequence utilities
 - func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
 
index f28b61a..22e0d4d 100644 (file)
@@ -32,6 +32,7 @@ import warnings
 import math
 import types
 import pickle
+import copy
 
 from common_methods_invocations import method_tests as autograd_method_tests
 from common_methods_invocations import create_input, unpack_variables, \
@@ -305,6 +306,32 @@ class JitTestCase(TestCase):
             imported.save(fname)
             return torch.jit.load(fname, map_location=map_location)
 
+    def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
+        buffer = io.BytesIO()
+        m.apply(lambda s: s._pack() if s._has_method('_pack') else None)
+        torch.jit.save(m, buffer)
+        m.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+        buffer.seek(0)
+        imported = torch.jit.load(buffer, map_location=map_location)
+        imported.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+
+        if not also_test_file:
+            return imported
+
+        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+        # close the file after creation and try to remove it manually
+        f = tempfile.NamedTemporaryFile(delete=False)
+        try:
+            f.close()
+            imported.save(f.name)
+            result = torch.jit.load(f.name, map_location=map_location)
+        finally:
+            os.unlink(f.name)
+
+        result.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+        return result
+
     def assertGraphContains(self, graph, kind):
         self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
 
@@ -4898,6 +4925,103 @@ a")
         a = A()
         self.assertEqual(a.with_docstring.__doc__, 'test str')
 
+    @unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
+                     'Quantized RNN requires FBGEMM. FBGEMM does not play'
+                     ' well with UBSAN at the moment, so we skip the test if'
+                     ' we are in a UBSAN environment.')
+    def test_rnn_cell_quantized(self):
+        d_in, d_hid = 2, 2
+
+        for cell in [
+            torch.nn.LSTMCell(d_in, d_hid).float(),
+            torch.nn.GRUCell(d_in, d_hid).float(),
+            torch.nn.RNNCell(d_in, d_hid).float(),
+        ]:
+            if isinstance(cell, torch.nn.LSTMCell):
+                num_chunks = 4
+            elif isinstance(cell, torch.nn.GRUCell):
+                num_chunks = 3
+            elif isinstance(cell, torch.nn.RNNCell):
+                num_chunks = 1
+
+            # Replace parameter values s.t. the range of values is exactly
+            # 255, thus we will have 0 quantization error in the quantized
+            # GEMM call. This i s for testing purposes.
+            #
+            # Note that the current implementation does not support
+            # accumulation values outside of the range representable by a
+            # 16 bit integer, instead resulting in a saturated value. We
+            # must take care that in our test we do not end up with a dot
+            # product that overflows the int16 range, e.g.
+            # (255*127+255*127) = 64770. So, we hardcode the test values
+            # here and ensure a mix of signedness.
+            vals = [[100, -155],
+                    [100, -155],
+                    [-155, 100],
+                    [-155, 100],
+                    [100, -155],
+                    [-155, 100],
+                    [-155, 100],
+                    [100, -155]]
+            vals = vals[:d_hid * num_chunks]
+            cell.weight_ih = torch.nn.Parameter(
+                torch.tensor(vals, dtype=torch.float),
+                requires_grad=False)
+            cell.weight_hh = torch.nn.Parameter(
+                torch.tensor(vals, dtype=torch.float),
+                requires_grad=False)
+
+            ref = copy.deepcopy(cell)
+
+            cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
+            x = torch.tensor([[100, -155],
+                              [-155, 100],
+                              [100, -155]], dtype=torch.float)
+            h0_vals = [[-155, 100],
+                       [-155, 155],
+                       [100, -155]]
+            hx = torch.tensor(h0_vals, dtype=torch.float)
+            if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
+                cx = torch.tensor(h0_vals, dtype=torch.float)
+                hiddens = (hx, cx)
+            else:
+                hiddens = hx
+
+            if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
+                from typing import Tuple
+
+                class ScriptWrapper(torch.jit.ScriptModule):
+                    def __init__(self, cell):
+                        super(ScriptWrapper, self).__init__()
+                        self.cell = cell
+
+                    @torch.jit.script_method
+                    def forward(self, x, hiddens):
+                        # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
+                        return self.cell(x, hiddens)
+            else:
+
+                class ScriptWrapper(torch.jit.ScriptModule):
+                    def __init__(self, cell):
+                        super(ScriptWrapper, self).__init__()
+                        self.cell = cell
+
+                    @torch.jit.script_method
+                    def forward(self, x, hiddens):
+                        # type: (torch.Tensor, torch.Tensor)
+                        return self.cell(x, hiddens)
+
+            cell = ScriptWrapper(cell)
+            outs = cell(x, hiddens)
+            cell = self.getExportImportCopyWithPacking(cell)
+
+            outs = cell(x, hiddens)
+            ref_outs = ref(x, hiddens)
+
+            self.assertEqual(len(outs), len(ref_outs))
+            for out, ref_out in zip(outs, ref_outs):
+                torch.testing.assert_allclose(out, ref_out)
+
     def test_script_module(self):
         class M1(torch.jit.ScriptModule):
             def __init__(self):
@@ -5205,10 +5329,7 @@ a")
         # Test save path
         self.assertFalse(sm.pack_called.item())
         self.assertFalse(sm.unpack_called.item())
-        sm.apply(lambda s: s._pack())
-        imported = self.getExportImportCopy(sm)
-        sm.apply(lambda s: s._unpack())
-        imported.apply(lambda s: s._unpack())
+        imported = self.getExportImportCopyWithPacking(sm)
         # ensure pack was called before serialization
         self.assertTrue(sm.pack_called.item())
         # ensure unpack was called after serialization so as to leave the module in an initialized state
@@ -8328,15 +8449,11 @@ a")
             fb_ref = FooBar()
             fb_ref.linear1.weight = torch.nn.Parameter(fb.linear1.weight.clone(), requires_grad=False)
             fb_ref.linear1.bias = torch.nn.Parameter(fb.linear1.bias.clone(), requires_grad=False)
-            torch.jit.quantized.quantize_linear_modules(fb)
+            fb = torch.jit.quantized.quantize_linear_modules(fb)
 
             x = (torch.rand(1, K1).float() - 0.5) / 10.0
             traced = torch.jit.trace(fb, (x,))
-            traced.apply(lambda s: s._pack() if s._has_method('_pack') else None)
-            fb = self.getExportImportCopy(traced)
-            traced.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
-
-            fb.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+            fb = self.getExportImportCopyWithPacking(traced)
 
             x = torch.tensor([[100, -150]], dtype=torch.float)
             y = fb(x)
index 4eb3a91..25c1740 100644 (file)
@@ -1,5 +1,10 @@
 import torch
 import copy
+import numbers
+from typing import Tuple
+
+from torch.nn.utils.rnn import PackedSequence
+from torch.nn import _VF
 
 
 class QuantizedLinear(torch.jit.ScriptModule):
@@ -45,10 +50,200 @@ class QuantizedLinear(torch.jit.ScriptModule):
         return repr
 
 
+# Quantized RNN cell implementations
+class QuantizedRNNCellBase(torch.jit.ScriptModule):
+    __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
+                     'zero_point_ih', 'zero_point_hh']
+
+    def __init__(self, other):
+        super(QuantizedRNNCellBase, self).__init__()
+        self.input_size = other.input_size
+        self.hidden_size = other.hidden_size
+        self.bias = other.bias
+        if not self.bias:
+            raise ValueError("Quantized RNN cells require bias terms")
+
+        weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \
+            torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
+        self.register_buffer('weight_ih', weight_ih)
+        self.register_buffer('col_offsets_ih', col_offsets_ih)
+        weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \
+            torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
+        self.register_buffer('weight_hh', weight_hh)
+        self.register_buffer('col_offsets_hh', col_offsets_hh)
+
+        packed_ih = torch.fbgemm_pack_quantized_matrix(
+            self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0))
+        self.register_buffer('packed_ih', packed_ih)
+        packed_hh = torch.fbgemm_pack_quantized_matrix(
+            self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0))
+        self.register_buffer('packed_hh', packed_hh)
+
+        self.bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=False)
+        self.bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=False)
+
+    def extra_repr(self):
+        s = '{input_size}, {hidden_size}'
+        if 'bias' in self.__dict__ and self.bias is not True:
+            s += ', bias={bias}'
+        if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
+            s += ', nonlinearity={nonlinearity}'
+        return s.format(**self.__dict__)
+
+    @torch.jit.script_method
+    def check_forward_input(self, input):
+        if input.size(1) != self.input_size:
+            raise RuntimeError(
+                "input has inconsistent input_size: got {}, expected {}".format(
+                    input.size(1), self.input_size))
+
+    @torch.jit.script_method
+    def check_forward_hidden(self, input, hx, hidden_label=''):
+        # type: (Tensor, Tensor, str) -> None
+        if input.size(0) != hx.size(0):
+            raise RuntimeError(
+                "Input batch size {} doesn't match hidden{} batch size {}".format(
+                    input.size(0), hidden_label, hx.size(0)))
+
+        if hx.size(1) != self.hidden_size:
+            raise RuntimeError(
+                "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
+                    hidden_label, hx.size(1), self.hidden_size))
+
+    # TODO: for some reason weak_script_method causes a destruction of the
+    # module to occur, which in turn frees the packed_ih object via its DataPtr
+    # deleter. This is bizarre and should probably get fixed.
+    # @torch._jit_internal.weak_script_method
+    @torch.jit.script_method
+    def _unpack(self):
+        self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(
+            self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0)))
+        self.packed_hh.set_(
+            torch.fbgemm_pack_quantized_matrix(
+                self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0)))
+
+    # @torch._jit_internal.weak_script_method
+    @torch.jit.script_method
+    def _pack(self):
+        self.packed_ih.set_(
+            torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
+        self.packed_hh.set_(
+            torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
+
+
+class QuantizedRNNCell(QuantizedRNNCellBase):
+    __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
+                     'zero_point_ih', 'zero_point_hh', 'nonlinearity']
+
+    def __init__(self, other):
+        super(QuantizedRNNCell, self).__init__(other)
+        self.nonlinearity = other.nonlinearity
+
+    @torch.jit.script_method
+    def forward(self, input, hx=None):
+        # type: (Tensor, Optional[Tensor]) -> Tensor
+        self.check_forward_input(input)
+        if hx is None:
+            _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        else:
+            _hx = torch.jit._unwrap_optional(hx)
+        self.check_forward_hidden(input, _hx, '')
+        if self.nonlinearity == "tanh":
+            ret = _VF.quantized_rnn_tanh_cell(
+                input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+                self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+                self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+                self.zero_point_hh
+            )
+        elif self.nonlinearity == "relu":
+            ret = _VF.quantized_rnn_relu_cell(
+                input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+                self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+                self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+                self.zero_point_hh
+            )
+        else:
+            ret = input  # TODO: remove when jit supports exception flow
+            raise RuntimeError(
+                "Unknown nonlinearity: {}".format(self.nonlinearity))
+        return ret
+
+
+class QuantizedLSTMCell(QuantizedRNNCellBase):
+    def __init__(self, other):
+        super(QuantizedLSTMCell, self).__init__(other)
+
+    @torch.jit.script_method
+    def forward(self, input, hx=None):
+        # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
+        self.check_forward_input(input)
+        if hx is None:
+            zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+            _hx = (zeros, zeros)
+        else:
+            _hx = torch.jit._unwrap_optional(hx)
+        self.check_forward_hidden(input, _hx[0], '[0]')
+        self.check_forward_hidden(input, _hx[1], '[1]')
+        return _VF.quantized_lstm_cell(
+            input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+            self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+            self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+            self.zero_point_hh
+        )
+
+
+class QuantizedGRUCell(QuantizedRNNCellBase):
+    def __init__(self, other):
+        super(QuantizedGRUCell, self).__init__(other)
+
+    @torch.jit.script_method
+    def forward(self, input, hx=None):
+        # type: (Tensor, Optional[Tensor]) -> Tensor
+        self.check_forward_input(input)
+        if hx is None:
+            _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        else:
+            _hx = torch.jit._unwrap_optional(hx)
+        self.check_forward_hidden(input, _hx, '')
+        return _VF.quantized_gru_cell(
+            input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+            self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+            self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+            self.zero_point_hh
+        )
+
+
+def quantize_rnn_cell_modules(module):
+    reassign = {}
+    for name, mod in module.named_modules():
+        if mod is module:
+            continue
+        new_mod = quantize_rnn_cell_modules(mod)
+        if new_mod is not mod:
+            reassign[name] = new_mod
+    for name, mod in reassign.items():
+        setattr(module, name, mod)
+    if isinstance(module, torch.nn.LSTMCell):
+        return QuantizedLSTMCell(mod)
+    if isinstance(module, torch.nn.GRUCell):
+        return QuantizedGRUCell(mod)
+    if isinstance(module, torch.nn.RNNCell):
+        return QuantizedRNNCell(mod)
+
+    return module
+
+
 def quantize_linear_modules(module):
+    reassign = {}
     for name, mod in module.named_modules():
         if mod is module:
             continue
-        if isinstance(mod, torch.nn.Linear):
-            setattr(module, name, QuantizedLinear(mod))
-        quantize_linear_modules(mod)
+        new_mod = quantize_linear_modules(mod)
+        if new_mod is not mod:
+            reassign[name] = new_mod
+
+    for name, mod in reassign.items():
+        setattr(module, name, mod)
+    if isinstance(mod, torch.nn.Linear):
+        return QuantizedLinear(mod)
+    return module