BatchReindexLayer to shuffle, subsample, and replicate examples in a batch
authorCarl Doersch <cdoersch@cs.cmu.edu>
Mon, 24 Aug 2015 03:47:25 +0000 (20:47 -0700)
committerCarl Doersch <cdoersch@cs.cmu.edu>
Thu, 8 Oct 2015 01:41:11 +0000 (18:41 -0700)
include/caffe/common_layers.hpp
src/caffe/layers/batch_reindex_layer.cpp [new file with mode: 0644]
src/caffe/layers/batch_reindex_layer.cu [new file with mode: 0644]
src/caffe/test/test_batch_reindex_layer.cpp [new file with mode: 0644]

index d2c0ce6..5d68e86 100644 (file)
@@ -71,6 +71,75 @@ class ArgMaxLayer : public Layer<Dtype> {
 };
 
 /**
+ * @brief Index into the input blob along its first axis.
+ *
+ * This layer can be used to select, reorder, and even replicate examples in a
+ * batch.  The second blob is cast to int and treated as an index into the
+ * first axis of the first blob.
+ */
+template <typename Dtype>
+class BatchReindexLayer : public Layer<Dtype> {
+ public:
+  explicit BatchReindexLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  virtual inline const char* type() const { return "BatchReindex"; }
+  virtual inline int ExactNumBottomBlobs() const { return 2; }
+  virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+  /**
+   * @param bottom input Blob vector (length 2+)
+   *   -# @f$ (N \times ...) @f$
+   *      the inputs @f$ x_1 @f$
+   *   -# @f$ (M) @f$
+   *      the inputs @f$ x_2 @f$
+   * @param top output Blob vector (length 1)
+   *   -# @f$ (M \times ...) @f$:
+   *      the reindexed array @f$
+   *        y = x_1[x_2]
+   *      @f$
+   */
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  /**
+   * @brief Computes the error gradient w.r.t. the reordered input.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient
+   *        with respect to the outputs
+   *   -# @f$ (M \times ...) @f$:
+   *      containing error gradients @f$ \frac{\partial E}{\partial y} @f$
+   *      with respect to concatenated outputs @f$ y @f$
+   * @param propagate_down see Layer::Backward.
+   * @param bottom input Blob vector (length 2):
+   *   - @f$ \frac{\partial E}{\partial y} @f$ is de-indexed (summing where
+   *     required) back to the input x_1
+   *   - This layer cannot backprop to x_2, i.e. propagate_down[1] must be
+   *     false.
+   */
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ private:
+  struct pair_sort_first {
+    bool operator()(const std::pair<int, int> &left,
+                    const std::pair<int, int> &right) {
+      return left.first < right.first;
+    }
+  };
+  void check_batch_reindex(int initial_num, int final_num,
+                           const Dtype* ridx_data);
+};
+
+
+/**
  * @brief Takes at least two Blob%s and concatenates them along either the num
  *        or channel dimension, outputting the result.
  */
diff --git a/src/caffe/layers/batch_reindex_layer.cpp b/src/caffe/layers/batch_reindex_layer.cpp
new file mode 100644 (file)
index 0000000..3bf757c
--- /dev/null
@@ -0,0 +1,79 @@
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template<typename Dtype>
+void BatchReindexLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+                                       const vector<Blob<Dtype>*>& top) {
+  CHECK_EQ(1, bottom[1]->num_axes());
+  vector<int> newshape;
+  newshape.push_back(bottom[1]->shape(0));
+  for (int i = 1; i < bottom[0]->shape().size(); ++i) {
+    newshape.push_back(bottom[0]->shape()[i]);
+  }
+  top[0]->Reshape(newshape);
+}
+
+template<typename Dtype>
+void BatchReindexLayer<Dtype>::check_batch_reindex(int initial_num,
+                                                   int final_num,
+                                                   const Dtype* ridx_data) {
+  for (int i = 0; i < final_num; ++i) {
+    CHECK_GE(ridx_data[i], 0)
+        << "Index specified for reindex layer was negative.";
+    CHECK_LT(ridx_data[i], initial_num)
+        << "Index specified for reindex layer was greater than batch size.";
+  }
+}
+
+template<typename Dtype>
+void BatchReindexLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+                                           const vector<Blob<Dtype>*>& top) {
+  check_batch_reindex(bottom[0]->shape(0), bottom[1]->count(),
+                      bottom[1]->cpu_data());
+  if (top[0]->count() == 0) {
+    return;
+  }
+  int inner_dim = bottom[0]->count() / bottom[0]->shape(0);
+  const Dtype* in = bottom[0]->cpu_data();
+  const Dtype* permut = bottom[1]->cpu_data();
+  Dtype* out = top[0]->mutable_cpu_data();
+  for (int index = 0; index < top[0]->count(); ++index) {
+    int n = index / (inner_dim);
+    int in_n = static_cast<int>(permut[n]);
+    out[index] = in[in_n * (inner_dim) + index % (inner_dim)];
+  }
+}
+
+template<typename Dtype>
+void BatchReindexLayer<Dtype>::Backward_cpu(
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  CHECK(!propagate_down[1]) << "Cannot backprop to index.";
+  if (!propagate_down[0]) {
+    return;
+  }
+  int inner_dim = bottom[0]->count() / bottom[0]->shape(0);
+  Dtype* bot_diff = bottom[0]->mutable_cpu_diff();
+  const Dtype* permut = bottom[1]->cpu_data();
+  const Dtype* top_diff = top[0]->cpu_diff();
+  caffe_set(bottom[0]->count(), Dtype(0), bot_diff);
+  for (int index = 0; index < top[0]->count(); ++index) {
+    int n = index / (inner_dim);
+    int in_n = static_cast<int>(permut[n]);
+    bot_diff[in_n * (inner_dim) + index % (inner_dim)] += top_diff[index];
+  }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(BatchReindexLayer);
+#endif
+
+INSTANTIATE_CLASS(BatchReindexLayer);
+REGISTER_LAYER_CLASS(BatchReindex);
+
+}  // namespace caffe
diff --git a/src/caffe/layers/batch_reindex_layer.cu b/src/caffe/layers/batch_reindex_layer.cu
new file mode 100644 (file)
index 0000000..c418cab
--- /dev/null
@@ -0,0 +1,107 @@
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template<typename Dtype>
+__global__ void BRForward(const int count, const int inner_dim, const Dtype* in,
+                          const Dtype* permut, Dtype* out) {
+  CUDA_KERNEL_LOOP(index, count) {
+    int n = index / (inner_dim);
+    int in_n = static_cast<int>(permut[n]);
+    out[index] = in[in_n * (inner_dim) + index % (inner_dim)];
+  }
+}
+
+template<typename Dtype>
+void BatchReindexLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+                                           const vector<Blob<Dtype>*>& top) {
+  check_batch_reindex(bottom[0]->shape(0), bottom[1]->count(),
+                      bottom[1]->cpu_data());
+  if (top[0]->count() == 0) {
+    return;
+  }
+  int threads = top[0]->count();
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  BRForward<Dtype> <<<CAFFE_GET_BLOCKS(threads), CAFFE_CUDA_NUM_THREADS>>>(
+      top[0]->count(), bottom[0]->count() / bottom[0]->shape(0),
+      bottom[0]->gpu_data(), bottom[1]->gpu_data(), top[0]->mutable_gpu_data());
+  CUDA_POST_KERNEL_CHECK;
+}
+
+template<typename Dtype>
+__global__ void BRBackward(const int count, const int inner_dim,
+                           const Dtype* in, const Dtype* top_indexes,
+                           const Dtype* begins, const Dtype* counts,
+                           Dtype* out) {
+  CUDA_KERNEL_LOOP(index, count) {
+    int n = index / (inner_dim);
+    out[index] = 0;
+    int lower = static_cast<int>(begins[n]);
+    int upper = lower + static_cast<int>(counts[n]);
+    for (int i = lower; i < upper; ++i) {
+      int in_n = static_cast<int>(top_indexes[i]);
+      out[index] += in[in_n * (inner_dim) + index % (inner_dim)];
+    }
+  }
+}
+
+template<typename Dtype>
+void BatchReindexLayer<Dtype>::Backward_gpu(
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  CHECK(!propagate_down[1]) << "Cannot backprop to index.";
+  if (!propagate_down[0]) {
+    return;
+  }
+
+  vector<std::pair<int, int> > mapping;
+  const Dtype* perm = bottom[1]->cpu_data();
+  for (int i = 0; i < bottom[1]->count(); ++i) {
+    mapping.push_back(pair<int, int>(static_cast<int>(perm[i]), i));
+  }
+  std::sort(mapping.begin(), mapping.end(), pair_sort_first());
+
+  // Each element of the bottom diff is potentially the sum of many top diffs.
+  // However, we'd like each CUDA thread to handle exactly one output.  Hence,
+  // we first pre-compute a list of lists of indices that need to be summed for
+  // each output. `top_indexes` holds the data of this list of lists.  The
+  // k'th element of `begins` points to the location in `top_indexes` where the
+  // list for the k'th example begin, and the k'th element of `counts` is the
+  // length of that list.
+  vector<int> shape;
+  shape.push_back(bottom[1]->count());
+  Blob<Dtype> top_indexes(shape);
+  shape[0] = bottom[0]->shape(0);
+  Blob<Dtype> counts(shape);
+  Blob<Dtype> begins(shape);
+  Dtype* t_i_data = top_indexes.mutable_cpu_data();
+  Dtype* c_data = counts.mutable_cpu_data();
+  Dtype* b_data = begins.mutable_cpu_data();
+  caffe_set(begins.count(), Dtype(-1), b_data);
+  caffe_set(counts.count(), Dtype(0), c_data);
+  for (int i = 0; i < mapping.size(); ++i) {
+    t_i_data[i] = mapping[i].second;
+    if (b_data[mapping[i].first] == -1) {
+      b_data[mapping[i].first] = i;
+    }
+    c_data[mapping[i].first] += 1;
+  }
+
+  int threads = bottom[0]->count();
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  BRBackward<Dtype> <<<CAFFE_GET_BLOCKS(threads), CAFFE_CUDA_NUM_THREADS>>>(
+      bottom[0]->count(), bottom[0]->count() / bottom[0]->shape(0),
+      top[0]->gpu_diff(), top_indexes.gpu_data(), begins.gpu_data(),
+      counts.gpu_data(), bottom[0]->mutable_gpu_diff());
+  CUDA_POST_KERNEL_CHECK;
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(BatchReindexLayer);
+
+}  // namespace caffe
diff --git a/src/caffe/test/test_batch_reindex_layer.cpp b/src/caffe/test/test_batch_reindex_layer.cpp
new file mode 100644 (file)
index 0000000..985db34
--- /dev/null
@@ -0,0 +1,119 @@
+#include <cstring>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+template<typename TypeParam>
+class BatchReindexLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  BatchReindexLayerTest()
+      : blob_bottom_(new Blob<Dtype>()),
+        blob_bottom_permute_(new Blob<Dtype>()),
+        blob_top_(new Blob<Dtype>()) {
+  }
+  virtual void SetUp() {
+    Caffe::set_random_seed(1701);
+    vector<int> sz;
+    sz.push_back(5);
+    sz.push_back(4);
+    sz.push_back(3);
+    sz.push_back(2);
+    blob_bottom_->Reshape(sz);
+    vector<int> permsz;
+    permsz.push_back(6);
+    blob_bottom_permute_->Reshape(permsz);
+
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    int perm[] = { 4, 0, 4, 0, 1, 2 };
+    for (int i = 0; i < blob_bottom_permute_->count(); ++i) {
+      blob_bottom_permute_->mutable_cpu_data()[i] = perm[i];
+    }
+
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_permute_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~BatchReindexLayerTest() {
+    delete blob_bottom_permute_;
+    delete blob_bottom_;
+    delete blob_top_;
+  }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_bottom_permute_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+
+  void TestForward() {
+    LayerParameter layer_param;
+
+    vector<int> sz;
+    sz.push_back(5);
+    sz.push_back(4);
+    sz.push_back(3);
+    sz.push_back(2);
+    blob_bottom_->Reshape(sz);
+    for (int i = 0; i < blob_bottom_->count(); ++i) {
+      blob_bottom_->mutable_cpu_data()[i] = i;
+    }
+
+    vector<int> permsz;
+    permsz.push_back(6);
+    blob_bottom_permute_->Reshape(permsz);
+    int perm[] = { 4, 0, 4, 0, 1, 2 };
+    for (int i = 0; i < blob_bottom_permute_->count(); ++i) {
+      blob_bottom_permute_->mutable_cpu_data()[i] = perm[i];
+    }
+    BatchReindexLayer<Dtype> layer(layer_param);
+    layer.SetUp(blob_bottom_vec_, blob_top_vec_);
+    EXPECT_EQ(blob_top_->num(), blob_bottom_permute_->num());
+    EXPECT_EQ(blob_top_->channels(), blob_bottom_->channels());
+    EXPECT_EQ(blob_top_->height(), blob_bottom_->height());
+    EXPECT_EQ(blob_top_->width(), blob_bottom_->width());
+
+    layer.Forward(blob_bottom_vec_, blob_top_vec_);
+    int channels = blob_top_->channels();
+    int height = blob_top_->height();
+    int width = blob_top_->width();
+    for (int i = 0; i < blob_top_->count(); ++i) {
+      int n = i / (channels * width * height);
+      int inner_idx = (i % (channels * width * height));
+      EXPECT_EQ(
+          blob_top_->cpu_data()[i],
+          blob_bottom_->cpu_data()[perm[n] * channels * width * height
+              + inner_idx]);
+    }
+  }
+};
+
+TYPED_TEST_CASE(BatchReindexLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(BatchReindexLayerTest, TestForward) {
+  this->TestForward();
+}
+
+TYPED_TEST(BatchReindexLayerTest, TestGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  BatchReindexLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-4, 1e-2);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_, 0);
+  }
+
+}  // namespace caffe