Add LSTMLayer and LSTMUnitLayer, with tests
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 5 Apr 2016 16:56:04 +0000 (09:56 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Wed, 1 Jun 2016 22:29:52 +0000 (15:29 -0700)
include/caffe/layers/lstm_layer.hpp [new file with mode: 0644]
src/caffe/layers/lstm_layer.cpp [new file with mode: 0644]
src/caffe/layers/lstm_unit_layer.cpp [new file with mode: 0644]
src/caffe/layers/lstm_unit_layer.cu [new file with mode: 0644]
src/caffe/test/test_lstm_layer.cpp [new file with mode: 0644]

diff --git a/include/caffe/layers/lstm_layer.hpp b/include/caffe/layers/lstm_layer.hpp
new file mode 100644 (file)
index 0000000..a0e67c9
--- /dev/null
@@ -0,0 +1,154 @@
+#ifndef CAFFE_LSTM_LAYER_HPP_
+#define CAFFE_LSTM_LAYER_HPP_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/layers/recurrent_layer.hpp"
+#include "caffe/net.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype> class RecurrentLayer;
+
+/**
+ * @brief Processes sequential inputs using a "Long Short-Term Memory" (LSTM)
+ *        [1] style recurrent neural network (RNN). Implemented by unrolling
+ *        the LSTM computation through time.
+ *
+ * The specific architecture used in this implementation is as described in
+ * "Learning to Execute" [2], reproduced below:
+ *     i_t := \sigmoid[ W_{hi} * h_{t-1} + W_{xi} * x_t + b_i ]
+ *     f_t := \sigmoid[ W_{hf} * h_{t-1} + W_{xf} * x_t + b_f ]
+ *     o_t := \sigmoid[ W_{ho} * h_{t-1} + W_{xo} * x_t + b_o ]
+ *     g_t :=    \tanh[ W_{hg} * h_{t-1} + W_{xg} * x_t + b_g ]
+ *     c_t := (f_t .* c_{t-1}) + (i_t .* g_t)
+ *     h_t := o_t .* \tanh[c_t]
+ * In the implementation, the i, f, o, and g computations are performed as a
+ * single inner product.
+ *
+ * Notably, this implementation lacks the "diagonal" gates, as used in the
+ * LSTM architectures described by Alex Graves [3] and others.
+ *
+ * [1] Hochreiter, Sepp, and Schmidhuber, Jürgen. "Long short-term memory."
+ *     Neural Computation 9, no. 8 (1997): 1735-1780.
+ *
+ * [2] Zaremba, Wojciech, and Sutskever, Ilya. "Learning to execute."
+ *     arXiv preprint arXiv:1410.4615 (2014).
+ *
+ * [3] Graves, Alex. "Generating sequences with recurrent neural networks."
+ *     arXiv preprint arXiv:1308.0850 (2013).
+ */
+template <typename Dtype>
+class LSTMLayer : public RecurrentLayer<Dtype> {
+ public:
+  explicit LSTMLayer(const LayerParameter& param)
+      : RecurrentLayer<Dtype>(param) {}
+
+  virtual inline const char* type() const { return "LSTM"; }
+
+ protected:
+  virtual void FillUnrolledNet(NetParameter* net_param) const;
+  virtual void RecurrentInputBlobNames(vector<string>* names) const;
+  virtual void RecurrentOutputBlobNames(vector<string>* names) const;
+  virtual void RecurrentInputShapes(vector<BlobShape>* shapes) const;
+  virtual void OutputBlobNames(vector<string>* names) const;
+};
+
+/**
+ * @brief A helper for LSTMLayer: computes a single timestep of the
+ *        non-linearity of the LSTM, producing the updated cell and hidden
+ *        states.
+ */
+template <typename Dtype>
+class LSTMUnitLayer : public Layer<Dtype> {
+ public:
+  explicit LSTMUnitLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  virtual inline const char* type() const { return "LSTMUnit"; }
+  virtual inline int ExactNumBottomBlobs() const { return 3; }
+  virtual inline int ExactNumTopBlobs() const { return 2; }
+
+  virtual inline bool AllowForceBackward(const int bottom_index) const {
+    // Can't propagate to sequence continuation indicators.
+    return bottom_index != 2;
+  }
+
+ protected:
+  /**
+   * @param bottom input Blob vector (length 3)
+   *   -# @f$ (1 \times N \times D) @f$
+   *      the previous timestep cell state @f$ c_{t-1} @f$
+   *   -# @f$ (1 \times N \times 4D) @f$
+   *      the "gate inputs" @f$ [i_t', f_t', o_t', g_t'] @f$
+   *   -# @f$ (1 \times N) @f$
+   *      the sequence continuation indicators  @f$ \delta_t @f$
+   * @param top output Blob vector (length 2)
+   *   -# @f$ (1 \times N \times D) @f$
+   *      the updated cell state @f$ c_t @f$, computed as:
+   *          i_t := \sigmoid[i_t']
+   *          f_t := \sigmoid[f_t']
+   *          o_t := \sigmoid[o_t']
+   *          g_t := \tanh[g_t']
+   *          c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t)
+   *   -# @f$ (1 \times N \times D) @f$
+   *      the updated hidden state @f$ h_t @f$, computed as:
+   *          h_t := o_t .* \tanh[c_t]
+   */
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  /**
+   * @brief Computes the error gradient w.r.t. the LSTMUnit inputs.
+   *
+   * @param top output Blob vector (length 2), providing the error gradient with
+   *        respect to the outputs
+   *   -# @f$ (1 \times N \times D) @f$:
+   *      containing error gradients @f$ \frac{\partial E}{\partial c_t} @f$
+   *      with respect to the updated cell state @f$ c_t @f$
+   *   -# @f$ (1 \times N \times D) @f$:
+   *      containing error gradients @f$ \frac{\partial E}{\partial h_t} @f$
+   *      with respect to the updated cell state @f$ h_t @f$
+   * @param propagate_down see Layer::Backward.
+   * @param bottom input Blob vector (length 3), into which the error gradients
+   *        with respect to the LSTMUnit inputs @f$ c_{t-1} @f$ and the gate
+   *        inputs are computed.  Computatation of the error gradients w.r.t.
+   *        the sequence indicators is not implemented.
+   *   -# @f$ (1 \times N \times D) @f$
+   *      the error gradient w.r.t. the previous timestep cell state
+   *      @f$ c_{t-1} @f$
+   *   -# @f$ (1 \times N \times 4D) @f$
+   *      the error gradient w.r.t. the "gate inputs"
+   *      @f$ [
+   *          \frac{\partial E}{\partial i_t}
+   *          \frac{\partial E}{\partial f_t}
+   *          \frac{\partial E}{\partial o_t}
+   *          \frac{\partial E}{\partial g_t}
+   *          ] @f$
+   *   -# @f$ (1 \times 1 \times N) @f$
+   *      the gradient w.r.t. the sequence continuation indicators
+   *      @f$ \delta_t @f$ is currently not computed.
+   */
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+  /// @brief The hidden and output dimension.
+  int hidden_dim_;
+  Blob<Dtype> X_acts_;
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_LSTM_LAYER_HPP_
diff --git a/src/caffe/layers/lstm_layer.cpp b/src/caffe/layers/lstm_layer.cpp
new file mode 100644 (file)
index 0000000..da48dba
--- /dev/null
@@ -0,0 +1,244 @@
+#include <string>
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/layers/lstm_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void LSTMLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const {
+  names->resize(2);
+  (*names)[0] = "h_0";
+  (*names)[1] = "c_0";
+}
+
+template <typename Dtype>
+void LSTMLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const {
+  names->resize(2);
+  (*names)[0] = "h_" + format_int(this->T_);
+  (*names)[1] = "c_T";
+}
+
+template <typename Dtype>
+void LSTMLayer<Dtype>::RecurrentInputShapes(vector<BlobShape>* shapes) const {
+  const int num_output = this->layer_param_.recurrent_param().num_output();
+  const int num_blobs = 2;
+  shapes->resize(num_blobs);
+  for (int i = 0; i < num_blobs; ++i) {
+    (*shapes)[i].Clear();
+    (*shapes)[i].add_dim(1);  // a single timestep
+    (*shapes)[i].add_dim(this->N_);
+    (*shapes)[i].add_dim(num_output);
+  }
+}
+
+template <typename Dtype>
+void LSTMLayer<Dtype>::OutputBlobNames(vector<string>* names) const {
+  names->resize(1);
+  (*names)[0] = "h";
+}
+
+template <typename Dtype>
+void LSTMLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
+  const int num_output = this->layer_param_.recurrent_param().num_output();
+  CHECK_GT(num_output, 0) << "num_output must be positive";
+  const FillerParameter& weight_filler =
+      this->layer_param_.recurrent_param().weight_filler();
+  const FillerParameter& bias_filler =
+      this->layer_param_.recurrent_param().bias_filler();
+
+  // Add generic LayerParameter's (without bottoms/tops) of layer types we'll
+  // use to save redundant code.
+  LayerParameter hidden_param;
+  hidden_param.set_type("InnerProduct");
+  hidden_param.mutable_inner_product_param()->set_num_output(num_output * 4);
+  hidden_param.mutable_inner_product_param()->set_bias_term(false);
+  hidden_param.mutable_inner_product_param()->set_axis(2);
+  hidden_param.mutable_inner_product_param()->
+      mutable_weight_filler()->CopyFrom(weight_filler);
+
+  LayerParameter biased_hidden_param(hidden_param);
+  biased_hidden_param.mutable_inner_product_param()->set_bias_term(true);
+  biased_hidden_param.mutable_inner_product_param()->
+      mutable_bias_filler()->CopyFrom(bias_filler);
+
+  LayerParameter sum_param;
+  sum_param.set_type("Eltwise");
+  sum_param.mutable_eltwise_param()->set_operation(
+      EltwiseParameter_EltwiseOp_SUM);
+
+  LayerParameter scale_param;
+  scale_param.set_type("Scale");
+  scale_param.mutable_scale_param()->set_axis(0);
+
+  LayerParameter slice_param;
+  slice_param.set_type("Slice");
+  slice_param.mutable_slice_param()->set_axis(0);
+
+  LayerParameter split_param;
+  split_param.set_type("Split");
+
+  vector<BlobShape> input_shapes;
+  RecurrentInputShapes(&input_shapes);
+  CHECK_EQ(2, input_shapes.size());
+
+  LayerParameter* input_layer_param = net_param->add_layer();
+  input_layer_param->set_type("Input");
+  InputParameter* input_param = input_layer_param->mutable_input_param();
+
+  input_layer_param->add_top("c_0");
+  input_param->add_shape()->CopyFrom(input_shapes[0]);
+
+  input_layer_param->add_top("h_0");
+  input_param->add_shape()->CopyFrom(input_shapes[1]);
+
+  LayerParameter* cont_slice_param = net_param->add_layer();
+  cont_slice_param->CopyFrom(slice_param);
+  cont_slice_param->set_name("cont_slice");
+  cont_slice_param->add_bottom("cont");
+  cont_slice_param->mutable_slice_param()->set_axis(0);
+
+  // Add layer to transform all timesteps of x to the hidden state dimension.
+  //     W_xc_x = W_xc * x + b_c
+  {
+    LayerParameter* x_transform_param = net_param->add_layer();
+    x_transform_param->CopyFrom(biased_hidden_param);
+    x_transform_param->set_name("x_transform");
+    x_transform_param->add_param()->set_name("W_xc");
+    x_transform_param->add_param()->set_name("b_c");
+    x_transform_param->add_bottom("x");
+    x_transform_param->add_top("W_xc_x");
+    x_transform_param->add_propagate_down(true);
+  }
+
+  if (this->static_input_) {
+    // Add layer to transform x_static to the gate dimension.
+    //     W_xc_x_static = W_xc_static * x_static
+    LayerParameter* x_static_transform_param = net_param->add_layer();
+    x_static_transform_param->CopyFrom(hidden_param);
+    x_static_transform_param->mutable_inner_product_param()->set_axis(1);
+    x_static_transform_param->set_name("W_xc_x_static");
+    x_static_transform_param->add_param()->set_name("W_xc_static");
+    x_static_transform_param->add_bottom("x_static");
+    x_static_transform_param->add_top("W_xc_x_static_preshape");
+    x_static_transform_param->add_propagate_down(true);
+
+    LayerParameter* reshape_param = net_param->add_layer();
+    reshape_param->set_type("Reshape");
+    BlobShape* new_shape =
+         reshape_param->mutable_reshape_param()->mutable_shape();
+    new_shape->add_dim(1);  // One timestep.
+    // Should infer this->N as the dimension so we can reshape on batch size.
+    new_shape->add_dim(-1);
+    new_shape->add_dim(
+        x_static_transform_param->inner_product_param().num_output());
+    reshape_param->set_name("W_xc_x_static_reshape");
+    reshape_param->add_bottom("W_xc_x_static_preshape");
+    reshape_param->add_top("W_xc_x_static");
+  }
+
+  LayerParameter* x_slice_param = net_param->add_layer();
+  x_slice_param->CopyFrom(slice_param);
+  x_slice_param->add_bottom("W_xc_x");
+  x_slice_param->set_name("W_xc_x_slice");
+
+  LayerParameter output_concat_layer;
+  output_concat_layer.set_name("h_concat");
+  output_concat_layer.set_type("Concat");
+  output_concat_layer.add_top("h");
+  output_concat_layer.mutable_concat_param()->set_axis(0);
+
+  for (int t = 1; t <= this->T_; ++t) {
+    string tm1s = format_int(t - 1);
+    string ts = format_int(t);
+
+    cont_slice_param->add_top("cont_" + ts);
+    x_slice_param->add_top("W_xc_x_" + ts);
+
+    // Add layers to flush the hidden state when beginning a new
+    // sequence, as indicated by cont_t.
+    //     h_conted_{t-1} := cont_t * h_{t-1}
+    //
+    // Normally, cont_t is binary (i.e., 0 or 1), so:
+    //     h_conted_{t-1} := h_{t-1} if cont_t == 1
+    //                       0   otherwise
+    {
+      LayerParameter* cont_h_param = net_param->add_layer();
+      cont_h_param->CopyFrom(scale_param);
+      cont_h_param->set_name("h_conted_" + tm1s);
+      cont_h_param->add_bottom("h_" + tm1s);
+      cont_h_param->add_bottom("cont_" + ts);
+      cont_h_param->add_top("h_conted_" + tm1s);
+    }
+
+    // Add layer to compute
+    //     W_hc_h_{t-1} := W_hc * h_conted_{t-1}
+    {
+      LayerParameter* w_param = net_param->add_layer();
+      w_param->CopyFrom(hidden_param);
+      w_param->set_name("transform_" + ts);
+      w_param->add_param()->set_name("W_hc");
+      w_param->add_bottom("h_conted_" + tm1s);
+      w_param->add_top("W_hc_h_" + tm1s);
+      w_param->mutable_inner_product_param()->set_axis(2);
+    }
+
+    // Add the outputs of the linear transformations to compute the gate input.
+    //     gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
+    //                   = W_hc_h_{t-1} + W_xc_x_t + b_c
+    {
+      LayerParameter* input_sum_layer = net_param->add_layer();
+      input_sum_layer->CopyFrom(sum_param);
+      input_sum_layer->set_name("gate_input_" + ts);
+      input_sum_layer->add_bottom("W_hc_h_" + tm1s);
+      input_sum_layer->add_bottom("W_xc_x_" + ts);
+      if (this->static_input_) {
+        input_sum_layer->add_bottom("W_xc_x_static");
+      }
+      input_sum_layer->add_top("gate_input_" + ts);
+    }
+
+    // Add LSTMUnit layer to compute the cell & hidden vectors c_t and h_t.
+    // Inputs: c_{t-1}, gate_input_t = (i_t, f_t, o_t, g_t), cont_t
+    // Outputs: c_t, h_t
+    //     [ i_t' ]
+    //     [ f_t' ] := gate_input_t
+    //     [ o_t' ]
+    //     [ g_t' ]
+    //         i_t := \sigmoid[i_t']
+    //         f_t := \sigmoid[f_t']
+    //         o_t := \sigmoid[o_t']
+    //         g_t := \tanh[g_t']
+    //         c_t := cont_t * (f_t .* c_{t-1}) + (i_t .* g_t)
+    //         h_t := o_t .* \tanh[c_t]
+    {
+      LayerParameter* lstm_unit_param = net_param->add_layer();
+      lstm_unit_param->set_type("LSTMUnit");
+      lstm_unit_param->add_bottom("c_" + tm1s);
+      lstm_unit_param->add_bottom("gate_input_" + ts);
+      lstm_unit_param->add_bottom("cont_" + ts);
+      lstm_unit_param->add_top("c_" + ts);
+      lstm_unit_param->add_top("h_" + ts);
+      lstm_unit_param->set_name("unit_" + ts);
+    }
+    output_concat_layer.add_bottom("h_" + ts);
+  }  // for (int t = 1; t <= this->T_; ++t)
+
+  {
+    LayerParameter* c_T_copy_param = net_param->add_layer();
+    c_T_copy_param->CopyFrom(split_param);
+    c_T_copy_param->add_bottom("c_" + format_int(this->T_));
+    c_T_copy_param->add_top("c_T");
+  }
+  net_param->add_layer()->CopyFrom(output_concat_layer);
+}
+
+INSTANTIATE_CLASS(LSTMLayer);
+REGISTER_LAYER_CLASS(LSTM);
+
+}  // namespace caffe
diff --git a/src/caffe/layers/lstm_unit_layer.cpp b/src/caffe/layers/lstm_unit_layer.cpp
new file mode 100644 (file)
index 0000000..277c031
--- /dev/null
@@ -0,0 +1,131 @@
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/layers/lstm_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+inline Dtype sigmoid(Dtype x) {
+  return 1. / (1. + exp(-x));
+}
+
+template <typename Dtype>
+inline Dtype tanh(Dtype x) {
+  return 2. * sigmoid(2. * x) - 1.;
+}
+
+template <typename Dtype>
+void LSTMUnitLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const int num_instances = bottom[0]->shape(1);
+  for (int i = 0; i < bottom.size(); ++i) {
+    if (i == 2) {
+      CHECK_EQ(2, bottom[i]->num_axes());
+    } else {
+      CHECK_EQ(3, bottom[i]->num_axes());
+    }
+    CHECK_EQ(1, bottom[i]->shape(0));
+    CHECK_EQ(num_instances, bottom[i]->shape(1));
+  }
+  hidden_dim_ = bottom[0]->shape(2);
+  CHECK_EQ(num_instances, bottom[1]->shape(1));
+  CHECK_EQ(4 * hidden_dim_, bottom[1]->shape(2));
+  top[0]->ReshapeLike(*bottom[0]);
+  top[1]->ReshapeLike(*bottom[0]);
+  X_acts_.ReshapeLike(*bottom[1]);
+}
+
+template <typename Dtype>
+void LSTMUnitLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const int num = bottom[0]->shape(1);
+  const int x_dim = hidden_dim_ * 4;
+  const Dtype* C_prev = bottom[0]->cpu_data();
+  const Dtype* X = bottom[1]->cpu_data();
+  const Dtype* cont = bottom[2]->cpu_data();
+  Dtype* C = top[0]->mutable_cpu_data();
+  Dtype* H = top[1]->mutable_cpu_data();
+  for (int n = 0; n < num; ++n) {
+    for (int d = 0; d < hidden_dim_; ++d) {
+      const Dtype i = sigmoid(X[d]);
+      const Dtype f = (*cont == 0) ? 0 :
+          (*cont * sigmoid(X[1 * hidden_dim_ + d]));
+      const Dtype o = sigmoid(X[2 * hidden_dim_ + d]);
+      const Dtype g = tanh(X[3 * hidden_dim_ + d]);
+      const Dtype c_prev = C_prev[d];
+      const Dtype c = f * c_prev + i * g;
+      C[d] = c;
+      const Dtype tanh_c = tanh(c);
+      H[d] = o * tanh_c;
+    }
+    C_prev += hidden_dim_;
+    X += x_dim;
+    C += hidden_dim_;
+    H += hidden_dim_;
+    ++cont;
+  }
+}
+
+template <typename Dtype>
+void LSTMUnitLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+  CHECK(!propagate_down[2]) << "Cannot backpropagate to sequence indicators.";
+  if (!propagate_down[0] && !propagate_down[1]) { return; }
+
+  const int num = bottom[0]->shape(1);
+  const int x_dim = hidden_dim_ * 4;
+  const Dtype* C_prev = bottom[0]->cpu_data();
+  const Dtype* X = bottom[1]->cpu_data();
+  const Dtype* cont = bottom[2]->cpu_data();
+  const Dtype* C = top[0]->cpu_data();
+  const Dtype* H = top[1]->cpu_data();
+  const Dtype* C_diff = top[0]->cpu_diff();
+  const Dtype* H_diff = top[1]->cpu_diff();
+  Dtype* C_prev_diff = bottom[0]->mutable_cpu_diff();
+  Dtype* X_diff = bottom[1]->mutable_cpu_diff();
+  for (int n = 0; n < num; ++n) {
+    for (int d = 0; d < hidden_dim_; ++d) {
+      const Dtype i = sigmoid(X[d]);
+      const Dtype f = (*cont == 0) ? 0 :
+          (*cont * sigmoid(X[1 * hidden_dim_ + d]));
+      const Dtype o = sigmoid(X[2 * hidden_dim_ + d]);
+      const Dtype g = tanh(X[3 * hidden_dim_ + d]);
+      const Dtype c_prev = C_prev[d];
+      const Dtype c = C[d];
+      const Dtype tanh_c = tanh(c);
+      Dtype* c_prev_diff = C_prev_diff + d;
+      Dtype* i_diff = X_diff + d;
+      Dtype* f_diff = X_diff + 1 * hidden_dim_ + d;
+      Dtype* o_diff = X_diff + 2 * hidden_dim_ + d;
+      Dtype* g_diff = X_diff + 3 * hidden_dim_ + d;
+      const Dtype c_term_diff =
+          C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c);
+      *c_prev_diff = c_term_diff * f;
+      *i_diff = c_term_diff * g * i * (1 - i);
+      *f_diff = c_term_diff * c_prev * f * (1 - f);
+      *o_diff = H_diff[d] * tanh_c * o * (1 - o);
+      *g_diff = c_term_diff * i * (1 - g * g);
+    }
+    C_prev += hidden_dim_;
+    X += x_dim;
+    C += hidden_dim_;
+    H += hidden_dim_;
+    C_diff += hidden_dim_;
+    H_diff += hidden_dim_;
+    X_diff += x_dim;
+    C_prev_diff += hidden_dim_;
+    ++cont;
+  }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(LSTMUnitLayer);
+#endif
+
+INSTANTIATE_CLASS(LSTMUnitLayer);
+REGISTER_LAYER_CLASS(LSTMUnit);
+
+}  // namespace caffe
diff --git a/src/caffe/layers/lstm_unit_layer.cu b/src/caffe/layers/lstm_unit_layer.cu
new file mode 100644 (file)
index 0000000..15bb451
--- /dev/null
@@ -0,0 +1,154 @@
+#include <algorithm>
+#include <cmath>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/layers/lstm_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+__device__ Dtype sigmoid(const Dtype x) {
+  return Dtype(1) / (Dtype(1) + exp(-x));
+}
+
+template <typename Dtype>
+__device__ Dtype tanh(const Dtype x) {
+  return Dtype(2) * sigmoid(Dtype(2) * x) - Dtype(1);
+}
+
+template <typename Dtype>
+__global__ void LSTMActsForward(const int nthreads, const int dim,
+                                const Dtype* X, Dtype* X_acts) {
+  CUDA_KERNEL_LOOP(index, nthreads) {
+    const int x_dim = 4 * dim;
+    const int d = index % x_dim;
+    if (d < 3 * dim) {
+      X_acts[index] = sigmoid(X[index]);
+    } else {
+      X_acts[index] = tanh(X[index]);
+    }
+  }
+}
+
+template <typename Dtype>
+__global__ void LSTMUnitForward(const int nthreads, const int dim,
+    const Dtype* C_prev, const Dtype* X, const Dtype* cont,
+    Dtype* C, Dtype* H) {
+  CUDA_KERNEL_LOOP(index, nthreads) {
+    const int n = index / dim;
+    const int d = index % dim;
+    const Dtype* X_offset = X + 4 * dim * n;
+    const Dtype i = X_offset[d];
+    const Dtype f = X_offset[1 * dim + d];
+    const Dtype o = X_offset[2 * dim + d];
+    const Dtype g = X_offset[3 * dim + d];
+    const Dtype c_prev = C_prev[index];
+    const Dtype c = cont[n] * f * c_prev + i * g;
+    C[index] = c;
+    const Dtype tanh_c = tanh(c);
+    H[index] = o * tanh_c;
+  }
+}
+
+template <typename Dtype>
+void LSTMUnitLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const int count = top[1]->count();
+  const Dtype* C_prev = bottom[0]->gpu_data();
+  const Dtype* X = bottom[1]->gpu_data();
+  const Dtype* cont = bottom[2]->gpu_data();
+  Dtype* X_acts = X_acts_.mutable_gpu_data();
+  Dtype* C = top[0]->mutable_gpu_data();
+  Dtype* H = top[1]->mutable_gpu_data();
+  const int X_count = bottom[1]->count();
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  LSTMActsForward<Dtype><<<CAFFE_GET_BLOCKS(X_count), CAFFE_CUDA_NUM_THREADS>>>(
+      X_count, hidden_dim_, X, X_acts);
+  CUDA_POST_KERNEL_CHECK;
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  LSTMUnitForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+      count, hidden_dim_, C_prev, X_acts, cont, C, H);
+  CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void LSTMUnitBackward(const int nthreads, const int dim,
+    const Dtype* C_prev, const Dtype* X, const Dtype* C, const Dtype* H,
+    const Dtype* cont, const Dtype* C_diff, const Dtype* H_diff,
+    Dtype* C_prev_diff, Dtype* X_diff) {
+  CUDA_KERNEL_LOOP(index, nthreads) {
+    const int n = index / dim;
+    const int d = index % dim;
+    const Dtype* X_offset = X + 4 * dim * n;
+    const Dtype i = X_offset[d];
+    const Dtype f = X_offset[1 * dim + d];
+    const Dtype o = X_offset[2 * dim + d];
+    const Dtype g = X_offset[3 * dim + d];
+    const Dtype c_prev = C_prev[index];
+    const Dtype c = C[index];
+    const Dtype tanh_c = tanh(c);
+    Dtype* c_prev_diff = C_prev_diff + index;
+    Dtype* X_diff_offset = X_diff + 4 * dim * n;
+    Dtype* i_diff = X_diff_offset + d;
+    Dtype* f_diff = X_diff_offset + 1 * dim + d;
+    Dtype* o_diff = X_diff_offset + 2 * dim + d;
+    Dtype* g_diff = X_diff_offset + 3 * dim + d;
+    const Dtype c_term_diff =
+        C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
+    const Dtype cont_n = cont[n];
+    *c_prev_diff = cont_n * c_term_diff * f;
+    *i_diff = c_term_diff * g;
+    *f_diff = cont_n * c_term_diff * c_prev;
+    *o_diff = H_diff[index] * tanh_c;
+    *g_diff = c_term_diff * i;
+  }
+}
+
+template <typename Dtype>
+__global__ void LSTMActsBackward(const int nthreads, const int dim,
+    const Dtype* X_acts, const Dtype* X_acts_diff, Dtype* X_diff) {
+  CUDA_KERNEL_LOOP(index, nthreads) {
+    const int x_dim = 4 * dim;
+    const int d = index % x_dim;
+    const Dtype X_act = X_acts[index];
+    if (d < 3 * dim) {
+      X_diff[index] = X_acts_diff[index] * X_act * (Dtype(1) - X_act);
+    } else {
+      X_diff[index] = X_acts_diff[index] * (Dtype(1) - X_act * X_act);
+    }
+  }
+}
+
+template <typename Dtype>
+void LSTMUnitLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  CHECK(!propagate_down[2]) << "Cannot backpropagate to sequence indicators.";
+  if (!propagate_down[0] && !propagate_down[1]) { return; }
+
+  const int count = top[1]->count();
+  const Dtype* C_prev = bottom[0]->gpu_data();
+  const Dtype* X_acts = X_acts_.gpu_data();
+  const Dtype* cont = bottom[2]->gpu_data();
+  const Dtype* C = top[0]->gpu_data();
+  const Dtype* H = top[1]->gpu_data();
+  const Dtype* C_diff = top[0]->gpu_diff();
+  const Dtype* H_diff = top[1]->gpu_diff();
+  Dtype* C_prev_diff = bottom[0]->mutable_gpu_diff();
+  Dtype* X_acts_diff = X_acts_.mutable_gpu_diff();
+  LSTMUnitBackward<Dtype>  // NOLINT_NEXT_LINE(whitespace/operators)
+      <<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(count, hidden_dim_,
+      C_prev, X_acts, C, H, cont, C_diff, H_diff, C_prev_diff, X_acts_diff);
+  CUDA_POST_KERNEL_CHECK;
+  const int X_count = bottom[1]->count();
+  Dtype* X_diff = bottom[1]->mutable_gpu_diff();
+  LSTMActsBackward<Dtype>  // NOLINT_NEXT_LINE(whitespace/operators)
+      <<<CAFFE_GET_BLOCKS(X_count), CAFFE_CUDA_NUM_THREADS>>>(
+      X_count, hidden_dim_, X_acts, X_acts_diff, X_diff);
+  CUDA_POST_KERNEL_CHECK;
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(LSTMUnitLayer);
+
+}  // namespace caffe
diff --git a/src/caffe/test/test_lstm_layer.cpp b/src/caffe/test/test_lstm_layer.cpp
new file mode 100644 (file)
index 0000000..51905ba
--- /dev/null
@@ -0,0 +1,288 @@
+#include <cstring>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/lstm_layer.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+template <typename TypeParam>
+class LSTMLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  LSTMLayerTest() : num_output_(7) {
+    blob_bottom_vec_.push_back(&blob_bottom_);
+    blob_bottom_vec_.push_back(&blob_bottom_cont_);
+    blob_top_vec_.push_back(&blob_top_);
+    unit_blob_bottom_vec_.push_back(&unit_blob_bottom_c_prev_);
+    unit_blob_bottom_vec_.push_back(&unit_blob_bottom_x_);
+    unit_blob_bottom_vec_.push_back(&unit_blob_bottom_cont_);
+    unit_blob_top_vec_.push_back(&unit_blob_top_c_);
+    unit_blob_top_vec_.push_back(&unit_blob_top_h_);
+
+    ReshapeBlobs(1, 3);
+
+    layer_param_.mutable_recurrent_param()->set_num_output(num_output_);
+    FillerParameter* weight_filler =
+        layer_param_.mutable_recurrent_param()->mutable_weight_filler();
+    weight_filler->set_type("gaussian");
+    weight_filler->set_std(0.2);
+    FillerParameter* bias_filler =
+        layer_param_.mutable_recurrent_param()->mutable_bias_filler();
+    bias_filler->set_type("gaussian");
+    bias_filler->set_std(0.1);
+
+    layer_param_.set_phase(TEST);
+  }
+
+  void ReshapeBlobs(int num_timesteps, int num_instances) {
+    blob_bottom_.Reshape(num_timesteps, num_instances, 3, 2);
+    blob_bottom_static_.Reshape(num_instances, 2, 3, 4);
+    vector<int> shape(2);
+    shape[0] = num_timesteps;
+    shape[1] = num_instances;
+    blob_bottom_cont_.Reshape(shape);
+    shape.push_back(num_output_);
+
+    shape[0] = 1; shape[1] = num_instances; shape[2] = 4 * num_output_;
+    unit_blob_bottom_x_.Reshape(shape);
+    shape[0] = 1; shape[1] = num_instances; shape[2] = num_output_;
+    unit_blob_bottom_c_prev_.Reshape(shape);
+    shape.resize(2);
+    shape[0] = 1; shape[1] = num_instances;
+    unit_blob_bottom_cont_.Reshape(shape);
+
+    FillerParameter filler_param;
+    filler_param.set_min(-1);
+    filler_param.set_max(1);
+    UniformFiller<Dtype> filler(filler_param);
+    filler.Fill(&blob_bottom_);
+    filler.Fill(&unit_blob_bottom_c_prev_);
+    filler.Fill(&unit_blob_bottom_x_);
+  }
+
+  int num_output_;
+  LayerParameter layer_param_;
+  Blob<Dtype> blob_bottom_;
+  Blob<Dtype> blob_bottom_cont_;
+  Blob<Dtype> blob_bottom_static_;
+  Blob<Dtype> blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+
+  Blob<Dtype> unit_blob_bottom_cont_;
+  Blob<Dtype> unit_blob_bottom_c_prev_;
+  Blob<Dtype> unit_blob_bottom_x_;
+  Blob<Dtype> unit_blob_top_c_;
+  Blob<Dtype> unit_blob_top_h_;
+  vector<Blob<Dtype>*> unit_blob_bottom_vec_;
+  vector<Blob<Dtype>*> unit_blob_top_vec_;
+};
+
+TYPED_TEST_CASE(LSTMLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(LSTMLayerTest, TestSetUp) {
+  typedef typename TypeParam::Dtype Dtype;
+  LSTMLayer<Dtype> layer(this->layer_param_);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  vector<int> expected_top_shape = this->blob_bottom_.shape();
+  expected_top_shape.resize(3);
+  expected_top_shape[2] = this->num_output_;
+  EXPECT_TRUE(this->blob_top_.shape() == expected_top_shape);
+}
+
+TYPED_TEST(LSTMLayerTest, TestForward) {
+  typedef typename TypeParam::Dtype Dtype;
+  const int kNumTimesteps = 3;
+  const int num = this->blob_bottom_.shape(1);
+  this->ReshapeBlobs(kNumTimesteps, num);
+
+  // Fill the cont blob with <0, 1, 1, ..., 1>,
+  // indicating a sequence that begins at the first timestep
+  // then continues for the rest of the sequence.
+  for (int t = 0; t < kNumTimesteps; ++t) {
+    for (int n = 0; n < num; ++n) {
+      this->blob_bottom_cont_.mutable_cpu_data()[t * num + n] = t > 0;
+    }
+  }
+
+  // Process the full sequence in a single batch.
+  FillerParameter filler_param;
+  filler_param.set_mean(0);
+  filler_param.set_std(1);
+  GaussianFiller<Dtype> sequence_filler(filler_param);
+  Caffe::set_random_seed(1);
+  sequence_filler.Fill(&this->blob_bottom_);
+  shared_ptr<LSTMLayer<Dtype> > layer(new LSTMLayer<Dtype>(this->layer_param_));
+  Caffe::set_random_seed(1701);
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  LOG(INFO) << "Calling forward for full sequence LSTM";
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  // Copy the inputs and outputs to reuse/check them later.
+  Blob<Dtype> bottom_copy(this->blob_bottom_.shape());
+  bottom_copy.CopyFrom(this->blob_bottom_);
+  Blob<Dtype> top_copy(this->blob_top_.shape());
+  top_copy.CopyFrom(this->blob_top_);
+
+  // Process the batch one timestep at a time;
+  // check that we get the same result.
+  this->ReshapeBlobs(1, num);
+  layer.reset(new LSTMLayer<Dtype>(this->layer_param_));
+  Caffe::set_random_seed(1701);
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  const int bottom_count = this->blob_bottom_.count();
+  const int top_count = this->blob_top_.count();
+  const Dtype kEpsilon = 1e-5;
+  for (int t = 0; t < kNumTimesteps; ++t) {
+    caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count,
+               this->blob_bottom_.mutable_cpu_data());
+    for (int n = 0; n < num; ++n) {
+      this->blob_bottom_cont_.mutable_cpu_data()[n] = t > 0;
+    }
+    LOG(INFO) << "Calling forward for LSTM timestep " << t;
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    for (int i = 0; i < top_count; ++i) {
+      ASSERT_LT(t * top_count + i, top_copy.count());
+      EXPECT_NEAR(this->blob_top_.cpu_data()[i],
+                  top_copy.cpu_data()[t * top_count + i], kEpsilon)
+         << "t = " << t << "; i = " << i;
+    }
+  }
+
+  // Process the batch one timestep at a time with all cont blobs set to 0.
+  // Check that we get a different result, except in the first timestep.
+  Caffe::set_random_seed(1701);
+  layer.reset(new LSTMLayer<Dtype>(this->layer_param_));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  for (int t = 0; t < kNumTimesteps; ++t) {
+    caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count,
+               this->blob_bottom_.mutable_cpu_data());
+    for (int n = 0; n < num; ++n) {
+      this->blob_bottom_cont_.mutable_cpu_data()[n] = 0;
+    }
+    LOG(INFO) << "Calling forward for LSTM timestep " << t;
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    for (int i = 0; i < top_count; ++i) {
+      if (t == 0) {
+        EXPECT_NEAR(this->blob_top_.cpu_data()[i],
+                    top_copy.cpu_data()[t * top_count + i], kEpsilon)
+           << "t = " << t << "; i = " << i;
+      } else {
+        EXPECT_NE(this->blob_top_.cpu_data()[i],
+                  top_copy.cpu_data()[t * top_count + i])
+           << "t = " << t << "; i = " << i;
+      }
+    }
+  }
+}
+
+TYPED_TEST(LSTMLayerTest, TestLSTMUnitSetUp) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  LSTMUnitLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->unit_blob_bottom_vec_, this->unit_blob_top_vec_);
+  const int num_axes = this->unit_blob_bottom_c_prev_.num_axes();
+  ASSERT_EQ(num_axes, this->unit_blob_top_c_.num_axes());
+  ASSERT_EQ(num_axes, this->unit_blob_top_h_.num_axes());
+  for (int i = 0; i < num_axes; ++i) {
+    EXPECT_EQ(this->unit_blob_bottom_c_prev_.shape(i),
+              this->unit_blob_top_c_.shape(i));
+    EXPECT_EQ(this->unit_blob_bottom_c_prev_.shape(i),
+              this->unit_blob_top_h_.shape(i));
+  }
+}
+
+TYPED_TEST(LSTMLayerTest, TestLSTMUnitGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  LSTMUnitLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  Dtype* cont_data = this->blob_bottom_cont_.mutable_cpu_data();
+  cont_data[0] = 0;
+  cont_data[1] = 0;
+  cont_data[2] = 0;
+  checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_,
+      this->unit_blob_top_vec_, 0);
+  checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_,
+      this->unit_blob_top_vec_, 1);
+}
+
+TYPED_TEST(LSTMLayerTest, TestLSTMUnitGradientNonZeroCont) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  LSTMUnitLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  Dtype* cont_data = this->blob_bottom_cont_.mutable_cpu_data();
+  cont_data[0] = 1;
+  cont_data[1] = 0;
+  cont_data[2] = 1;
+  checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_,
+      this->unit_blob_top_vec_, 0);
+  checker.CheckGradientExhaustive(&layer, this->unit_blob_bottom_vec_,
+      this->unit_blob_top_vec_, 1);
+}
+
+TYPED_TEST(LSTMLayerTest, TestGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LSTMLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+}
+
+TYPED_TEST(LSTMLayerTest, TestGradientNonZeroCont) {
+  typedef typename TypeParam::Dtype Dtype;
+  LSTMLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) {
+    this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2;
+  }
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+}
+
+TYPED_TEST(LSTMLayerTest, TestGradientNonZeroContBufferSize2) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->ReshapeBlobs(2, 2);
+  FillerParameter filler_param;
+  UniformFiller<Dtype> filler(filler_param);
+  filler.Fill(&this->blob_bottom_);
+  LSTMLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) {
+    this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2;
+  }
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+}
+
+TYPED_TEST(LSTMLayerTest, TestGradientNonZeroContBufferSize2WithStaticInput) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->ReshapeBlobs(2, 2);
+  FillerParameter filler_param;
+  UniformFiller<Dtype> filler(filler_param);
+  filler.Fill(&this->blob_bottom_);
+  filler.Fill(&this->blob_bottom_static_);
+  this->blob_bottom_vec_.push_back(&this->blob_bottom_static_);
+  LSTMLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) {
+    this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2;
+  }
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 2);
+}
+
+
+}  // namespace caffe