const Tensor& w_hh;
const Tensor& b_ih; /* optional */
const Tensor& b_hh; /* optional */
+
+ Tensor matmul_ih(Tensor input) const {
+ return at::matmul(input, w_ih.t());
+ }
+ Tensor matmul_hh(Tensor h) const {
+ return at::matmul(h, w_hh.t());
+ }
+ Tensor linear_ih(Tensor input) const {
+ return at::linear(input, w_ih, b_ih);
+ }
+ Tensor linear_hh(Tensor h) const {
+ return at::linear(h, w_hh, b_hh);
+ }
+};
+
+// Run this Python script and pipe to clang-format to generate the constructor
+// and data members:
+//
+// names = ['w', 'b', 'packed', 'col_offsets', 'scale', 'zero_point']
+//
+//
+// get_type = lambda i: 'Scalar' if i == 4 or i == 5 else 'Tensor'
+// member_ref = lambda i: '' if i == 4 or i == 5 else '&'
+//
+// suffixes = ['ih', 'hh']
+//
+// params = []
+// initializers = []
+// members = []
+// for i in range(len(names)*2):
+// params.append('const {typ}& _{name}_{suffix}'.format(typ=get_type(
+// i//2), name=names[(i//2) % len(names)], suffix=suffixes[i % 2]))
+// initializers.append('{name}_{suffix}(_{name}_{suffix})'.format(
+// name=names[(i//2) % len(names)], suffix=suffixes[i % 2]))
+// members.append('const {typ}{member_ref} {name}_{suffix};'.format(typ=get_type(
+// i//2), name=names[(i//2) % len(names)], suffix=suffixes[i % 2], member_ref=member_ref(i//2)))
+//
+// params_str = ', '.join(params)
+// initializers_str = ', '.join(initializers)
+// members_str = '\n'.join(members)
+//
+// ctor = 'QuantizedCellParams(' + params_str + ') : ' + initializers_str + '{}'
+// print('struct QuantizedCellParams {', '\n\n'.join([ctor, members_str]), '};')
+
+struct QuantizedCellParams {
+ QuantizedCellParams(const Tensor &_w_ih, const Tensor &_w_hh,
+ const Tensor &_b_ih, const Tensor &_b_hh,
+ const Tensor &_packed_ih, const Tensor &_packed_hh,
+ const Tensor &_col_offsets_ih,
+ const Tensor &_col_offsets_hh, const Scalar &_scale_ih,
+ const Scalar &_scale_hh, const Scalar &_zero_point_ih,
+ const Scalar &_zero_point_hh)
+ : w_ih(_w_ih), w_hh(_w_hh), b_ih(_b_ih), b_hh(_b_hh),
+ packed_ih(_packed_ih), packed_hh(_packed_hh),
+ col_offsets_ih(_col_offsets_ih), col_offsets_hh(_col_offsets_hh),
+ scale_ih(_scale_ih), scale_hh(_scale_hh), zero_point_ih(_zero_point_ih),
+ zero_point_hh(_zero_point_hh) {}
+
+ const Tensor &w_ih;
+ const Tensor &w_hh;
+ const Tensor &b_ih;
+ const Tensor &b_hh;
+ const Tensor &packed_ih;
+ const Tensor &packed_hh;
+ const Tensor &col_offsets_ih;
+ const Tensor &col_offsets_hh;
+ const Scalar scale_ih;
+ const Scalar scale_hh;
+ const Scalar zero_point_ih;
+ const Scalar zero_point_hh;
+
+ Tensor matmul_ih(Tensor input) const {
+ AT_CHECK(false, "matmul is not supported with quantized cell params");
+ }
+ Tensor matmul_hh(Tensor h) const {
+ AT_CHECK(false, "matmul is not supported with quantized cell params");
+ }
+ Tensor linear_ih(Tensor input) const {
+ return at::fbgemm_linear_int8_weight(
+ input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih);
+ }
+ Tensor linear_hh(Tensor h) const {
+ return at::fbgemm_linear_int8_weight(
+ h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh);
+ }
};
// Gathers every two elements of a vector in a vector of pairs
return result;
}
+static std::vector<QuantizedCellParams> gather_quantized_params(TensorList params) {
+ static at::Tensor undefined;
+ std::vector<QuantizedCellParams> result;
+ AT_CHECK(params.size() % 12 == 0, "got an incorrect number of quantized RNN parameters");
+ for (size_t i = 0; i < params.size(); i += 12) {
+ result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3],
+ params[i + 4], params[i + 5], params[i + 6], params[i + 7],
+ params[i + 8].item(), params[i + 9].item(),
+ params[i + 10].item(), params[i + 11].item());
+ }
+ return result;
+}
+
////////////////////////////////////////////////////////////////////////////////
// HIDDEN STATE FUNCTIONS
// It's a struct only because functional programming in C++ is a pain, and it's easier
// to pass around "vtable pointers" than actual function pointers.
-template<typename hidden_type_tmpl>
+template<typename hidden_type_tmpl, typename cell_params_tmpl>
struct Cell {
using hidden_type = hidden_type_tmpl;
+ using cell_params = cell_params_tmpl;
virtual ~Cell() {} // This is really dumb, but enables projects with -Wnon-virtual-dtor to compile...
- virtual hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const = 0;
+ virtual hidden_type operator()(const Tensor& input, const hidden_type& hidden, const cell_params& params) const = 0;
};
-template<typename nonlinearity>
-struct SimpleCell : Cell<Tensor> {
- hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const override {
- return nonlinearity{}(at::linear(input, params.w_ih, params.b_ih) + at::linear(hidden, params.w_hh, params.b_hh));
+template<typename nonlinearity, typename cell_params>
+struct SimpleCell : Cell<Tensor, cell_params> {
+ using hidden_type = Tensor;
+ Tensor operator()(const Tensor& input, const Tensor& hidden, const cell_params& params) const override {
+ return nonlinearity{}(params.linear_ih(input) + params.linear_hh(hidden));
}
};
// TODO: can use inplace ops?
-struct LSTMCell : Cell<std::tuple<Tensor, Tensor>> {
- hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const override {
+template <typename cell_params>
+struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
+ using hidden_type = std::tuple<Tensor, Tensor>;
+ hidden_type operator()(const Tensor& input, const hidden_type& hidden, const cell_params& params) const override {
auto hx = std::get<0>(hidden);
auto cx = std::get<1>(hidden);
if (input.is_cuda()) {
- auto igates = at::matmul(input, params.w_ih.t());
- auto hgates = at::matmul(hx, params.w_hh.t());
+ auto igates = params.matmul_ih(input);
+ auto hgates = params.matmul_hh(hx);
auto result = at::_thnn_fused_lstm_cell(igates, hgates, cx, params.b_ih, params.b_hh);
// Slice off the workspace argument (it's needed only for AD).
return std::make_tuple(std::get<0>(result), std::get<1>(result));
}
- auto gates = at::linear(input, params.w_ih, params.b_ih) + at::linear(hx, params.w_hh, params.b_hh);
+ auto gates = params.linear_ih(input) + params.linear_hh(hx);
auto chunked_gates = gates.chunk(4, 1);
auto ingate = chunked_gates[0].sigmoid();
}
};
-struct GRUCell : Cell<Tensor> {
- hidden_type operator()(const Tensor& input, const hidden_type& hidden, const CellParams& params) const override {
+template <typename cell_params>
+struct GRUCell : Cell<Tensor, cell_params> {
+ using hidden_type = Tensor;
+ hidden_type operator()(const Tensor& input, const hidden_type& hidden, const cell_params& params) const override {
if (input.is_cuda()) {
- auto igates = at::matmul(input, params.w_ih.t());
- auto hgates = at::matmul(hidden, params.w_hh.t());
+ auto igates = params.matmul_ih(input);
+ auto hgates = params.matmul_hh(hidden);
auto result = at::_thnn_fused_gru_cell(igates, hgates, hidden, params.b_ih, params.b_hh);
// Slice off the workspace argument (it's needed only for AD).
return std::get<0>(result);
}
- auto igates = at::linear(input, params.w_ih, params.b_ih);
- auto hgates = at::linear(hidden, params.w_hh, params.b_hh);
+ auto igates = params.linear_ih(input);
+ auto hgates = params.linear_hh(hidden);
auto chunked_igates = igates.chunk(3, 1);
auto chunked_hgates = hgates.chunk(3, 1);
virtual output_type operator()(const io_type& input, const hidden_type& input_hidden, const param_type& params) const = 0;
};
-template<typename hidden_type>
-struct FullLayer : Layer<Tensor, hidden_type, CellParams> {
- using output_type = typename Layer<Tensor, hidden_type, CellParams>::output_type;
+template<typename hidden_type, typename cell_params>
+struct FullLayer : Layer<Tensor, hidden_type, cell_params> {
+ using output_type = typename Layer<Tensor, hidden_type, cell_params>::output_type;
using unstacked_output_type = LayerOutput<std::vector<Tensor>, hidden_type>;
- FullLayer(Cell<hidden_type>& cell)
+ FullLayer(Cell<hidden_type, cell_params>& cell)
: cell_(cell) {};
- unstacked_output_type operator()(std::vector<Tensor> step_inputs, const hidden_type& input_hidden, const CellParams& params) const {
+ unstacked_output_type operator()(std::vector<Tensor> step_inputs, const hidden_type& input_hidden, const cell_params& params) const {
std::vector<Tensor> step_outputs;
auto hidden = input_hidden;
for (size_t i = 0; i < step_inputs.size(); i++) {
return {step_outputs, hidden};
}
- output_type operator()(const Tensor& inputs, const hidden_type& input_hidden, const CellParams& params) const override {
+ output_type operator()(const Tensor& inputs, const hidden_type& input_hidden, const cell_params& params) const override {
auto unstacked_output = (*this)(inputs.unbind(0), input_hidden, params);
return {at::stack(unstacked_output.outputs, 0), unstacked_output.final_hidden};
}
- Cell<hidden_type>& cell_;
+ Cell<hidden_type, cell_params>& cell_;
};
-template<typename dir_hidden_type>
-struct FullBidirectionalLayer : Layer<Tensor, pair_of<dir_hidden_type>, pair_of<CellParams>> {
+template<typename dir_hidden_type, typename cell_params>
+struct FullBidirectionalLayer : Layer<Tensor, pair_of<dir_hidden_type>, pair_of<cell_params>> {
using hidden_type = pair_of<dir_hidden_type>;
- using param_type = pair_of<CellParams>;
+ using param_type = pair_of<cell_params>;
using output_type = typename Layer<Tensor, hidden_type, param_type>::output_type;
- FullBidirectionalLayer(Cell<dir_hidden_type>& cell)
+ FullBidirectionalLayer(Cell<dir_hidden_type, cell_params>& cell)
: layer_(cell) {};
output_type operator()(const Tensor& input, const hidden_type& input_hidden, const param_type& params) const override {
return std::move(x);
}
- FullLayer<dir_hidden_type> layer_;
+ FullLayer<dir_hidden_type, cell_params> layer_;
};
-template<typename hidden_type>
-struct PackedLayer : Layer<PackedSequence, hidden_type, CellParams> {
- using output_type = typename Layer<PackedSequence, hidden_type, CellParams>::output_type;
+template<typename hidden_type, typename cell_params>
+struct PackedLayer : Layer<PackedSequence, hidden_type, cell_params> {
+ using output_type = typename Layer<PackedSequence, hidden_type, cell_params>::output_type;
- PackedLayer(Cell<hidden_type>& cell)
+ PackedLayer(Cell<hidden_type, cell_params>& cell)
: cell_(cell) {};
- output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const CellParams& params) const override {
+ output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const cell_params& params) const override {
std::vector<at::Tensor> step_outputs;
std::vector<hidden_type> hiddens;
int64_t input_offset = 0;
return { PackedSequence{ at::cat(step_outputs, 0), input.batch_sizes }, hidden_concat(hiddens) };
}
- Cell<hidden_type>& cell_;
+ Cell<hidden_type, cell_params>& cell_;
};
-template<typename hidden_type>
-struct ReversedPackedLayer : Layer<PackedSequence, hidden_type, CellParams> {
- using output_type = typename Layer<PackedSequence, hidden_type, CellParams>::output_type;
+template<typename hidden_type, typename cell_params>
+struct ReversedPackedLayer : Layer<PackedSequence, hidden_type, cell_params> {
+ using output_type = typename Layer<PackedSequence, hidden_type, cell_params>::output_type;
- ReversedPackedLayer(Cell<hidden_type>& cell)
+ ReversedPackedLayer(Cell<hidden_type, cell_params>& cell)
: cell_(cell) {};
- output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const CellParams& params) const override {
+ output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const cell_params& params) const override {
std::vector<at::Tensor> step_outputs;
int64_t input_offset = input.data.size(0);
int64_t num_steps = input.batch_sizes.size(0);
return { PackedSequence{ at::cat(step_outputs, 0), input.batch_sizes }, hidden };
}
- Cell<hidden_type>& cell_;
+ Cell<hidden_type, cell_params>& cell_;
};
-template<typename dir_hidden_type>
-struct PackedBidirectionalLayer : Layer<PackedSequence, pair_of<dir_hidden_type>, pair_of<CellParams>> {
+template<typename dir_hidden_type, typename cell_params>
+struct PackedBidirectionalLayer : Layer<PackedSequence, pair_of<dir_hidden_type>, pair_of<cell_params>> {
using hidden_type = pair_of<dir_hidden_type>;
- using param_type = pair_of<CellParams>;
+ using param_type = pair_of<cell_params>;
using output_type = typename Layer<PackedSequence, hidden_type, param_type>::output_type;
- PackedBidirectionalLayer(Cell<dir_hidden_type>& cell)
+ PackedBidirectionalLayer(Cell<dir_hidden_type, cell_params>& cell)
: layer_(cell), rev_layer_(cell) {};
output_type operator()(const PackedSequence& input, const hidden_type& input_hidden, const param_type& params) const override {
return { output, std::make_pair(fw_result.final_hidden, rev_result.final_hidden) };
}
- PackedLayer<dir_hidden_type> layer_;
- ReversedPackedLayer<dir_hidden_type> rev_layer_;
+ PackedLayer<dir_hidden_type, cell_params> layer_;
+ ReversedPackedLayer<dir_hidden_type, cell_params> rev_layer_;
};
////////////////////////////////////////////////////////////////////////////////
// HELPERS SIMPLIFYING DISPATCH TO FUNCTIONS ABOVE
////////////////////////////////////////////////////////////////////////////////
-template<typename CellType, template<typename> class LayerT, template<typename> class BidirLayerT, typename io_type>
+template<typename CellType, template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
LayerOutput<io_type, std::vector<typename CellType::hidden_type>> _rnn_impl(
const io_type& input,
- const std::vector<CellParams>& params,
+ const std::vector<cell_params>& params,
const std::vector<typename CellType::hidden_type>& hiddens,
int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
using hidden_type = typename CellType::hidden_type;
CellType cell;
if (bidirectional) {
- using BidirLayer = BidirLayerT<hidden_type>;
+ using BidirLayer = BidirLayerT<hidden_type, cell_params>;
auto bidir_result = apply_layer_stack(BidirLayer{cell}, input, pair_vec(hiddens), pair_vec(params), num_layers, dropout_p, train);
return {bidir_result.outputs, unpair_vec(std::move(bidir_result.final_hidden))};
} else {
- return apply_layer_stack(LayerT<hidden_type>{cell}, input, hiddens, params, num_layers, dropout_p, train);
+ return apply_layer_stack(LayerT<hidden_type,cell_params>{cell}, input, hiddens, params, num_layers, dropout_p, train);
}
}
-template<typename CellType, template<typename> class LayerT, template<typename> class BidirLayerT, typename io_type>
+template<typename CellType, template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
std::tuple<io_type, Tensor> _rnn_impl_with_concat(
const io_type& input,
- const std::vector<CellParams>& params,
+ const std::vector<cell_params>& params,
const std::vector<typename CellType::hidden_type>& hiddens,
int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
auto result = _rnn_impl<CellType, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(result.outputs, at::stack(result.final_hidden, 0));
}
-template<template<typename> class LayerT, template<typename> class BidirLayerT, typename io_type>
+template<template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
std::tuple<io_type, Tensor, Tensor> _lstm_impl(
const io_type& input,
- const std::vector<CellParams>& params, const Tensor& hx, const Tensor& cx,
+ const std::vector<cell_params>& params, const Tensor& hx, const Tensor& cx,
int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
// It's much more useful for us to work on lists of pairs of hx and cx for each layer, so we need
// to transpose a pair of those tensors.
auto layer_hx = hx.unbind(0);
auto layer_cx = cx.unbind(0);
int64_t total_layers = layer_hx.size();
- std::vector<LSTMCell::hidden_type> hiddens;
+ std::vector<typename LSTMCell<cell_params>::hidden_type> hiddens;
hiddens.reserve(total_layers);
for (int64_t i = 0; i < total_layers; ++i) {
hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
}
- auto result = _rnn_impl<LSTMCell, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
+ auto result = _rnn_impl<LSTMCell<cell_params>, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
// Now, we need to reverse the transposed we performed above.
std::vector<Tensor> hy, cy;
return std::make_tuple(packed_output.data, std::get<1>(result)); \
}
-ONE_HIDDEN_RNN(gru, GRUCell)
-ONE_HIDDEN_RNN(rnn_tanh, SimpleCell<tanh_f>)
-ONE_HIDDEN_RNN(rnn_relu, SimpleCell<relu_f>)
+ONE_HIDDEN_RNN(gru, GRUCell<CellParams>)
+using tanf_cell_type = SimpleCell<tanh_f, CellParams>;
+ONE_HIDDEN_RNN(rnn_tanh, tanf_cell_type)
+using relu_cell_type = SimpleCell<relu_f, CellParams>;
+ONE_HIDDEN_RNN(rnn_relu, relu_cell_type);
DEFINE_DISPATCH(lstm_cudnn_stub);
DEFINE_DISPATCH(lstm_packed_cudnn_stub);
const Tensor& input, TensorList hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
AT_CHECK(hx.size() == 2, "lstm_cell expects two hidden states");
- return LSTMCell{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh});
+ return LSTMCell<CellParams>{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh});
}
Tensor gru_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
- return GRUCell{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+ return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
}
Tensor rnn_tanh_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
- return SimpleCell<tanh_f>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+ return SimpleCell<tanh_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
}
Tensor rnn_relu_cell(
const Tensor& input, const Tensor& hx,
const Tensor& w_ih, const Tensor& w_hh, const Tensor& b_ih, const Tensor& b_hh) {
- return SimpleCell<relu_f>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+ return SimpleCell<relu_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh});
+}
+
+// Quantized implementations
+//
+// These implementations use FBGEMM to do the i2h and h2h linear layers with
+// an int8 quantized weight. This is advantageous in small-batch-size scenarios
+// where runtime is dominated by memory fetches of the weight matrix.
+
+std::tuple<Tensor, Tensor, Tensor> quantized_lstm(
+ const Tensor& _input, TensorList hx,
+ TensorList _params, bool has_biases,
+ int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
+ AT_CHECK(hx.size() == 2, "lstm expects two hidden states");
+ if (at::cudnn_is_acceptable(_input)) {
+ Tensor output, hy, cy;
+ lstm_cudnn_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases,
+ num_layers, dropout_p, train, bidirectional, batch_first);
+ return std::make_tuple(output, hy, cy);
+ }
+ check_device(_input, _params, hx);
+ auto input = batch_first ? _input.transpose(0, 1) : _input;
+ AT_CHECK(has_biases, "quantized LSTM requires biases");
+ auto params = gather_quantized_params(_params);
+ auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
+ input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
+ if (batch_first) {
+ std::get<0>(results) = std::get<0>(results).transpose(0, 1);
+ }
+ return results;
}
+#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \
+return_type name( \
+ const Tensor& input, \
+ hx_type hx, \
+ const Tensor& w_ih, \
+ const Tensor& w_hh, \
+ const Tensor& b_ih, \
+ const Tensor& b_hh, \
+ const Tensor& packed_ih, \
+ const Tensor& packed_hh, \
+ const Tensor& col_offsets_ih, \
+ const Tensor& col_offsets_hh, \
+ const Scalar scale_ih, \
+ const Scalar scale_hh, \
+ const Scalar zero_point_ih, \
+ const Scalar zero_point_hh) { \
+ QuantizedCellParams params( \
+ w_ih, \
+ w_hh, \
+ b_ih, \
+ b_hh, \
+ packed_ih, \
+ packed_hh, \
+ col_offsets_ih, \
+ col_offsets_hh, \
+ scale_ih, \
+ scale_hh, \
+ zero_point_ih, \
+ zero_point_hh); \
+ return cell_type{}( \
+ input, prepare_hx_fn(hx), params); \
+}
+
+// Quantized LSTM cell
+using quantized_lstm_cell_type = LSTMCell<QuantizedCellParams>;
+using quantized_lstm_return_type = std::tuple<Tensor, Tensor>;
+std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
+ return std::make_tuple(hx[0], hx[1]);
+}
+DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
+
+// Helpers for simpler cells
+using simple_hx_type = const Tensor&;
+simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
+ return hx;
+}
+
+// Quantized GRU cell
+using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
+DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
+
+// Quantized RNN w/ ReLU cell
+using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
+DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
+
+// Quantized RNN w/ tanh cell
+using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
+DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
+
}} // namespace at::native
import math
import types
import pickle
+import copy
from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
imported.save(fname)
return torch.jit.load(fname, map_location=map_location)
+ def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
+ buffer = io.BytesIO()
+ m.apply(lambda s: s._pack() if s._has_method('_pack') else None)
+ torch.jit.save(m, buffer)
+ m.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+ buffer.seek(0)
+ imported = torch.jit.load(buffer, map_location=map_location)
+ imported.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+
+ if not also_test_file:
+ return imported
+
+ # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
+ # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
+ # close the file after creation and try to remove it manually
+ f = tempfile.NamedTemporaryFile(delete=False)
+ try:
+ f.close()
+ imported.save(f.name)
+ result = torch.jit.load(f.name, map_location=map_location)
+ finally:
+ os.unlink(f.name)
+
+ result.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+ return result
+
def assertGraphContains(self, graph, kind):
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
a = A()
self.assertEqual(a.with_docstring.__doc__, 'test str')
+ @unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
+ 'Quantized RNN requires FBGEMM. FBGEMM does not play'
+ ' well with UBSAN at the moment, so we skip the test if'
+ ' we are in a UBSAN environment.')
+ def test_rnn_cell_quantized(self):
+ d_in, d_hid = 2, 2
+
+ for cell in [
+ torch.nn.LSTMCell(d_in, d_hid).float(),
+ torch.nn.GRUCell(d_in, d_hid).float(),
+ torch.nn.RNNCell(d_in, d_hid).float(),
+ ]:
+ if isinstance(cell, torch.nn.LSTMCell):
+ num_chunks = 4
+ elif isinstance(cell, torch.nn.GRUCell):
+ num_chunks = 3
+ elif isinstance(cell, torch.nn.RNNCell):
+ num_chunks = 1
+
+ # Replace parameter values s.t. the range of values is exactly
+ # 255, thus we will have 0 quantization error in the quantized
+ # GEMM call. This i s for testing purposes.
+ #
+ # Note that the current implementation does not support
+ # accumulation values outside of the range representable by a
+ # 16 bit integer, instead resulting in a saturated value. We
+ # must take care that in our test we do not end up with a dot
+ # product that overflows the int16 range, e.g.
+ # (255*127+255*127) = 64770. So, we hardcode the test values
+ # here and ensure a mix of signedness.
+ vals = [[100, -155],
+ [100, -155],
+ [-155, 100],
+ [-155, 100],
+ [100, -155],
+ [-155, 100],
+ [-155, 100],
+ [100, -155]]
+ vals = vals[:d_hid * num_chunks]
+ cell.weight_ih = torch.nn.Parameter(
+ torch.tensor(vals, dtype=torch.float),
+ requires_grad=False)
+ cell.weight_hh = torch.nn.Parameter(
+ torch.tensor(vals, dtype=torch.float),
+ requires_grad=False)
+
+ ref = copy.deepcopy(cell)
+
+ cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
+ x = torch.tensor([[100, -155],
+ [-155, 100],
+ [100, -155]], dtype=torch.float)
+ h0_vals = [[-155, 100],
+ [-155, 155],
+ [100, -155]]
+ hx = torch.tensor(h0_vals, dtype=torch.float)
+ if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
+ cx = torch.tensor(h0_vals, dtype=torch.float)
+ hiddens = (hx, cx)
+ else:
+ hiddens = hx
+
+ if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
+ from typing import Tuple
+
+ class ScriptWrapper(torch.jit.ScriptModule):
+ def __init__(self, cell):
+ super(ScriptWrapper, self).__init__()
+ self.cell = cell
+
+ @torch.jit.script_method
+ def forward(self, x, hiddens):
+ # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
+ return self.cell(x, hiddens)
+ else:
+
+ class ScriptWrapper(torch.jit.ScriptModule):
+ def __init__(self, cell):
+ super(ScriptWrapper, self).__init__()
+ self.cell = cell
+
+ @torch.jit.script_method
+ def forward(self, x, hiddens):
+ # type: (torch.Tensor, torch.Tensor)
+ return self.cell(x, hiddens)
+
+ cell = ScriptWrapper(cell)
+ outs = cell(x, hiddens)
+ cell = self.getExportImportCopyWithPacking(cell)
+
+ outs = cell(x, hiddens)
+ ref_outs = ref(x, hiddens)
+
+ self.assertEqual(len(outs), len(ref_outs))
+ for out, ref_out in zip(outs, ref_outs):
+ torch.testing.assert_allclose(out, ref_out)
+
def test_script_module(self):
class M1(torch.jit.ScriptModule):
def __init__(self):
# Test save path
self.assertFalse(sm.pack_called.item())
self.assertFalse(sm.unpack_called.item())
- sm.apply(lambda s: s._pack())
- imported = self.getExportImportCopy(sm)
- sm.apply(lambda s: s._unpack())
- imported.apply(lambda s: s._unpack())
+ imported = self.getExportImportCopyWithPacking(sm)
# ensure pack was called before serialization
self.assertTrue(sm.pack_called.item())
# ensure unpack was called after serialization so as to leave the module in an initialized state
fb_ref = FooBar()
fb_ref.linear1.weight = torch.nn.Parameter(fb.linear1.weight.clone(), requires_grad=False)
fb_ref.linear1.bias = torch.nn.Parameter(fb.linear1.bias.clone(), requires_grad=False)
- torch.jit.quantized.quantize_linear_modules(fb)
+ fb = torch.jit.quantized.quantize_linear_modules(fb)
x = (torch.rand(1, K1).float() - 0.5) / 10.0
traced = torch.jit.trace(fb, (x,))
- traced.apply(lambda s: s._pack() if s._has_method('_pack') else None)
- fb = self.getExportImportCopy(traced)
- traced.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
-
- fb.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
+ fb = self.getExportImportCopyWithPacking(traced)
x = torch.tensor([[100, -150]], dtype=torch.float)
y = fb(x)
import torch
import copy
+import numbers
+from typing import Tuple
+
+from torch.nn.utils.rnn import PackedSequence
+from torch.nn import _VF
class QuantizedLinear(torch.jit.ScriptModule):
return repr
+# Quantized RNN cell implementations
+class QuantizedRNNCellBase(torch.jit.ScriptModule):
+ __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
+ 'zero_point_ih', 'zero_point_hh']
+
+ def __init__(self, other):
+ super(QuantizedRNNCellBase, self).__init__()
+ self.input_size = other.input_size
+ self.hidden_size = other.hidden_size
+ self.bias = other.bias
+ if not self.bias:
+ raise ValueError("Quantized RNN cells require bias terms")
+
+ weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \
+ torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
+ self.register_buffer('weight_ih', weight_ih)
+ self.register_buffer('col_offsets_ih', col_offsets_ih)
+ weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \
+ torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
+ self.register_buffer('weight_hh', weight_hh)
+ self.register_buffer('col_offsets_hh', col_offsets_hh)
+
+ packed_ih = torch.fbgemm_pack_quantized_matrix(
+ self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0))
+ self.register_buffer('packed_ih', packed_ih)
+ packed_hh = torch.fbgemm_pack_quantized_matrix(
+ self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0))
+ self.register_buffer('packed_hh', packed_hh)
+
+ self.bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=False)
+ self.bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=False)
+
+ def extra_repr(self):
+ s = '{input_size}, {hidden_size}'
+ if 'bias' in self.__dict__ and self.bias is not True:
+ s += ', bias={bias}'
+ if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
+ s += ', nonlinearity={nonlinearity}'
+ return s.format(**self.__dict__)
+
+ @torch.jit.script_method
+ def check_forward_input(self, input):
+ if input.size(1) != self.input_size:
+ raise RuntimeError(
+ "input has inconsistent input_size: got {}, expected {}".format(
+ input.size(1), self.input_size))
+
+ @torch.jit.script_method
+ def check_forward_hidden(self, input, hx, hidden_label=''):
+ # type: (Tensor, Tensor, str) -> None
+ if input.size(0) != hx.size(0):
+ raise RuntimeError(
+ "Input batch size {} doesn't match hidden{} batch size {}".format(
+ input.size(0), hidden_label, hx.size(0)))
+
+ if hx.size(1) != self.hidden_size:
+ raise RuntimeError(
+ "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
+ hidden_label, hx.size(1), self.hidden_size))
+
+ # TODO: for some reason weak_script_method causes a destruction of the
+ # module to occur, which in turn frees the packed_ih object via its DataPtr
+ # deleter. This is bizarre and should probably get fixed.
+ # @torch._jit_internal.weak_script_method
+ @torch.jit.script_method
+ def _unpack(self):
+ self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(
+ self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0)))
+ self.packed_hh.set_(
+ torch.fbgemm_pack_quantized_matrix(
+ self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0)))
+
+ # @torch._jit_internal.weak_script_method
+ @torch.jit.script_method
+ def _pack(self):
+ self.packed_ih.set_(
+ torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
+ self.packed_hh.set_(
+ torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
+
+
+class QuantizedRNNCell(QuantizedRNNCellBase):
+ __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
+ 'zero_point_ih', 'zero_point_hh', 'nonlinearity']
+
+ def __init__(self, other):
+ super(QuantizedRNNCell, self).__init__(other)
+ self.nonlinearity = other.nonlinearity
+
+ @torch.jit.script_method
+ def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
+ self.check_forward_input(input)
+ if hx is None:
+ _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx, '')
+ if self.nonlinearity == "tanh":
+ ret = _VF.quantized_rnn_tanh_cell(
+ input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+ self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+ self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+ self.zero_point_hh
+ )
+ elif self.nonlinearity == "relu":
+ ret = _VF.quantized_rnn_relu_cell(
+ input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+ self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+ self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+ self.zero_point_hh
+ )
+ else:
+ ret = input # TODO: remove when jit supports exception flow
+ raise RuntimeError(
+ "Unknown nonlinearity: {}".format(self.nonlinearity))
+ return ret
+
+
+class QuantizedLSTMCell(QuantizedRNNCellBase):
+ def __init__(self, other):
+ super(QuantizedLSTMCell, self).__init__(other)
+
+ @torch.jit.script_method
+ def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
+ self.check_forward_input(input)
+ if hx is None:
+ zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ _hx = (zeros, zeros)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx[0], '[0]')
+ self.check_forward_hidden(input, _hx[1], '[1]')
+ return _VF.quantized_lstm_cell(
+ input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+ self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+ self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+ self.zero_point_hh
+ )
+
+
+class QuantizedGRUCell(QuantizedRNNCellBase):
+ def __init__(self, other):
+ super(QuantizedGRUCell, self).__init__(other)
+
+ @torch.jit.script_method
+ def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
+ self.check_forward_input(input)
+ if hx is None:
+ _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx, '')
+ return _VF.quantized_gru_cell(
+ input, _hx, self.weight_ih, self.weight_hh, self.bias_ih,
+ self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
+ self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
+ self.zero_point_hh
+ )
+
+
+def quantize_rnn_cell_modules(module):
+ reassign = {}
+ for name, mod in module.named_modules():
+ if mod is module:
+ continue
+ new_mod = quantize_rnn_cell_modules(mod)
+ if new_mod is not mod:
+ reassign[name] = new_mod
+ for name, mod in reassign.items():
+ setattr(module, name, mod)
+ if isinstance(module, torch.nn.LSTMCell):
+ return QuantizedLSTMCell(mod)
+ if isinstance(module, torch.nn.GRUCell):
+ return QuantizedGRUCell(mod)
+ if isinstance(module, torch.nn.RNNCell):
+ return QuantizedRNNCell(mod)
+
+ return module
+
+
def quantize_linear_modules(module):
+ reassign = {}
for name, mod in module.named_modules():
if mod is module:
continue
- if isinstance(mod, torch.nn.Linear):
- setattr(module, name, QuantizedLinear(mod))
- quantize_linear_modules(mod)
+ new_mod = quantize_linear_modules(mod)
+ if new_mod is not mod:
+ reassign[name] = new_mod
+
+ for name, mod in reassign.items():
+ setattr(module, name, mod)
+ if isinstance(mod, torch.nn.Linear):
+ return QuantizedLinear(mod)
+ return module