Add RNNLayer, with tests
authorJeff Donahue <jeff.donahue@gmail.com>
Sun, 15 Feb 2015 22:56:50 +0000 (14:56 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Wed, 1 Jun 2016 22:29:52 +0000 (15:29 -0700)
include/caffe/layers/rnn_layer.hpp [new file with mode: 0644]
src/caffe/layers/rnn_layer.cpp [new file with mode: 0644]
src/caffe/test/test_rnn_layer.cpp [new file with mode: 0644]

diff --git a/include/caffe/layers/rnn_layer.hpp b/include/caffe/layers/rnn_layer.hpp
new file mode 100644 (file)
index 0000000..6dce238
--- /dev/null
@@ -0,0 +1,47 @@
+#ifndef CAFFE_RNN_LAYER_HPP_
+#define CAFFE_RNN_LAYER_HPP_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/layers/recurrent_layer.hpp"
+#include "caffe/net.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype> class RecurrentLayer;
+
+/**
+ * @brief Processes time-varying inputs using a simple recurrent neural network
+ *        (RNN). Implemented as a network unrolling the RNN computation in time.
+ *
+ * Given time-varying inputs @f$ x_t @f$, computes hidden state @f$
+ *     h_t := \tanh[ W_{hh} h_{t_1} + W_{xh} x_t + b_h ]
+ * @f$, and outputs @f$
+ *     o_t := \tanh[ W_{ho} h_t + b_o ]
+ * @f$.
+ */
+template <typename Dtype>
+class RNNLayer : public RecurrentLayer<Dtype> {
+ public:
+  explicit RNNLayer(const LayerParameter& param)
+      : RecurrentLayer<Dtype>(param) {}
+
+  virtual inline const char* type() const { return "RNN"; }
+
+ protected:
+  virtual void FillUnrolledNet(NetParameter* net_param) const;
+  virtual void RecurrentInputBlobNames(vector<string>* names) const;
+  virtual void RecurrentOutputBlobNames(vector<string>* names) const;
+  virtual void RecurrentInputShapes(vector<BlobShape>* shapes) const;
+  virtual void OutputBlobNames(vector<string>* names) const;
+};
+
+}  // namespace caffe
+
+#endif  // CAFFE_RNN_LAYER_HPP_
diff --git a/src/caffe/layers/rnn_layer.cpp b/src/caffe/layers/rnn_layer.cpp
new file mode 100644 (file)
index 0000000..f62ae8c
--- /dev/null
@@ -0,0 +1,236 @@
+#include <string>
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/layers/rnn_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void RNNLayer<Dtype>::RecurrentInputBlobNames(vector<string>* names) const {
+  names->resize(1);
+  (*names)[0] = "h_0";
+}
+
+template <typename Dtype>
+void RNNLayer<Dtype>::RecurrentOutputBlobNames(vector<string>* names) const {
+  names->resize(1);
+  (*names)[0] = "h_" + format_int(this->T_);
+}
+
+template <typename Dtype>
+void RNNLayer<Dtype>::RecurrentInputShapes(vector<BlobShape>* shapes) const {
+  const int num_output = this->layer_param_.recurrent_param().num_output();
+  shapes->resize(1);
+  (*shapes)[0].Clear();
+  (*shapes)[0].add_dim(1);  // a single timestep
+  (*shapes)[0].add_dim(this->N_);
+  (*shapes)[0].add_dim(num_output);
+}
+
+template <typename Dtype>
+void RNNLayer<Dtype>::OutputBlobNames(vector<string>* names) const {
+  names->resize(1);
+  (*names)[0] = "o";
+}
+
+template <typename Dtype>
+void RNNLayer<Dtype>::FillUnrolledNet(NetParameter* net_param) const {
+  const int num_output = this->layer_param_.recurrent_param().num_output();
+  CHECK_GT(num_output, 0) << "num_output must be positive";
+  const FillerParameter& weight_filler =
+      this->layer_param_.recurrent_param().weight_filler();
+  const FillerParameter& bias_filler =
+      this->layer_param_.recurrent_param().bias_filler();
+
+  // Add generic LayerParameter's (without bottoms/tops) of layer types we'll
+  // use to save redundant code.
+  LayerParameter hidden_param;
+  hidden_param.set_type("InnerProduct");
+  hidden_param.mutable_inner_product_param()->set_num_output(num_output);
+  hidden_param.mutable_inner_product_param()->set_bias_term(false);
+  hidden_param.mutable_inner_product_param()->set_axis(2);
+  hidden_param.mutable_inner_product_param()->
+      mutable_weight_filler()->CopyFrom(weight_filler);
+
+  LayerParameter biased_hidden_param(hidden_param);
+  biased_hidden_param.mutable_inner_product_param()->set_bias_term(true);
+  biased_hidden_param.mutable_inner_product_param()->
+      mutable_bias_filler()->CopyFrom(bias_filler);
+
+  LayerParameter sum_param;
+  sum_param.set_type("Eltwise");
+  sum_param.mutable_eltwise_param()->set_operation(
+      EltwiseParameter_EltwiseOp_SUM);
+
+  LayerParameter tanh_param;
+  tanh_param.set_type("TanH");
+
+  LayerParameter scale_param;
+  scale_param.set_type("Scale");
+  scale_param.mutable_scale_param()->set_axis(0);
+
+  LayerParameter slice_param;
+  slice_param.set_type("Slice");
+  slice_param.mutable_slice_param()->set_axis(0);
+
+  vector<BlobShape> input_shapes;
+  RecurrentInputShapes(&input_shapes);
+  CHECK_EQ(1, input_shapes.size());
+
+  LayerParameter* input_layer_param = net_param->add_layer();
+  input_layer_param->set_type("Input");
+  InputParameter* input_param = input_layer_param->mutable_input_param();
+  input_layer_param->add_top("h_0");
+  input_param->add_shape()->CopyFrom(input_shapes[0]);
+
+  LayerParameter* cont_slice_param = net_param->add_layer();
+  cont_slice_param->CopyFrom(slice_param);
+  cont_slice_param->set_name("cont_slice");
+  cont_slice_param->add_bottom("cont");
+  cont_slice_param->mutable_slice_param()->set_axis(0);
+
+  // Add layer to transform all timesteps of x to the hidden state dimension.
+  //     W_xh_x = W_xh * x + b_h
+  {
+    LayerParameter* x_transform_param = net_param->add_layer();
+    x_transform_param->CopyFrom(biased_hidden_param);
+    x_transform_param->set_name("x_transform");
+    x_transform_param->add_param()->set_name("W_xh");
+    x_transform_param->add_param()->set_name("b_h");
+    x_transform_param->add_bottom("x");
+    x_transform_param->add_top("W_xh_x");
+    x_transform_param->add_propagate_down(true);
+  }
+
+  if (this->static_input_) {
+    // Add layer to transform x_static to the hidden state dimension.
+    //     W_xh_x_static = W_xh_static * x_static
+    LayerParameter* x_static_transform_param = net_param->add_layer();
+    x_static_transform_param->CopyFrom(hidden_param);
+    x_static_transform_param->mutable_inner_product_param()->set_axis(1);
+    x_static_transform_param->set_name("W_xh_x_static");
+    x_static_transform_param->add_param()->set_name("W_xh_static");
+    x_static_transform_param->add_bottom("x_static");
+    x_static_transform_param->add_top("W_xh_x_static_preshape");
+    x_static_transform_param->add_propagate_down(true);
+
+    LayerParameter* reshape_param = net_param->add_layer();
+    reshape_param->set_type("Reshape");
+    BlobShape* new_shape =
+         reshape_param->mutable_reshape_param()->mutable_shape();
+    new_shape->add_dim(1);  // One timestep.
+    // Should infer this->N as the dimension so we can reshape on batch size.
+    new_shape->add_dim(-1);
+    new_shape->add_dim(
+        x_static_transform_param->inner_product_param().num_output());
+    reshape_param->set_name("W_xh_x_static_reshape");
+    reshape_param->add_bottom("W_xh_x_static_preshape");
+    reshape_param->add_top("W_xh_x_static");
+  }
+
+  LayerParameter* x_slice_param = net_param->add_layer();
+  x_slice_param->CopyFrom(slice_param);
+  x_slice_param->set_name("W_xh_x_slice");
+  x_slice_param->add_bottom("W_xh_x");
+
+  LayerParameter output_concat_layer;
+  output_concat_layer.set_name("o_concat");
+  output_concat_layer.set_type("Concat");
+  output_concat_layer.add_top("o");
+  output_concat_layer.mutable_concat_param()->set_axis(0);
+
+  for (int t = 1; t <= this->T_; ++t) {
+    string tm1s = format_int(t - 1);
+    string ts = format_int(t);
+
+    cont_slice_param->add_top("cont_" + ts);
+    x_slice_param->add_top("W_xh_x_" + ts);
+
+    // Add layer to flush the hidden state when beginning a new sequence,
+    // as indicated by cont_t.
+    //     h_conted_{t-1} := cont_t * h_{t-1}
+    //
+    // Normally, cont_t is binary (i.e., 0 or 1), so:
+    //     h_conted_{t-1} := h_{t-1} if cont_t == 1
+    //                       0   otherwise
+    {
+      LayerParameter* cont_h_param = net_param->add_layer();
+      cont_h_param->CopyFrom(scale_param);
+      cont_h_param->set_name("h_conted_" + tm1s);
+      cont_h_param->add_bottom("h_" + tm1s);
+      cont_h_param->add_bottom("cont_" + ts);
+      cont_h_param->add_top("h_conted_" + tm1s);
+    }
+
+    // Add layer to compute
+    //     W_hh_h_{t-1} := W_hh * h_conted_{t-1}
+    {
+      LayerParameter* w_param = net_param->add_layer();
+      w_param->CopyFrom(hidden_param);
+      w_param->set_name("W_hh_h_" + tm1s);
+      w_param->add_param()->set_name("W_hh");
+      w_param->add_bottom("h_conted_" + tm1s);
+      w_param->add_top("W_hh_h_" + tm1s);
+      w_param->mutable_inner_product_param()->set_axis(2);
+    }
+
+    // Add layers to compute
+    //     h_t := \tanh( W_hh * h_conted_{t-1} + W_xh * x_t + b_h )
+    //          = \tanh( W_hh_h_{t-1} + W_xh_t )
+    {
+      LayerParameter* h_input_sum_param = net_param->add_layer();
+      h_input_sum_param->CopyFrom(sum_param);
+      h_input_sum_param->set_name("h_input_sum_" + ts);
+      h_input_sum_param->add_bottom("W_hh_h_" + tm1s);
+      h_input_sum_param->add_bottom("W_xh_x_" + ts);
+      if (this->static_input_) {
+        h_input_sum_param->add_bottom("W_xh_x_static");
+      }
+      h_input_sum_param->add_top("h_neuron_input_" + ts);
+    }
+    {
+      LayerParameter* h_neuron_param = net_param->add_layer();
+      h_neuron_param->CopyFrom(tanh_param);
+      h_neuron_param->set_name("h_neuron_" + ts);
+      h_neuron_param->add_bottom("h_neuron_input_" + ts);
+      h_neuron_param->add_top("h_" + ts);
+    }
+
+    // Add layer to compute
+    //     W_ho_h_t := W_ho * h_t + b_o
+    {
+      LayerParameter* w_param = net_param->add_layer();
+      w_param->CopyFrom(biased_hidden_param);
+      w_param->set_name("W_ho_h_" + ts);
+      w_param->add_param()->set_name("W_ho");
+      w_param->add_param()->set_name("b_o");
+      w_param->add_bottom("h_" + ts);
+      w_param->add_top("W_ho_h_" + ts);
+      w_param->mutable_inner_product_param()->set_axis(2);
+    }
+
+    // Add layers to compute
+    //     o_t := \tanh( W_ho h_t + b_o)
+    //          = \tanh( W_ho_h_t )
+    {
+      LayerParameter* o_neuron_param = net_param->add_layer();
+      o_neuron_param->CopyFrom(tanh_param);
+      o_neuron_param->set_name("o_neuron_" + ts);
+      o_neuron_param->add_bottom("W_ho_h_" + ts);
+      o_neuron_param->add_top("o_" + ts);
+    }
+    output_concat_layer.add_bottom("o_" + ts);
+  }  // for (int t = 1; t <= this->T_; ++t)
+
+  net_param->add_layer()->CopyFrom(output_concat_layer);
+}
+
+INSTANTIATE_CLASS(RNNLayer);
+REGISTER_LAYER_CLASS(RNN);
+
+}  // namespace caffe
diff --git a/src/caffe/test/test_rnn_layer.cpp b/src/caffe/test/test_rnn_layer.cpp
new file mode 100644 (file)
index 0000000..dd8952d
--- /dev/null
@@ -0,0 +1,217 @@
+#include <cstring>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/rnn_layer.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+template <typename TypeParam>
+class RNNLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  RNNLayerTest() : num_output_(7) {
+    blob_bottom_vec_.push_back(&blob_bottom_);
+    blob_bottom_vec_.push_back(&blob_bottom_cont_);
+    blob_top_vec_.push_back(&blob_top_);
+
+    ReshapeBlobs(1, 3);
+
+    layer_param_.mutable_recurrent_param()->set_num_output(num_output_);
+    FillerParameter* weight_filler =
+        layer_param_.mutable_recurrent_param()->mutable_weight_filler();
+    weight_filler->set_type("gaussian");
+    weight_filler->set_std(0.2);
+    FillerParameter* bias_filler =
+        layer_param_.mutable_recurrent_param()->mutable_bias_filler();
+    bias_filler->set_type("gaussian");
+    bias_filler->set_std(0.1);
+
+    layer_param_.set_phase(TEST);
+  }
+
+  void ReshapeBlobs(int num_timesteps, int num_instances) {
+    blob_bottom_.Reshape(num_timesteps, num_instances, 3, 2);
+    blob_bottom_static_.Reshape(num_instances, 2, 3, 4);
+    vector<int> shape(2);
+    shape[0] = num_timesteps;
+    shape[1] = num_instances;
+    blob_bottom_cont_.Reshape(shape);
+
+    FillerParameter filler_param;
+    filler_param.set_min(-1);
+    filler_param.set_max(1);
+    UniformFiller<Dtype> filler(filler_param);
+    filler.Fill(&blob_bottom_);
+  }
+
+  int num_output_;
+  LayerParameter layer_param_;
+  Blob<Dtype> blob_bottom_;
+  Blob<Dtype> blob_bottom_cont_;
+  Blob<Dtype> blob_bottom_static_;
+  Blob<Dtype> blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(RNNLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(RNNLayerTest, TestSetUp) {
+  typedef typename TypeParam::Dtype Dtype;
+  RNNLayer<Dtype> layer(this->layer_param_);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  vector<int> expected_top_shape = this->blob_bottom_.shape();
+  expected_top_shape.resize(3);
+  expected_top_shape[2] = this->num_output_;
+  EXPECT_TRUE(this->blob_top_.shape() == expected_top_shape);
+}
+
+TYPED_TEST(RNNLayerTest, TestForward) {
+  typedef typename TypeParam::Dtype Dtype;
+  const int kNumTimesteps = 3;
+  const int num = this->blob_bottom_.shape(1);
+  this->ReshapeBlobs(kNumTimesteps, num);
+
+  // Fill the cont blob with <0, 1, 1, ..., 1>,
+  // indicating a sequence that begins at the first timestep
+  // then continues for the rest of the sequence.
+  for (int t = 0; t < kNumTimesteps; ++t) {
+    for (int n = 0; n < num; ++n) {
+      this->blob_bottom_cont_.mutable_cpu_data()[t * num + n] = t > 0;
+    }
+  }
+
+  // Process the full sequence in a single batch.
+  FillerParameter filler_param;
+  filler_param.set_mean(0);
+  filler_param.set_std(1);
+  GaussianFiller<Dtype> sequence_filler(filler_param);
+  sequence_filler.Fill(&this->blob_bottom_);
+  shared_ptr<RNNLayer<Dtype> > layer(new RNNLayer<Dtype>(this->layer_param_));
+  Caffe::set_random_seed(1701);
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  LOG(INFO) << "Calling forward for full sequence RNN";
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  // Copy the inputs and outputs to reuse/check them later.
+  Blob<Dtype> bottom_copy(this->blob_bottom_.shape());
+  bottom_copy.CopyFrom(this->blob_bottom_);
+  Blob<Dtype> top_copy(this->blob_top_.shape());
+  top_copy.CopyFrom(this->blob_top_);
+
+  // Process the batch one timestep at a time;
+  // check that we get the same result.
+  this->ReshapeBlobs(1, num);
+  layer.reset(new RNNLayer<Dtype>(this->layer_param_));
+  Caffe::set_random_seed(1701);
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  const int bottom_count = this->blob_bottom_.count();
+  const int top_count = this->blob_top_.count();
+  const Dtype kEpsilon = 1e-5;
+  for (int t = 0; t < kNumTimesteps; ++t) {
+    caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count,
+               this->blob_bottom_.mutable_cpu_data());
+    for (int n = 0; n < num; ++n) {
+      this->blob_bottom_cont_.mutable_cpu_data()[n] = t > 0;
+    }
+    LOG(INFO) << "Calling forward for RNN timestep " << t;
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    for (int i = 0; i < top_count; ++i) {
+      ASSERT_LT(t * top_count + i, top_copy.count());
+      EXPECT_NEAR(this->blob_top_.cpu_data()[i],
+                  top_copy.cpu_data()[t * top_count + i], kEpsilon)
+         << "t = " << t << "; i = " << i;
+    }
+  }
+
+  // Process the batch one timestep at a time with all cont blobs set to 0.
+  // Check that we get a different result, except in the first timestep.
+  Caffe::set_random_seed(1701);
+  layer.reset(new RNNLayer<Dtype>(this->layer_param_));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  for (int t = 0; t < kNumTimesteps; ++t) {
+    caffe_copy(bottom_count, bottom_copy.cpu_data() + t * bottom_count,
+               this->blob_bottom_.mutable_cpu_data());
+    for (int n = 0; n < num; ++n) {
+      this->blob_bottom_cont_.mutable_cpu_data()[n] = 0;
+    }
+    LOG(INFO) << "Calling forward for RNN timestep " << t;
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    for (int i = 0; i < top_count; ++i) {
+      if (t == 0) {
+        EXPECT_NEAR(this->blob_top_.cpu_data()[i],
+                    top_copy.cpu_data()[t * top_count + i], kEpsilon)
+           << "t = " << t << "; i = " << i;
+      } else {
+        EXPECT_NE(this->blob_top_.cpu_data()[i],
+                  top_copy.cpu_data()[t * top_count + i])
+           << "t = " << t << "; i = " << i;
+      }
+    }
+  }
+}
+
+TYPED_TEST(RNNLayerTest, TestGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  RNNLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+}
+
+TYPED_TEST(RNNLayerTest, TestGradientNonZeroCont) {
+  typedef typename TypeParam::Dtype Dtype;
+  RNNLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) {
+    this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2;
+  }
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+}
+
+TYPED_TEST(RNNLayerTest, TestGradientNonZeroContBufferSize2) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->ReshapeBlobs(2, 2);
+  // fill the values
+  FillerParameter filler_param;
+  UniformFiller<Dtype> filler(filler_param);
+  filler.Fill(&this->blob_bottom_);
+  RNNLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) {
+    this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2;
+  }
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+}
+
+TYPED_TEST(RNNLayerTest, TestGradientNonZeroContBufferSize2WithStaticInput) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->ReshapeBlobs(2, 2);
+  FillerParameter filler_param;
+  UniformFiller<Dtype> filler(filler_param);
+  filler.Fill(&this->blob_bottom_);
+  filler.Fill(&this->blob_bottom_static_);
+  this->blob_bottom_vec_.push_back(&this->blob_bottom_static_);
+  RNNLayer<Dtype> layer(this->layer_param_);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  for (int i = 0; i < this->blob_bottom_cont_.count(); ++i) {
+    this->blob_bottom_cont_.mutable_cpu_data()[i] = i > 2;
+  }
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 2);
+}
+
+}  // namespace caffe