Implement a Caffe2 standalone LSTM operator (#17726)
authorAhmed Aly <ahhegazy@fb.com>
Thu, 7 Mar 2019 09:03:51 +0000 (01:03 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 7 Mar 2019 09:08:49 +0000 (01:08 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17726

Pull Request resolved: https://github.com/pytorch/pytorch/pull/17725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/17461

Implementing a standalone LSTM Operator in Caffe2 adopted from this Aten implementation: diffusion/FBS/browse/master/fbcode/caffe2/aten/src/ATen/native/RNN.cpp. The most tricky thing in this exercise was that caffe2::Tensor has no copy constructor that made it necessary to implement a custom templated copy constructor for the different Tensor containers used in the code. Also there was no way to use off-the-shelf C2 operators in my code easily so I had to copy some code that is doing basic matmul, cat, split, transpose and linear as utility functions.

Two things missing:

- Profiling this implementation against the current ONNXified LSTM op
- Make this operator available to use in PyTorch

Reviewed By: dzhulgakov

Differential Revision: D14351575

fbshipit-source-id: 3b99b53212cf593c7a49e45580b5a07b90809e64

caffe2/operators/inference_lstm_op.cc [new file with mode: 0644]
caffe2/operators/inference_lstm_op.h [new file with mode: 0644]
caffe2/operators/lstm_utils.h [new file with mode: 0644]
caffe2/python/test/inference_lstm_op_test.py [new file with mode: 0644]

diff --git a/caffe2/operators/inference_lstm_op.cc b/caffe2/operators/inference_lstm_op.cc
new file mode 100644 (file)
index 0000000..cfdf21f
--- /dev/null
@@ -0,0 +1,71 @@
+#include "caffe2/operators/inference_lstm_op.h"
+
+namespace caffe2 {
+namespace {
+
+bool InferenceLSTMOp::RunOnDevice() {
+  auto& _input = Input(0);
+  auto& hidden_0 = Input(1);
+  auto& hidden_1 = Input(2);
+  std::vector<Tensor> params;
+  for (int i = 3; i < InputSize(); i++) {
+    params.push_back(Input(i).UnsafeSharedInstance());
+  }
+  auto input = batch_first_ ? transpose(_input, 0, 1, &context_)
+                            : _input.UnsafeSharedInstance();
+
+  auto cell_params = gather_params(params, has_biases_, &context_);
+  auto results = _lstm_impl(
+      input,
+      cell_params,
+      hidden_0,
+      hidden_1,
+      num_layers_,
+      bidirectional_,
+      &context_);
+
+  std::vector<Tensor> allOutputs(OutputSize());
+  allOutputs.at(0) = copy_ctor(std::get<0>(results));
+  if (batch_first_) {
+    allOutputs.at(0) = transpose(allOutputs.at(0), 0, 1, &context_);
+  }
+  allOutputs.at(1) = copy_ctor(std::get<1>(results));
+  allOutputs.at(2) = copy_ctor(std::get<2>(results));
+  for (int i = 0; i < OutputSize(); i++) {
+    auto output = XOutput(i, allOutputs.at(i).sizes(), dtype<float>());
+    context_.CopyItemsSameDevice(
+        allOutputs.at(i).dtype(),
+        allOutputs.at(i).numel(),
+        allOutputs.at(i).template data<float>(),
+        output.template mutable_data<float>());
+  }
+  return true;
+}
+
+REGISTER_CPU_OPERATOR(InferenceLSTM, InferenceLSTMOp);
+OPERATOR_SCHEMA(InferenceLSTM)
+    .NumInputs(1, INT_MAX)
+    .NumOutputs(3)
+    .Output(0, "output", "the output of the last layer of lstm")
+    .Output(1, "hidden", "hidden state at t = seq_len")
+    .Output(2, "cell", "cell state at t = seq_len")
+    .Arg("num_layers", "(*long*): number of layers in the lstm stack")
+    .Arg("has_biases", "(*bool*): whether the cells have biases or not")
+    .Arg("batch_first", "(*bool*): whether the batch is at dim 0")
+    .Arg("bidirectional", "(*bool*): if bidirectional");
+NO_GRADIENT(InferenceLSTM);
+} // namespace
+} // namespace caffe2
+
+C10_REGISTER_CAFFE2_OPERATOR_CPU(
+    InferenceLSTM,
+    (std::vector<c10::Argument>{
+        c10::Argument("input_list", ListType::ofTensors()),
+        c10::Argument("num_layers", IntType::get()),
+        c10::Argument("has_biases", BoolType::get()),
+        c10::Argument("batch_first", BoolType::get()),
+        c10::Argument("bidirectional", BoolType::get())}),
+    (std::vector<c10::Argument>{c10::Argument("output"),
+                                c10::Argument("hidden"),
+                                c10::Argument("cell")}),
+    caffe2::InferenceLSTMOp);
diff --git a/caffe2/operators/inference_lstm_op.h b/caffe2/operators/inference_lstm_op.h
new file mode 100644 (file)
index 0000000..d907674
--- /dev/null
@@ -0,0 +1,310 @@
+#ifndef LSTM_OP_H_
+#define LSTM_OP_H_
+
+#include <c10/core/Tensor.h>
+#include <algorithm>
+#include <sstream>
+#include <unordered_map>
+#include <vector>
+#include "caffe2/core/blob_serialization.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor.h"
+#include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math.h"
+#include "lstm_utils.h"
+
+C10_DECLARE_CAFFE2_OPERATOR(LSTMOp);
+
+namespace caffe2 {
+namespace {
+
+using t_tuple = std::tuple<Tensor, Tensor>;
+
+struct CellParams {
+  CellParams(
+      const Tensor& _w_ih,
+      const Tensor& _w_hh,
+      const Tensor& _b_ih,
+      const Tensor& _b_hh,
+      CPUContext* _context) {
+    initParams(_w_ih, _w_hh, _b_ih, _b_hh, _context);
+  }
+
+  CellParams(const CellParams& rhs) {
+    initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
+  }
+
+  CellParams& operator=(const CellParams& rhs) {
+    initParams(rhs.w_ih, rhs.w_hh, rhs.b_ih, rhs.b_hh, rhs.context);
+    return *this;
+  }
+
+  void initParams(
+      const Tensor& _w_ih,
+      const Tensor& _w_hh,
+      const Tensor& _b_ih,
+      const Tensor& _b_hh,
+      CPUContext* _context) {
+    w_ih = copy_ctor(_w_ih);
+    w_hh = copy_ctor(_w_hh);
+    b_ih = copy_ctor(_b_ih);
+    b_hh = copy_ctor(_b_hh);
+    context = _context;
+  }
+
+  Tensor w_ih;
+  Tensor w_hh;
+  Tensor b_ih; /* optional */
+  Tensor b_hh; /* optional */
+  CPUContext* context;
+
+  Tensor linear_ih(const Tensor& input) const {
+    return linear(input, w_ih, b_ih, context);
+  }
+  Tensor linear_hh(const Tensor& h) const {
+    return linear(h, w_hh, b_hh, context);
+  }
+};
+
+struct LSTMCell {
+  explicit LSTMCell(CPUContext* context) : context_(context) {}
+  t_tuple operator()(
+      const Tensor& input,
+      const t_tuple& hidden,
+      const CellParams& params) const {
+    const auto& hx = std::get<0>(hidden);
+    const auto& cx = std::get<1>(hidden);
+    auto linear_ih = params.linear_ih(input);
+    auto linear_hh = params.linear_hh(hx);
+    auto gates = add(linear_ih, linear_hh, context_);
+    auto chunked_gates = chunk(gates, 4, 1, context_);
+    auto ingate = sigmoid(chunked_gates[0]);
+    auto forgetgate = sigmoid(chunked_gates[1]);
+    auto cellgate = tanh(chunked_gates[2], context_);
+    auto outgate = sigmoid(chunked_gates[3]);
+
+    auto cy =
+        add(mul(forgetgate, cx, context_),
+            mul(ingate, cellgate, context_),
+            context_);
+    auto hy = mul(outgate, tanh(cy, context_), context_);
+    return std::make_tuple(std::move(hy), std::move(cy));
+  }
+  CPUContext* context_;
+};
+
+template <typename output_type, typename hidden_type>
+struct LayerOutput {
+  output_type outputs;
+  hidden_type final_hidden;
+
+  LayerOutput(const output_type& _outputs, const hidden_type& _hidden) {
+    outputs = copy_ctor(_outputs);
+    final_hidden = copy_ctor(_hidden);
+  }
+};
+
+template <typename hidden_type, typename param_type>
+struct Layer {
+  using output_type = LayerOutput<Tensor, hidden_type>;
+  virtual ~Layer() {}
+  virtual output_type operator()(
+      const Tensor& input,
+      const hidden_type& input_hidden,
+      const param_type& params) const = 0;
+};
+
+struct FullLSTMLayer : Layer<t_tuple, CellParams> {
+  FullLSTMLayer(LSTMCell& cell, CPUContext* context)
+      : cell_(cell), context_(context) {}
+
+  LayerOutput<std::vector<Tensor>, t_tuple> operator()(
+      const std::vector<Tensor>& step_inputs,
+      const std::tuple<Tensor, Tensor>& input_hidden,
+      const CellParams& params) const {
+    std::vector<Tensor> step_outputs;
+    auto hidden = copy_ctor(input_hidden);
+
+    for (size_t i = 0; i < step_inputs.size(); i++) {
+      hidden = cell_(step_inputs[i], hidden, params);
+      step_outputs.push_back(copy_ctor(std::get<0>(hidden)));
+    }
+
+    return {step_outputs, hidden};
+  }
+
+  LayerOutput<Tensor, t_tuple> operator()(
+      const Tensor& inputs,
+      const std::tuple<Tensor, Tensor>& input_hidden,
+      const CellParams& params) const override {
+    auto unstacked_output =
+        (*this)(unbind(inputs, 0, context_), input_hidden, params);
+    return {stack(unstacked_output.outputs, 0, context_),
+            unstacked_output.final_hidden};
+  }
+  LSTMCell cell_;
+  CPUContext* context_;
+};
+
+struct FullBidirectionalLSTMLayer
+    : Layer<std::pair<t_tuple, t_tuple>, std::pair<CellParams, CellParams>> {
+  using bidir_hidden_type = std::pair<t_tuple, t_tuple>;
+  using param_type = std::pair<CellParams, CellParams>;
+  using output_type = LayerOutput<Tensor, bidir_hidden_type>;
+
+  FullBidirectionalLSTMLayer(LSTMCell& cell, CPUContext* context)
+      : layer_(cell, context), context_(context) {}
+
+  output_type operator()(
+      const Tensor& input,
+      const bidir_hidden_type& input_hidden,
+      const param_type& params) const override {
+    std::vector<Tensor> outputs;
+    auto step_inputs = unbind(input, 0, context_);
+    auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
+    auto fw_output = stack(fw_result.outputs, 0, context_);
+    outputs.push_back(copy_ctor(fw_output));
+    auto rev_step_inputs = reverse(std::move(step_inputs));
+    auto rev_result =
+        layer_(rev_step_inputs, input_hidden.second, params.second);
+    std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
+    auto rev_output = stack(rev_result.outputs, 0, context_);
+    outputs.push_back(copy_ctor(rev_output));
+    return {cat(outputs, fw_output.dim() - 1, context_),
+            std::make_pair(
+                std::move(fw_result.final_hidden),
+                std::move(rev_result.final_hidden))};
+  }
+
+  inline std::vector<Tensor> reverse(std::vector<Tensor>&& x) const {
+    std::reverse(x.begin(), x.end());
+    return std::move(x);
+  }
+
+ private:
+  FullLSTMLayer layer_;
+  CPUContext* context_;
+};
+
+template <typename hidden_type, typename weight_type>
+LayerOutput<Tensor, std::vector<hidden_type>> apply_layer_stack(
+    const Layer<hidden_type, weight_type>& layer,
+    const Tensor& input,
+    const std::vector<hidden_type>& hiddens,
+    const std::vector<weight_type>& weights,
+    int64_t num_layers) {
+  CAFFE_ENFORCE(
+      num_layers == hiddens.size(),
+      "Expected more hidden states in stacked_rnn");
+  CAFFE_ENFORCE(
+      num_layers == weights.size(), "Expected more weights in stacked_rnn");
+
+  auto layer_input = input.UnsafeSharedInstance();
+  auto hidden_it = hiddens.begin();
+  auto weight_it = weights.begin();
+  std::vector<hidden_type> final_hiddens(num_layers);
+  for (int64_t l = 0; l < num_layers; ++l) {
+    auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++));
+    final_hiddens.at(l) = std::move(layer_output.final_hidden);
+    layer_input = std::move(layer_output.outputs);
+  }
+  return {layer_input, final_hiddens};
+}
+
+std::tuple<Tensor, Tensor, Tensor> _lstm_impl(
+    const Tensor& input,
+    const std::vector<CellParams>& params,
+    const Tensor& hx,
+    const Tensor& cx,
+    int64_t num_layers,
+    bool bidirectional,
+    CPUContext* context) {
+  using stack_output = LayerOutput<Tensor, std::vector<t_tuple>>;
+  auto layer_hx = unbind(hx, 0, context);
+  auto layer_cx = unbind(cx, 0, context);
+  int64_t total_layers = layer_hx.size();
+  std::vector<std::tuple<Tensor, Tensor>> hiddens;
+  hiddens.reserve(total_layers);
+  for (int64_t i = 0; i < total_layers; ++i) {
+    hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
+  }
+  LSTMCell cell(context);
+  std::shared_ptr<stack_output> stack_output_ptr;
+  if (bidirectional) {
+    auto bidir_result = apply_layer_stack(
+        FullBidirectionalLSTMLayer{cell, context},
+        input,
+        pair_vec(hiddens),
+        pair_vec(params),
+        num_layers);
+    stack_output_ptr.reset(new stack_output(
+        bidir_result.outputs,
+        unpair_vec(std::move(bidir_result.final_hidden))));
+  } else {
+    auto result = apply_layer_stack(
+        FullLSTMLayer{cell, context}, input, hiddens, params, num_layers);
+    stack_output_ptr = std::make_shared<stack_output>(std::move(result));
+  }
+
+  std::vector<Tensor> hy, cy;
+  hy.reserve(total_layers);
+  cy.reserve(total_layers);
+  for (auto& hidden : stack_output_ptr->final_hidden) {
+    hy.push_back(std::move(std::get<0>(hidden)));
+    cy.push_back(std::move(std::get<1>(hidden)));
+  }
+  return std::make_tuple(
+      std::move(stack_output_ptr->outputs),
+      stack(hy, 0, context),
+      stack(cy, 0, context));
+}
+
+// Parses a flat list of parameter tensors into a list of CellParams
+std::vector<CellParams> gather_params(
+    const std::vector<Tensor>& params,
+    bool has_biases,
+    CPUContext* context) {
+  Tensor undefined;
+  std::vector<CellParams> result;
+  if (has_biases) {
+    CAFFE_ENFORCE_EQ(
+        params.size() % 4, 0, "got an incorrect number of LSTM parameters");
+    for (size_t i = 0; i < params.size(); i += 4) {
+      result.emplace_back(
+          params[i], params[i + 1], params[i + 2], params[i + 3], context);
+    }
+  } else {
+    CAFFE_ENFORCE_EQ(
+        params.size() % 2, 0, "got an incorrect number of LSTM parameters");
+    for (size_t i = 0; i < params.size(); i += 2) {
+      result.emplace_back(
+          params[i], params[i + 1], undefined, undefined, context);
+    }
+  }
+  return result;
+}
+
+class InferenceLSTMOp : public Operator<CPUContext> {
+ public:
+  template <class... Args>
+  explicit InferenceLSTMOp(Args&&... args)
+      : Operator(std::forward<Args>(args)...),
+        num_layers_(this->template GetSingleArgument<int64_t>("num_layers", 1)),
+        bidirectional_(
+            this->template GetSingleArgument<bool>("bidirectional", false)),
+        has_biases_(this->template GetSingleArgument<bool>("has_biases", true)),
+        batch_first_(
+            this->template GetSingleArgument<bool>("batch_first", false)) {}
+
+  bool RunOnDevice() override;
+
+ protected:
+  int64_t num_layers_;
+  bool bidirectional_;
+  bool has_biases_;
+  bool batch_first_;
+};
+
+} // namespace
+} // namespace caffe2
+#endif // LSTM_OP_H_
diff --git a/caffe2/operators/lstm_utils.h b/caffe2/operators/lstm_utils.h
new file mode 100644 (file)
index 0000000..b354764
--- /dev/null
@@ -0,0 +1,318 @@
+#include <algorithm>
+#include <vector>
+#include "caffe2/core/tensor.h"
+#include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+namespace {
+
+using t_tuple = std::tuple<Tensor, Tensor>;
+
+template <typename T>
+T copy_ctor(const T& x) {
+  return x;
+}
+
+template <>
+Tensor copy_ctor(const Tensor& X) {
+  return X.UnsafeSharedInstance();
+}
+
+template <>
+t_tuple copy_ctor(const t_tuple& X) {
+  return std::make_tuple(copy_ctor(std::get<0>(X)), copy_ctor(std::get<1>(X)));
+}
+
+template <>
+std::pair<t_tuple, t_tuple> copy_ctor(const std::pair<t_tuple, t_tuple>& X) {
+  return std::make_pair(copy_ctor(X.first), copy_ctor(X.second));
+}
+
+template <>
+std::vector<Tensor> copy_ctor(const std::vector<Tensor>& X) {
+  std::vector<Tensor> Y(X.size());
+  std::transform(X.begin(), X.end(), Y.begin(), [](const Tensor& x) {
+    return copy_ctor(x);
+  });
+  return Y;
+}
+
+template <>
+std::vector<t_tuple> copy_ctor(const std::vector<t_tuple>& X) {
+  std::vector<t_tuple> Y(X.size());
+  std::transform(X.begin(), X.end(), Y.begin(), [](const t_tuple& x) {
+    return copy_ctor(x);
+  });
+  return Y;
+}
+
+template <>
+std::vector<std::pair<t_tuple, t_tuple>> copy_ctor(
+    const std::vector<std::pair<t_tuple, t_tuple>>& X) {
+  std::vector<std::pair<t_tuple, t_tuple>> Y(X.size());
+  std::transform(
+      X.begin(), X.end(), Y.begin(), [](const std::pair<t_tuple, t_tuple>& x) {
+        return copy_ctor(x);
+      });
+  return Y;
+}
+
+// Gathers every two elements of a vector in a vector of pairs
+template <typename T>
+static std::vector<std::pair<T, T>> pair_vec(const std::vector<T>& vals) {
+  CAFFE_ENFORCE_EQ(
+      vals.size() % 2,
+      0,
+      "Odd number of params or hiddens given to a bidirectional RNN");
+  std::vector<std::pair<T, T>> result;
+  result.reserve(vals.size() / 2);
+  for (int64_t i = 0; i < vals.size(); i += 2) {
+    result.emplace_back(copy_ctor(vals[i]), copy_ctor(vals[i + 1]));
+  }
+  return result;
+}
+
+// Flattens a vector of pairs
+template <typename T>
+static std::vector<T> unpair_vec(std::vector<std::pair<T, T>>&& vals) {
+  std::vector<T> result;
+  result.reserve(vals.size() * 2);
+  for (int64_t i = 0; i < vals.size(); i++) {
+    result.push_back(std::move(vals[i].first));
+    result.push_back(std::move(vals[i].second));
+  }
+  return result;
+}
+
+Tensor matmul(const Tensor& X, const Tensor& W, CPUContext* context) {
+  const auto canonical_axis = X.canonical_axis_index(1);
+  const auto M = X.size_to_dim(canonical_axis);
+  const auto K = X.size_from_dim(canonical_axis);
+  const auto canonical_axis_w = W.canonical_axis_index(1);
+  const int N = W.size_to_dim(canonical_axis_w);
+  auto output_size = X.sizes().vec();
+  output_size.resize(canonical_axis + 1);
+  output_size[canonical_axis] = N;
+  Tensor C(output_size, CPU);
+  math::Gemm<float, CPUContext>(
+      CblasNoTrans,
+      CblasTrans,
+      M,
+      N,
+      K,
+      1,
+      X.template data<float>(),
+      W.template data<float>(),
+      0,
+      C.template mutable_data<float>(),
+      context);
+  return C;
+}
+
+Tensor
+linear(const Tensor& X, const Tensor& W, const Tensor& B, CPUContext* context) {
+  auto output = matmul(X, W, context);
+  if (B) {
+    const auto canonical_axis = X.canonical_axis_index(1);
+    const auto M = X.size_to_dim(canonical_axis);
+    const auto canonical_axis_w = W.canonical_axis_index(1);
+    const int N = W.size_to_dim(canonical_axis_w);
+    auto bias_multiplier_ = caffe2::empty({M}, CPU);
+    math::Set<float, CPUContext>(
+        M, 1, bias_multiplier_.template mutable_data<float>(), context);
+    math::Gemm<float, CPUContext>(
+        CblasNoTrans,
+        CblasNoTrans,
+        M,
+        N,
+        1,
+        1,
+        bias_multiplier_.template data<float>(),
+        B.template data<float>(),
+        1,
+        output.template mutable_data<float>(),
+        context);
+  }
+  return output;
+}
+
+std::vector<Tensor>
+chunk(const Tensor& input, int chunks, int axis, CPUContext* context) {
+  int canonical_axis = input.canonical_axis_index(axis);
+  CAFFE_ENFORCE_LT(
+      canonical_axis, input.dim(), "Axis not in input ndim range.");
+  const int input_channels = input.dim32(canonical_axis);
+  CAFFE_ENFORCE_EQ(
+      input_channels % chunks,
+      0,
+      "input channels should be divisible by the number of chunks.");
+  auto split_size = input_channels / chunks;
+  vector<int64_t> output_dims(input.sizes().vec());
+  int before = 1, after = 1;
+  for (int i = 0; i < canonical_axis; ++i) {
+    before *= input.dim32(i);
+  }
+  for (int i = canonical_axis + 1; i < input.dim(); ++i) {
+    after *= input.dim32(i);
+  }
+  size_t input_offset = 0;
+  std::vector<Tensor> outputs;
+  for (int i = 0; i < chunks; ++i) {
+    auto axis_dim = split_size;
+    output_dims[canonical_axis] = split_size;
+    Tensor output(output_dims, CPU);
+    math::CopyMatrix<CPUContext>(
+        input.itemsize(),
+        before,
+        axis_dim * after,
+        static_cast<const char*>(input.raw_data()) + input_offset,
+        input.dim32(canonical_axis) * after,
+        output.raw_mutable_data(input.dtype()),
+        axis_dim * after,
+        context,
+        input.dtype().copy());
+    input_offset += axis_dim * after * input.itemsize();
+    outputs.push_back(std::move(output));
+  }
+  return outputs;
+}
+
+std::vector<Tensor> unbind(const Tensor& input, int axis, CPUContext* context) {
+  // 1 - Chunk the input tensor along the given axis into N chunks where
+  // N is the dim(axis)
+  auto chunks = chunk(input, input.sizes()[axis], axis, context);
+  // 2 - Compute new dimensions
+  std::vector<int64_t> newDims = input.sizes().vec();
+  newDims.erase(newDims.begin() + axis);
+
+  // 3 - Reshape chunks to drop the extra dimension
+  for (int i = 0; i < chunks.size(); i++) {
+    CAFFE_ENFORCE_EQ(
+        chunks[i].sizes()[axis], 1, "Got an unexpected chunk size");
+    chunks[i].Reshape(newDims);
+  }
+  return chunks;
+}
+
+Tensor
+cat(const std::vector<Tensor>& tensorList, int axis, CPUContext* context) {
+  // Adopted from C2's concat operator
+  auto input_zero = copy_ctor(tensorList.at(0));
+  vector<int64_t> outputDims(input_zero.sizes().vec());
+  CAFFE_ENFORCE(outputDims.size() > 0);
+  for (int i = 1; i < tensorList.size(); i++) {
+    CAFFE_ENFORCE(input_zero.dtype() == tensorList.at(i).dtype());
+    outputDims[axis] += tensorList.at(i).sizes()[axis];
+  }
+  auto output_channels = outputDims[axis];
+  Tensor output(outputDims, CPU);
+  int before = 1, after = 1;
+  for (int i = 0; i < tensorList.at(0).dim(); ++i) {
+    if (i == axis) {
+      continue;
+    }
+    int dim = input_zero.dim32(i);
+    if (i < axis) {
+      before *= dim;
+    } else {
+      after *= dim;
+    }
+  }
+  size_t output_offset = 0;
+  for (const auto& input : tensorList) {
+    auto axis_dim = input.dim32(axis);
+    math::CopyMatrix<CPUContext>(
+        input.itemsize(),
+        before,
+        axis_dim * after,
+        input.raw_data(),
+        axis_dim * after,
+        static_cast<char*>(output.raw_mutable_data(input_zero.dtype())) +
+            output_offset,
+        output_channels * after,
+        context,
+        input_zero.dtype().copy());
+    output_offset += axis_dim * after * input.itemsize();
+  }
+
+  return output;
+}
+
+Tensor
+stack(const std::vector<Tensor>& tensorList, int axis, CPUContext* context) {
+  // 1 - Compute new dimensions
+  std::vector<int64_t> newDims(tensorList[0].sizes().vec());
+  std::vector<Tensor> expandedTensorList;
+  newDims.insert(newDims.begin() + axis, 1);
+  for (int i = 0; i < tensorList.size(); i++) {
+    expandedTensorList.emplace_back(tensorList[i].Clone());
+    expandedTensorList.at(i).Reshape(newDims);
+  }
+  return cat(expandedTensorList, axis, context);
+}
+
+Tensor sigmoid(const Tensor& X) {
+  Tensor Y(X.sizes(), CPU);
+  auto N = X.numel();
+  EigenVectorArrayMap<float>(Y.template mutable_data<float>(), N) = 1.0 /
+      (1.0 +
+       (-ConstEigenVectorArrayMap<float>(X.template data<float>(), N)).exp());
+  return Y;
+}
+
+Tensor tanh(const Tensor& X, CPUContext* context) {
+  Tensor Y(X.sizes(), CPU);
+  math::Tanh<float, CPUContext>(
+      X.numel(),
+      X.template data<float>(),
+      Y.template mutable_data<float>(),
+      context);
+  return Y;
+}
+
+Tensor add(const Tensor& X, const Tensor& Y, CPUContext* context) {
+  Tensor Z(X.sizes().vec(), CPU);
+  math::Add<float, CPUContext>(
+      X.numel(),
+      X.template data<float>(),
+      Y.template data<float>(),
+      Z.template mutable_data<float>(),
+      context);
+  return Z;
+}
+
+Tensor mul(const Tensor& X, const Tensor& Y, CPUContext* context) {
+  Tensor Z(X.sizes().vec(), CPU);
+  math::Mul<float, CPUContext>(
+      X.numel(),
+      X.template data<float>(),
+      Y.template data<float>(),
+      Z.template mutable_data<float>(),
+      context);
+  return Z;
+}
+
+Tensor transpose(const Tensor& X, int dim0, int dim1, CPUContext* context) {
+  int ndim = X.dim();
+  CAFFE_ENFORCE(ndim > dim0 && ndim > dim1, "Invalid transpose dimensions");
+  std::vector<int> axes(ndim);
+  std::iota(axes.begin(), axes.end(), 0);
+  std::swap(axes[dim0], axes[dim1]);
+  std::vector<int> Y_dims(ndim);
+  std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
+  for (int i = 0; i < ndim; ++i) {
+    Y_dims[i] = X_dims[axes[i]];
+  }
+  Tensor Y(Y_dims, CPU);
+  math::Transpose<float, CPUContext>(
+      ndim,
+      X_dims.data(),
+      axes.data(),
+      X.template data<float>(),
+      Y.template mutable_data<float>(),
+      context);
+  return Y;
+}
+} // namespace
+} // namespace caffe2
diff --git a/caffe2/python/test/inference_lstm_op_test.py b/caffe2/python/test/inference_lstm_op_test.py
new file mode 100644 (file)
index 0000000..dc33f52
--- /dev/null
@@ -0,0 +1,72 @@
+#!/usr/bin/env python3
+import inspect
+
+import hypothesis.strategies as st
+import numpy as np
+import torch
+from caffe2.python import core, workspace
+from caffe2.python.test_util import TestCase
+from hypothesis import given
+from torch import nn
+
+
+class TestC2LSTM(TestCase):
+    @given(
+        bsz=st.integers(1, 5),
+        seq_lens=st.integers(1, 6),
+        emb_lens=st.integers(5, 10),
+        hidden_size=st.integers(3, 7),
+        num_layers=st.integers(1, 4),
+        has_biases=st.booleans(),
+        is_bidirectional=st.booleans(),
+        batch_first=st.booleans(),
+    )
+    def test_c2_lstm(
+        self,
+        bsz,
+        seq_lens,
+        emb_lens,
+        hidden_size,
+        num_layers,
+        has_biases,
+        is_bidirectional,
+        batch_first,
+    ):
+        net = core.Net("test_net")
+        num_directions = 2 if is_bidirectional else 1
+        py_lstm = nn.LSTM(
+            emb_lens,
+            hidden_size,
+            batch_first=batch_first,
+            bidirectional=is_bidirectional,
+            bias=has_biases,
+            num_layers=num_layers,
+        )
+
+        hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)
+
+        if batch_first:
+            inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
+        else:
+            inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32)
+
+        py_results = py_lstm(torch.from_numpy(inputs))
+        lstm_in = [
+            torch.from_numpy(inputs),
+            torch.from_numpy(hx),
+            torch.from_numpy(hx),
+        ] + [param.detach() for param in py_lstm._flat_weights]
+
+        c2_results = torch.ops._caffe2.InferenceLSTM(
+            lstm_in, num_layers, has_biases, batch_first, is_bidirectional
+        )
+
+        np.testing.assert_array_almost_equal(
+            py_results[0].detach().numpy(), c2_results[0].detach().numpy()
+        )
+        np.testing.assert_array_almost_equal(
+            py_results[1][0].detach().numpy(), c2_results[1].detach().numpy()
+        )
+        np.testing.assert_array_almost_equal(
+            py_results[1][1].detach().numpy(), c2_results[2].detach().numpy()
+        )