From f8778aef78f3e34e78a0cbe4cd2da9e7f1895a15 Mon Sep 17 00:00:00 2001 From: Ahmed Aly Date: Thu, 7 Mar 2019 01:03:51 -0800 Subject: [PATCH] Implement a Caffe2 standalone LSTM operator (#17726) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17726 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17725 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17461 Implementing a standalone LSTM Operator in Caffe2 adopted from this Aten implementation: diffusion/FBS/browse/master/fbcode/caffe2/aten/src/ATen/native/RNN.cpp. The most tricky thing in this exercise was that caffe2::Tensor has no copy constructor that made it necessary to implement a custom templated copy constructor for the different Tensor containers used in the code. Also there was no way to use off-the-shelf C2 operators in my code easily so I had to copy some code that is doing basic matmul, cat, split, transpose and linear as utility functions. Two things missing: - Profiling this implementation against the current ONNXified LSTM op - Make this operator available to use in PyTorch Reviewed By: dzhulgakov Differential Revision: D14351575 fbshipit-source-id: 3b99b53212cf593c7a49e45580b5a07b90809e64 --- caffe2/operators/inference_lstm_op.cc | 71 ++++++ caffe2/operators/inference_lstm_op.h | 310 ++++++++++++++++++++++++++ caffe2/operators/lstm_utils.h | 318 +++++++++++++++++++++++++++ caffe2/python/test/inference_lstm_op_test.py | 72 ++++++ 4 files changed, 771 insertions(+) create mode 100644 caffe2/operators/inference_lstm_op.cc create mode 100644 caffe2/operators/inference_lstm_op.h create mode 100644 caffe2/operators/lstm_utils.h create mode 100644 caffe2/python/test/inference_lstm_op_test.py diff --git a/caffe2/operators/inference_lstm_op.cc b/caffe2/operators/inference_lstm_op.cc new file mode 100644 index 0000000..cfdf21f --- /dev/null +++ b/caffe2/operators/inference_lstm_op.cc @@ -0,0 +1,71 @@ +#include "caffe2/operators/inference_lstm_op.h" + +namespace caffe2 { +namespace { + +bool InferenceLSTMOp::RunOnDevice() { + auto& _input = Input(0); + auto& hidden_0 = Input(1); + auto& hidden_1 = Input(2); + std::vector params; + for (int i = 3; i < InputSize(); i++) { + params.push_back(Input(i).UnsafeSharedInstance()); + } + auto input = batch_first_ ? transpose(_input, 0, 1, &context_) + : _input.UnsafeSharedInstance(); + + auto cell_params = gather_params(params, has_biases_, &context_); + auto results = _lstm_impl( + input, + cell_params, + hidden_0, + hidden_1, + num_layers_, + bidirectional_, + &context_); + + std::vector allOutputs(OutputSize()); + allOutputs.at(0) = copy_ctor(std::get<0>(results)); + if (batch_first_) { + allOutputs.at(0) = transpose(allOutputs.at(0), 0, 1, &context_); + } + allOutputs.at(1) = copy_ctor(std::get<1>(results)); + allOutputs.at(2) = copy_ctor(std::get<2>(results)); + for (int i = 0; i < OutputSize(); i++) { + auto output = XOutput(i, allOutputs.at(i).sizes(), dtype()); + context_.CopyItemsSameDevice( + allOutputs.at(i).dtype(), + allOutputs.at(i).numel(), + allOutputs.at(i).template data(), + output.template mutable_data()); + } + return true; +} + +REGISTER_CPU_OPERATOR(InferenceLSTM, InferenceLSTMOp); +OPERATOR_SCHEMA(InferenceLSTM) + .NumInputs(1, INT_MAX) + .NumOutputs(3) + .Output(0, "output", "the output of the last layer of lstm") + .Output(1, "hidden", "hidden state at t = seq_len") + .Output(2, "cell", "cell state at t = seq_len") + .Arg("num_layers", "(*long*): number of layers in the lstm stack") + .Arg("has_biases", "(*bool*): whether the cells have biases or not") + .Arg("batch_first", "(*bool*): whether the batch is at dim 0") + .Arg("bidirectional", "(*bool*): if bidirectional"); +NO_GRADIENT(InferenceLSTM); +} // namespace +} // namespace caffe2 + +C10_REGISTER_CAFFE2_OPERATOR_CPU( + InferenceLSTM, + (std::vector{ + c10::Argument("input_list", ListType::ofTensors()), + c10::Argument("num_layers", IntType::get()), + c10::Argument("has_biases", BoolType::get()), + c10::Argument("batch_first", BoolType::get()), + c10::Argument("bidirectional", BoolType::get())}), + (std::vector{c10::Argument("output"), + c10::Argument("hidden"), + c10::Argument("cell")}), + caffe2::InferenceLSTMOp); diff --git a/caffe2/operators/inference_lstm_op.h b/caffe2/operators/inference_lstm_op.h new file mode 100644 index 0000000..d907674 --- /dev/null +++ b/caffe2/operators/inference_lstm_op.h @@ -0,0 +1,310 @@ +#ifndef LSTM_OP_H_ +#define LSTM_OP_H_ + +#include +#include +#include +#include +#include +#include "caffe2/core/blob_serialization.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/tensor.h" +#include "caffe2/utils/eigen_utils.h" +#include "caffe2/utils/math.h" +#include "lstm_utils.h" + +C10_DECLARE_CAFFE2_OPERATOR(LSTMOp); + +namespace caffe2 { +namespace { + +using t_tuple = std::tuple; + +struct CellParams { + CellParams( + const Tensor& _w_ih, + const Tensor& _w_hh, + const Tensor& _b_ih, + const Tensor& _b_hh, + CPUContext* _context) { + initParams(_w_ih, _w_hh, _b_ih, _b_hh, _context); + } + + CellParams(const CellParams& rhs) { + initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context); + } + + CellParams& operator=(const CellParams& rhs) { + initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context); + return *this; + } + + void initParams( + const Tensor& _w_ih, + const Tensor& _w_hh, + const Tensor& _b_ih, + const Tensor& _b_hh, + CPUContext* _context) { + w_ih = copy_ctor(_w_ih); + w_hh = copy_ctor(_w_hh); + b_ih = copy_ctor(_b_ih); + b_hh = copy_ctor(_b_hh); + context = _context; + } + + Tensor w_ih; + Tensor w_hh; + Tensor b_ih; /* optional */ + Tensor b_hh; /* optional */ + CPUContext* context; + + Tensor linear_ih(const Tensor& input) const { + return linear(input, w_ih, b_ih, context); + } + Tensor linear_hh(const Tensor& h) const { + return linear(h, w_hh, b_hh, context); + } +}; + +struct LSTMCell { + explicit LSTMCell(CPUContext* context) : context_(context) {} + t_tuple operator()( + const Tensor& input, + const t_tuple& hidden, + const CellParams& params) const { + const auto& hx = std::get<0>(hidden); + const auto& cx = std::get<1>(hidden); + auto linear_ih = params.linear_ih(input); + auto linear_hh = params.linear_hh(hx); + auto gates = add(linear_ih, linear_hh, context_); + auto chunked_gates = chunk(gates, 4, 1, context_); + auto ingate = sigmoid(chunked_gates[0]); + auto forgetgate = sigmoid(chunked_gates[1]); + auto cellgate = tanh(chunked_gates[2], context_); + auto outgate = sigmoid(chunked_gates[3]); + + auto cy = + add(mul(forgetgate, cx, context_), + mul(ingate, cellgate, context_), + context_); + auto hy = mul(outgate, tanh(cy, context_), context_); + return std::make_tuple(std::move(hy), std::move(cy)); + } + CPUContext* context_; +}; + +template +struct LayerOutput { + output_type outputs; + hidden_type final_hidden; + + LayerOutput(const output_type& _outputs, const hidden_type& _hidden) { + outputs = copy_ctor(_outputs); + final_hidden = copy_ctor(_hidden); + } +}; + +template +struct Layer { + using output_type = LayerOutput; + virtual ~Layer() {} + virtual output_type operator()( + const Tensor& input, + const hidden_type& input_hidden, + const param_type& params) const = 0; +}; + +struct FullLSTMLayer : Layer { + FullLSTMLayer(LSTMCell& cell, CPUContext* context) + : cell_(cell), context_(context) {} + + LayerOutput, t_tuple> operator()( + const std::vector& step_inputs, + const std::tuple& input_hidden, + const CellParams& params) const { + std::vector step_outputs; + auto hidden = copy_ctor(input_hidden); + + for (size_t i = 0; i < step_inputs.size(); i++) { + hidden = cell_(step_inputs[i], hidden, params); + step_outputs.push_back(copy_ctor(std::get<0>(hidden))); + } + + return {step_outputs, hidden}; + } + + LayerOutput operator()( + const Tensor& inputs, + const std::tuple& input_hidden, + const CellParams& params) const override { + auto unstacked_output = + (*this)(unbind(inputs, 0, context_), input_hidden, params); + return {stack(unstacked_output.outputs, 0, context_), + unstacked_output.final_hidden}; + } + LSTMCell cell_; + CPUContext* context_; +}; + +struct FullBidirectionalLSTMLayer + : Layer, std::pair> { + using bidir_hidden_type = std::pair; + using param_type = std::pair; + using output_type = LayerOutput; + + FullBidirectionalLSTMLayer(LSTMCell& cell, CPUContext* context) + : layer_(cell, context), context_(context) {} + + output_type operator()( + const Tensor& input, + const bidir_hidden_type& input_hidden, + const param_type& params) const override { + std::vector outputs; + auto step_inputs = unbind(input, 0, context_); + auto fw_result = layer_(step_inputs, input_hidden.first, params.first); + auto fw_output = stack(fw_result.outputs, 0, context_); + outputs.push_back(copy_ctor(fw_output)); + auto rev_step_inputs = reverse(std::move(step_inputs)); + auto rev_result = + layer_(rev_step_inputs, input_hidden.second, params.second); + std::reverse(rev_result.outputs.begin(), rev_result.outputs.end()); + auto rev_output = stack(rev_result.outputs, 0, context_); + outputs.push_back(copy_ctor(rev_output)); + return {cat(outputs, fw_output.dim() - 1, context_), + std::make_pair( + std::move(fw_result.final_hidden), + std::move(rev_result.final_hidden))}; + } + + inline std::vector reverse(std::vector&& x) const { + std::reverse(x.begin(), x.end()); + return std::move(x); + } + + private: + FullLSTMLayer layer_; + CPUContext* context_; +}; + +template +LayerOutput> apply_layer_stack( + const Layer& layer, + const Tensor& input, + const std::vector& hiddens, + const std::vector& weights, + int64_t num_layers) { + CAFFE_ENFORCE( + num_layers == hiddens.size(), + "Expected more hidden states in stacked_rnn"); + CAFFE_ENFORCE( + num_layers == weights.size(), "Expected more weights in stacked_rnn"); + + auto layer_input = input.UnsafeSharedInstance(); + auto hidden_it = hiddens.begin(); + auto weight_it = weights.begin(); + std::vector final_hiddens(num_layers); + for (int64_t l = 0; l < num_layers; ++l) { + auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++)); + final_hiddens.at(l) = std::move(layer_output.final_hidden); + layer_input = std::move(layer_output.outputs); + } + return {layer_input, final_hiddens}; +} + +std::tuple _lstm_impl( + const Tensor& input, + const std::vector& params, + const Tensor& hx, + const Tensor& cx, + int64_t num_layers, + bool bidirectional, + CPUContext* context) { + using stack_output = LayerOutput>; + auto layer_hx = unbind(hx, 0, context); + auto layer_cx = unbind(cx, 0, context); + int64_t total_layers = layer_hx.size(); + std::vector> hiddens; + hiddens.reserve(total_layers); + for (int64_t i = 0; i < total_layers; ++i) { + hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i])); + } + LSTMCell cell(context); + std::shared_ptr stack_output_ptr; + if (bidirectional) { + auto bidir_result = apply_layer_stack( + FullBidirectionalLSTMLayer{cell, context}, + input, + pair_vec(hiddens), + pair_vec(params), + num_layers); + stack_output_ptr.reset(new stack_output( + bidir_result.outputs, + unpair_vec(std::move(bidir_result.final_hidden)))); + } else { + auto result = apply_layer_stack( + FullLSTMLayer{cell, context}, input, hiddens, params, num_layers); + stack_output_ptr = std::make_shared(std::move(result)); + } + + std::vector hy, cy; + hy.reserve(total_layers); + cy.reserve(total_layers); + for (auto& hidden : stack_output_ptr->final_hidden) { + hy.push_back(std::move(std::get<0>(hidden))); + cy.push_back(std::move(std::get<1>(hidden))); + } + return std::make_tuple( + std::move(stack_output_ptr->outputs), + stack(hy, 0, context), + stack(cy, 0, context)); +} + +// Parses a flat list of parameter tensors into a list of CellParams +std::vector gather_params( + const std::vector& params, + bool has_biases, + CPUContext* context) { + Tensor undefined; + std::vector result; + if (has_biases) { + CAFFE_ENFORCE_EQ( + params.size() % 4, 0, "got an incorrect number of LSTM parameters"); + for (size_t i = 0; i < params.size(); i += 4) { + result.emplace_back( + params[i], params[i + 1], params[i + 2], params[i + 3], context); + } + } else { + CAFFE_ENFORCE_EQ( + params.size() % 2, 0, "got an incorrect number of LSTM parameters"); + for (size_t i = 0; i < params.size(); i += 2) { + result.emplace_back( + params[i], params[i + 1], undefined, undefined, context); + } + } + return result; +} + +class InferenceLSTMOp : public Operator { + public: + template + explicit InferenceLSTMOp(Args&&... args) + : Operator(std::forward(args)...), + num_layers_(this->template GetSingleArgument("num_layers", 1)), + bidirectional_( + this->template GetSingleArgument("bidirectional", false)), + has_biases_(this->template GetSingleArgument("has_biases", true)), + batch_first_( + this->template GetSingleArgument("batch_first", false)) {} + + bool RunOnDevice() override; + + protected: + int64_t num_layers_; + bool bidirectional_; + bool has_biases_; + bool batch_first_; +}; + +} // namespace +} // namespace caffe2 +#endif // LSTM_OP_H_ diff --git a/caffe2/operators/lstm_utils.h b/caffe2/operators/lstm_utils.h new file mode 100644 index 0000000..b354764 --- /dev/null +++ b/caffe2/operators/lstm_utils.h @@ -0,0 +1,318 @@ +#include +#include +#include "caffe2/core/tensor.h" +#include "caffe2/utils/eigen_utils.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { +namespace { + +using t_tuple = std::tuple; + +template +T copy_ctor(const T& x) { + return x; +} + +template <> +Tensor copy_ctor(const Tensor& X) { + return X.UnsafeSharedInstance(); +} + +template <> +t_tuple copy_ctor(const t_tuple& X) { + return std::make_tuple(copy_ctor(std::get<0>(X)), copy_ctor(std::get<1>(X))); +} + +template <> +std::pair copy_ctor(const std::pair& X) { + return std::make_pair(copy_ctor(X.first), copy_ctor(X.second)); +} + +template <> +std::vector copy_ctor(const std::vector& X) { + std::vector Y(X.size()); + std::transform(X.begin(), X.end(), Y.begin(), [](const Tensor& x) { + return copy_ctor(x); + }); + return Y; +} + +template <> +std::vector copy_ctor(const std::vector& X) { + std::vector Y(X.size()); + std::transform(X.begin(), X.end(), Y.begin(), [](const t_tuple& x) { + return copy_ctor(x); + }); + return Y; +} + +template <> +std::vector> copy_ctor( + const std::vector>& X) { + std::vector> Y(X.size()); + std::transform( + X.begin(), X.end(), Y.begin(), [](const std::pair& x) { + return copy_ctor(x); + }); + return Y; +} + +// Gathers every two elements of a vector in a vector of pairs +template +static std::vector> pair_vec(const std::vector& vals) { + CAFFE_ENFORCE_EQ( + vals.size() % 2, + 0, + "Odd number of params or hiddens given to a bidirectional RNN"); + std::vector> result; + result.reserve(vals.size() / 2); + for (int64_t i = 0; i < vals.size(); i += 2) { + result.emplace_back(copy_ctor(vals[i]), copy_ctor(vals[i + 1])); + } + return result; +} + +// Flattens a vector of pairs +template +static std::vector unpair_vec(std::vector>&& vals) { + std::vector result; + result.reserve(vals.size() * 2); + for (int64_t i = 0; i < vals.size(); i++) { + result.push_back(std::move(vals[i].first)); + result.push_back(std::move(vals[i].second)); + } + return result; +} + +Tensor matmul(const Tensor& X, const Tensor& W, CPUContext* context) { + const auto canonical_axis = X.canonical_axis_index(1); + const auto M = X.size_to_dim(canonical_axis); + const auto K = X.size_from_dim(canonical_axis); + const auto canonical_axis_w = W.canonical_axis_index(1); + const int N = W.size_to_dim(canonical_axis_w); + auto output_size = X.sizes().vec(); + output_size.resize(canonical_axis + 1); + output_size[canonical_axis] = N; + Tensor C(output_size, CPU); + math::Gemm( + CblasNoTrans, + CblasTrans, + M, + N, + K, + 1, + X.template data(), + W.template data(), + 0, + C.template mutable_data(), + context); + return C; +} + +Tensor +linear(const Tensor& X, const Tensor& W, const Tensor& B, CPUContext* context) { + auto output = matmul(X, W, context); + if (B) { + const auto canonical_axis = X.canonical_axis_index(1); + const auto M = X.size_to_dim(canonical_axis); + const auto canonical_axis_w = W.canonical_axis_index(1); + const int N = W.size_to_dim(canonical_axis_w); + auto bias_multiplier_ = caffe2::empty({M}, CPU); + math::Set( + M, 1, bias_multiplier_.template mutable_data(), context); + math::Gemm( + CblasNoTrans, + CblasNoTrans, + M, + N, + 1, + 1, + bias_multiplier_.template data(), + B.template data(), + 1, + output.template mutable_data(), + context); + } + return output; +} + +std::vector +chunk(const Tensor& input, int chunks, int axis, CPUContext* context) { + int canonical_axis = input.canonical_axis_index(axis); + CAFFE_ENFORCE_LT( + canonical_axis, input.dim(), "Axis not in input ndim range."); + const int input_channels = input.dim32(canonical_axis); + CAFFE_ENFORCE_EQ( + input_channels % chunks, + 0, + "input channels should be divisible by the number of chunks."); + auto split_size = input_channels / chunks; + vector output_dims(input.sizes().vec()); + int before = 1, after = 1; + for (int i = 0; i < canonical_axis; ++i) { + before *= input.dim32(i); + } + for (int i = canonical_axis + 1; i < input.dim(); ++i) { + after *= input.dim32(i); + } + size_t input_offset = 0; + std::vector outputs; + for (int i = 0; i < chunks; ++i) { + auto axis_dim = split_size; + output_dims[canonical_axis] = split_size; + Tensor output(output_dims, CPU); + math::CopyMatrix( + input.itemsize(), + before, + axis_dim * after, + static_cast(input.raw_data()) + input_offset, + input.dim32(canonical_axis) * after, + output.raw_mutable_data(input.dtype()), + axis_dim * after, + context, + input.dtype().copy()); + input_offset += axis_dim * after * input.itemsize(); + outputs.push_back(std::move(output)); + } + return outputs; +} + +std::vector unbind(const Tensor& input, int axis, CPUContext* context) { + // 1 - Chunk the input tensor along the given axis into N chunks where + // N is the dim(axis) + auto chunks = chunk(input, input.sizes()[axis], axis, context); + // 2 - Compute new dimensions + std::vector newDims = input.sizes().vec(); + newDims.erase(newDims.begin() + axis); + + // 3 - Reshape chunks to drop the extra dimension + for (int i = 0; i < chunks.size(); i++) { + CAFFE_ENFORCE_EQ( + chunks[i].sizes()[axis], 1, "Got an unexpected chunk size"); + chunks[i].Reshape(newDims); + } + return chunks; +} + +Tensor +cat(const std::vector& tensorList, int axis, CPUContext* context) { + // Adopted from C2's concat operator + auto input_zero = copy_ctor(tensorList.at(0)); + vector outputDims(input_zero.sizes().vec()); + CAFFE_ENFORCE(outputDims.size() > 0); + for (int i = 1; i < tensorList.size(); i++) { + CAFFE_ENFORCE(input_zero.dtype() == tensorList.at(i).dtype()); + outputDims[axis] += tensorList.at(i).sizes()[axis]; + } + auto output_channels = outputDims[axis]; + Tensor output(outputDims, CPU); + int before = 1, after = 1; + for (int i = 0; i < tensorList.at(0).dim(); ++i) { + if (i == axis) { + continue; + } + int dim = input_zero.dim32(i); + if (i < axis) { + before *= dim; + } else { + after *= dim; + } + } + size_t output_offset = 0; + for (const auto& input : tensorList) { + auto axis_dim = input.dim32(axis); + math::CopyMatrix( + input.itemsize(), + before, + axis_dim * after, + input.raw_data(), + axis_dim * after, + static_cast(output.raw_mutable_data(input_zero.dtype())) + + output_offset, + output_channels * after, + context, + input_zero.dtype().copy()); + output_offset += axis_dim * after * input.itemsize(); + } + + return output; +} + +Tensor +stack(const std::vector& tensorList, int axis, CPUContext* context) { + // 1 - Compute new dimensions + std::vector newDims(tensorList[0].sizes().vec()); + std::vector expandedTensorList; + newDims.insert(newDims.begin() + axis, 1); + for (int i = 0; i < tensorList.size(); i++) { + expandedTensorList.emplace_back(tensorList[i].Clone()); + expandedTensorList.at(i).Reshape(newDims); + } + return cat(expandedTensorList, axis, context); +} + +Tensor sigmoid(const Tensor& X) { + Tensor Y(X.sizes(), CPU); + auto N = X.numel(); + EigenVectorArrayMap(Y.template mutable_data(), N) = 1.0 / + (1.0 + + (-ConstEigenVectorArrayMap(X.template data(), N)).exp()); + return Y; +} + +Tensor tanh(const Tensor& X, CPUContext* context) { + Tensor Y(X.sizes(), CPU); + math::Tanh( + X.numel(), + X.template data(), + Y.template mutable_data(), + context); + return Y; +} + +Tensor add(const Tensor& X, const Tensor& Y, CPUContext* context) { + Tensor Z(X.sizes().vec(), CPU); + math::Add( + X.numel(), + X.template data(), + Y.template data(), + Z.template mutable_data(), + context); + return Z; +} + +Tensor mul(const Tensor& X, const Tensor& Y, CPUContext* context) { + Tensor Z(X.sizes().vec(), CPU); + math::Mul( + X.numel(), + X.template data(), + Y.template data(), + Z.template mutable_data(), + context); + return Z; +} + +Tensor transpose(const Tensor& X, int dim0, int dim1, CPUContext* context) { + int ndim = X.dim(); + CAFFE_ENFORCE(ndim > dim0 && ndim > dim1, "Invalid transpose dimensions"); + std::vector axes(ndim); + std::iota(axes.begin(), axes.end(), 0); + std::swap(axes[dim0], axes[dim1]); + std::vector Y_dims(ndim); + std::vector X_dims(X.sizes().cbegin(), X.sizes().cend()); + for (int i = 0; i < ndim; ++i) { + Y_dims[i] = X_dims[axes[i]]; + } + Tensor Y(Y_dims, CPU); + math::Transpose( + ndim, + X_dims.data(), + axes.data(), + X.template data(), + Y.template mutable_data(), + context); + return Y; +} +} // namespace +} // namespace caffe2 diff --git a/caffe2/python/test/inference_lstm_op_test.py b/caffe2/python/test/inference_lstm_op_test.py new file mode 100644 index 0000000..dc33f52 --- /dev/null +++ b/caffe2/python/test/inference_lstm_op_test.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +import inspect + +import hypothesis.strategies as st +import numpy as np +import torch +from caffe2.python import core, workspace +from caffe2.python.test_util import TestCase +from hypothesis import given +from torch import nn + + +class TestC2LSTM(TestCase): + @given( + bsz=st.integers(1, 5), + seq_lens=st.integers(1, 6), + emb_lens=st.integers(5, 10), + hidden_size=st.integers(3, 7), + num_layers=st.integers(1, 4), + has_biases=st.booleans(), + is_bidirectional=st.booleans(), + batch_first=st.booleans(), + ) + def test_c2_lstm( + self, + bsz, + seq_lens, + emb_lens, + hidden_size, + num_layers, + has_biases, + is_bidirectional, + batch_first, + ): + net = core.Net("test_net") + num_directions = 2 if is_bidirectional else 1 + py_lstm = nn.LSTM( + emb_lens, + hidden_size, + batch_first=batch_first, + bidirectional=is_bidirectional, + bias=has_biases, + num_layers=num_layers, + ) + + hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32) + + if batch_first: + inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32) + else: + inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32) + + py_results = py_lstm(torch.from_numpy(inputs)) + lstm_in = [ + torch.from_numpy(inputs), + torch.from_numpy(hx), + torch.from_numpy(hx), + ] + [param.detach() for param in py_lstm._flat_weights] + + c2_results = torch.ops._caffe2.InferenceLSTM( + lstm_in, num_layers, has_biases, batch_first, is_bidirectional + ) + + np.testing.assert_array_almost_equal( + py_results[0].detach().numpy(), c2_results[0].detach().numpy() + ) + np.testing.assert_array_almost_equal( + py_results[1][0].detach().numpy(), c2_results[1].detach().numpy() + ) + np.testing.assert_array_almost_equal( + py_results[1][1].detach().numpy(), c2_results[2].detach().numpy() + ) -- 2.7.4