Add ReductionLayer to reduce any number of "tail" axes to a scalar value
authorJeff Donahue <jeff.donahue@gmail.com>
Mon, 3 Nov 2014 01:21:37 +0000 (17:21 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Wed, 3 Jun 2015 03:47:29 +0000 (20:47 -0700)
Currently implements operations SUM, MEAN, ASUM (sum of absolute
values), and SUMSQ (sum of squares)

include/caffe/common_layers.hpp
src/caffe/layers/reduction_layer.cpp [new file with mode: 0644]
src/caffe/layers/reduction_layer.cu [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/test/test_reduction_layer.cpp [new file with mode: 0644]

index 3155b45..d2c0ce6 100644 (file)
@@ -400,6 +400,51 @@ class ReshapeLayer : public Layer<Dtype> {
 };
 
 /**
+ * @brief Compute "reductions" -- operations that return a scalar output Blob
+ *        for an input Blob of arbitrary size, such as the sum, absolute sum,
+ *        and sum of squares.
+ *
+ * TODO(dox): thorough documentation for Forward, Backward, and proto params.
+ */
+template <typename Dtype>
+class ReductionLayer : public Layer<Dtype> {
+ public:
+  explicit ReductionLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  virtual inline const char* type() const { return "Reduction"; }
+  virtual inline int ExactNumBottomBlobs() const { return 1; }
+  virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+  /// @brief the reduction operation performed by the layer
+  ReductionParameter_ReductionOp op_;
+  /// @brief a scalar coefficient applied to all outputs
+  Dtype coeff_;
+  /// @brief the index of the first input axis to reduce
+  int axis_;
+  /// @brief the number of reductions performed
+  int num_;
+  /// @brief the input size of each reduction
+  int dim_;
+  /// @brief a helper Blob used for summation (op_ == SUM)
+  Blob<Dtype> sum_multiplier_;
+};
+
+/**
  * @brief Ignores bottom blobs while producing no top blobs. (This is useful
  *        to suppress outputs during testing.)
  */
diff --git a/src/caffe/layers/reduction_layer.cpp b/src/caffe/layers/reduction_layer.cpp
new file mode 100644 (file)
index 0000000..8ae6329
--- /dev/null
@@ -0,0 +1,132 @@
+#include <algorithm>
+#include <cfloat>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ReductionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  op_ = this->layer_param_.reduction_param().operation();
+}
+
+template <typename Dtype>
+void ReductionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  axis_ = bottom[0]->CanonicalAxisIndex(
+      this->layer_param_.reduction_param().axis());
+  // In the output, we'll keep all axes up to the reduction axis, but
+  // throw away any after that.
+  // Note: currently reducing along non-tail axes is not supported; otherwise,
+  // we'd need to also copy any axes following an "end_axis".
+  vector<int> top_shape(bottom[0]->shape().begin(),
+                        bottom[0]->shape().begin() + axis_);
+  top[0]->Reshape(top_shape);
+  num_ = bottom[0]->count(0, axis_);
+  dim_ = bottom[0]->count(axis_);
+  CHECK_EQ(num_, top[0]->count());
+  if (op_ == ReductionParameter_ReductionOp_SUM ||
+      op_ == ReductionParameter_ReductionOp_MEAN) {
+    vector<int> sum_mult_shape(1, dim_);
+    sum_multiplier_.Reshape(sum_mult_shape);
+    caffe_set(dim_, Dtype(1), sum_multiplier_.mutable_cpu_data());
+  }
+  coeff_ = this->layer_param().reduction_param().coeff();
+  if (op_ == ReductionParameter_ReductionOp_MEAN) {
+    coeff_ /= dim_;
+  }
+}
+
+template <typename Dtype>
+void ReductionLayer<Dtype>::Forward_cpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  const Dtype* mult_data = NULL;
+  if (sum_multiplier_.count() > 0) {
+    mult_data = sum_multiplier_.cpu_data();
+  }
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  for (int i = 0; i < num_; ++i) {
+    switch (op_) {
+    case ReductionParameter_ReductionOp_SUM:
+    case ReductionParameter_ReductionOp_MEAN:
+      *top_data = caffe_cpu_dot(dim_, mult_data, bottom_data);
+      break;
+    case ReductionParameter_ReductionOp_ASUM:
+      *top_data = caffe_cpu_asum(dim_, bottom_data);
+      break;
+    case ReductionParameter_ReductionOp_SUMSQ:
+      *top_data = caffe_cpu_dot(dim_, bottom_data, bottom_data);
+      break;
+    default:
+      LOG(FATAL) << "Unknown reduction op: "
+          << ReductionParameter_ReductionOp_Name(op_);
+    }
+    bottom_data += dim_;
+    ++top_data;
+  }
+  if (coeff_ != Dtype(1)) {
+    // Reset the top_data pointer.
+    top_data = top[0]->mutable_cpu_data();
+    caffe_scal(num_, coeff_, top_data);
+  }
+}
+
+template <typename Dtype>
+void ReductionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+  if (!propagate_down[0]) { return; }
+  // Get bottom_data, if needed.
+  const Dtype* bottom_data = NULL;
+  switch (op_) {
+  // Operations that don't need bottom_data
+  case ReductionParameter_ReductionOp_SUM:
+  case ReductionParameter_ReductionOp_MEAN:
+    break;
+  // Operations that need bottom_data
+  case ReductionParameter_ReductionOp_ASUM:
+  case ReductionParameter_ReductionOp_SUMSQ:
+    bottom_data = bottom[0]->cpu_data();
+    break;
+  default:
+    LOG(FATAL) << "Unknown reduction op: "
+        << ReductionParameter_ReductionOp_Name(op_);
+  }
+  const Dtype* top_diff = top[0]->cpu_diff();
+  Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+  for (int i = 0; i < num_; ++i) {
+    const Dtype bottom_coeff = (*top_diff) * coeff_;
+    switch (op_) {
+    case ReductionParameter_ReductionOp_SUM:
+    case ReductionParameter_ReductionOp_MEAN:
+      caffe_set(dim_, bottom_coeff, bottom_diff);
+      break;
+    case ReductionParameter_ReductionOp_ASUM:
+      caffe_cpu_sign(dim_, bottom_data, bottom_diff);
+      caffe_scal(dim_, bottom_coeff, bottom_diff);
+      break;
+    case ReductionParameter_ReductionOp_SUMSQ:
+      caffe_cpu_scale(dim_, 2 * bottom_coeff, bottom_data, bottom_diff);
+      break;
+    default:
+      LOG(FATAL) << "Unknown reduction op: "
+          << ReductionParameter_ReductionOp_Name(op_);
+    }
+    bottom_data += dim_;
+    bottom_diff += dim_;
+    ++top_diff;
+  }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(ReductionLayer);
+#endif
+
+INSTANTIATE_CLASS(ReductionLayer);
+REGISTER_LAYER_CLASS(Reduction);
+
+}  // namespace caffe
diff --git a/src/caffe/layers/reduction_layer.cu b/src/caffe/layers/reduction_layer.cu
new file mode 100644 (file)
index 0000000..2dbd3bc
--- /dev/null
@@ -0,0 +1,93 @@
+#include <cfloat>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ReductionLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  const Dtype* mult_data = NULL;
+  if (sum_multiplier_.count() > 0) {
+    mult_data = sum_multiplier_.gpu_data();
+  }
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  for (int i = 0; i < num_; ++i) {
+    switch (op_) {
+    case ReductionParameter_ReductionOp_SUM:
+    case ReductionParameter_ReductionOp_MEAN:
+      caffe_gpu_dot(dim_, mult_data, bottom_data, top_data);
+      break;
+    case ReductionParameter_ReductionOp_ASUM:
+      caffe_gpu_asum(dim_, bottom_data, top_data);
+      break;
+    case ReductionParameter_ReductionOp_SUMSQ:
+      caffe_gpu_dot(dim_, bottom_data, bottom_data, top_data);
+      break;
+    default:
+      LOG(FATAL) << "Unknown reduction op: "
+          << ReductionParameter_ReductionOp_Name(op_);
+    }
+    bottom_data += dim_;
+    ++top_data;
+  }
+  if (coeff_ != Dtype(1)) {
+    // Reset the top_data pointer.
+    top_data = top[0]->mutable_gpu_data();
+    caffe_gpu_scal(num_, coeff_, top_data);
+  }
+}
+
+template <typename Dtype>
+void ReductionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+  if (!propagate_down[0]) { return; }
+  // Get bottom_data, if needed.
+  const Dtype* bottom_data = NULL;
+  switch (op_) {
+  // Operations that don't need bottom_data
+  case ReductionParameter_ReductionOp_SUM:
+  case ReductionParameter_ReductionOp_MEAN:
+    break;
+  // Operations that need bottom_data
+  case ReductionParameter_ReductionOp_ASUM:
+  case ReductionParameter_ReductionOp_SUMSQ:
+    bottom_data = bottom[0]->gpu_data();
+    break;
+  default:
+    LOG(FATAL) << "Unknown reduction op: "
+        << ReductionParameter_ReductionOp_Name(op_);
+  }
+  const Dtype* top_diff = top[0]->cpu_diff();
+  Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+  for (int i = 0; i < num_; ++i) {
+    const Dtype bottom_coeff = (*top_diff) * coeff_;
+    switch (op_) {
+    case ReductionParameter_ReductionOp_SUM:
+    case ReductionParameter_ReductionOp_MEAN:
+      caffe_gpu_set(dim_, bottom_coeff, bottom_diff);
+      break;
+    case ReductionParameter_ReductionOp_ASUM:
+      caffe_gpu_sign(dim_, bottom_data, bottom_diff);
+      caffe_gpu_scal(dim_, bottom_coeff, bottom_diff);
+      break;
+    case ReductionParameter_ReductionOp_SUMSQ:
+      caffe_gpu_scale(dim_, 2 * bottom_coeff, bottom_data, bottom_diff);
+      break;
+    default:
+      LOG(FATAL) << "Unknown reduction op: "
+          << ReductionParameter_ReductionOp_Name(op_);
+    }
+    bottom_data += dim_;
+    bottom_diff += dim_;
+    ++top_diff;
+  }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(ReductionLayer);
+
+}  // namespace caffe
index f79cf80..81a8c69 100644 (file)
@@ -269,7 +269,7 @@ message ParamSpec {
 // NOTE
 // Update the next available ID when you add a new LayerParameter field.
 //
-// LayerParameter next available layer-specific ID: 136 (last added: flatten_param)
+// LayerParameter next available layer-specific ID: 137 (last added: reduction_param)
 message LayerParameter {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
@@ -341,6 +341,7 @@ message LayerParameter {
   optional PowerParameter power_param = 122;
   optional PReLUParameter prelu_param = 131;
   optional PythonParameter python_param = 130;
+  optional ReductionParameter reduction_param = 136;
   optional ReLUParameter relu_param = 123;
   optional ReshapeParameter reshape_param = 133;
   optional SigmoidParameter sigmoid_param = 124;
@@ -704,6 +705,36 @@ message PythonParameter {
   optional string layer = 2;
 }
 
+// Message that stores parameters used by ReductionLayer
+message ReductionParameter {
+  enum ReductionOp {
+    SUM = 1;
+    ASUM = 2;
+    SUMSQ = 3;
+    MEAN = 4;
+  }
+
+  optional ReductionOp operation = 1 [default = SUM]; // reduction operation
+
+  // The first axis to reduce to a scalar -- may be negative to index from the
+  // end (e.g., -1 for the last axis).
+  // (Currently, only reduction along ALL "tail" axes is supported; reduction
+  // of axis M through N, where N < num_axes - 1, is unsupported.)
+  // Suppose we have an n-axis bottom Blob with shape:
+  //     (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)).
+  // If axis == m, the output Blob will have shape
+  //     (d0, d1, d2, ..., d(m-1)),
+  // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1))
+  // times, each including (dm * d(m+1) * ... * d(n-1)) individual data.
+  // If axis == 0 (the default), the output Blob always has the empty shape
+  // (count 1), performing reduction across the entire input --
+  // often useful for creating new loss functions.
+  optional int32 axis = 2 [default = 0];
+
+  optional float coeff = 3 [default = 1.0]; // coefficient for output
+}
+
+// Message that stores parameters used by ReLULayer
 message ReLUParameter {
   // Allow non-zero slope for negative inputs to speed up optimization
   // Described in:
diff --git a/src/caffe/test/test_reduction_layer.cpp b/src/caffe/test/test_reduction_layer.cpp
new file mode 100644 (file)
index 0000000..f568a18
--- /dev/null
@@ -0,0 +1,297 @@
+#include <algorithm>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+template <typename TypeParam>
+class ReductionLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  ReductionLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+        blob_top_(new Blob<Dtype>()) {
+    // fill the values
+    Caffe::set_random_seed(1701);
+    FillerParameter filler_param;
+    UniformFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~ReductionLayerTest() {
+    delete blob_bottom_;
+    delete blob_top_;
+  }
+
+  void TestForward(ReductionParameter_ReductionOp op,
+                   float coeff = 1, int axis = 0) {
+    LayerParameter layer_param;
+    ReductionParameter* reduction_param = layer_param.mutable_reduction_param();
+    reduction_param->set_operation(op);
+    if (coeff != 1.0) { reduction_param->set_coeff(coeff); }
+    if (axis != 0) { reduction_param->set_axis(axis); }
+    shared_ptr<ReductionLayer<Dtype> > layer(
+        new ReductionLayer<Dtype>(layer_param));
+    layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    const Dtype* in_data = this->blob_bottom_->cpu_data();
+    const int num = this->blob_bottom_->count(0, axis);
+    const int dim = this->blob_bottom_->count(axis);
+    for (int n = 0; n < num; ++n) {
+      Dtype expected_result = 0;
+      for (int d = 0; d < dim; ++d) {
+        switch (op) {
+          case ReductionParameter_ReductionOp_SUM:
+            expected_result += *in_data;
+            break;
+          case ReductionParameter_ReductionOp_MEAN:
+            expected_result += *in_data / dim;
+            break;
+          case ReductionParameter_ReductionOp_ASUM:
+            expected_result += fabs(*in_data);
+            break;
+          case ReductionParameter_ReductionOp_SUMSQ:
+            expected_result += (*in_data) * (*in_data);
+            break;
+          default:
+            LOG(FATAL) << "Unknown reduction op: "
+                << ReductionParameter_ReductionOp_Name(op);
+        }
+        ++in_data;
+      }
+      expected_result *= coeff;
+      const Dtype computed_result = this->blob_top_->cpu_data()[n];
+      EXPECT_FLOAT_EQ(expected_result, computed_result)
+          << "Incorrect result computed with op "
+          << ReductionParameter_ReductionOp_Name(op) << ", coeff " << coeff;
+    }
+  }
+
+  void TestGradient(ReductionParameter_ReductionOp op,
+                    float coeff = 1, int axis = 0) {
+    typedef typename TypeParam::Dtype Dtype;
+    LayerParameter layer_param;
+    ReductionParameter* reduction_param = layer_param.mutable_reduction_param();
+    reduction_param->set_operation(op);
+    reduction_param->set_coeff(coeff);
+    reduction_param->set_axis(axis);
+    ReductionLayer<Dtype> layer(layer_param);
+    GradientChecker<Dtype> checker(1e-2, 2e-3);
+    checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+        this->blob_top_vec_);
+  }
+
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(ReductionLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(ReductionLayerTest, TestSetUp) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  shared_ptr<ReductionLayer<Dtype> > layer(
+      new ReductionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  ASSERT_EQ(this->blob_top_->num_axes(), 0);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSetUpWithAxis1) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_reduction_param()->set_axis(1);
+  shared_ptr<ReductionLayer<Dtype> > layer(
+      new ReductionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  ASSERT_EQ(this->blob_top_->num_axes(), 1);
+  EXPECT_EQ(this->blob_top_->shape(0), 2);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSetUpWithAxis2) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_reduction_param()->set_axis(2);
+  shared_ptr<ReductionLayer<Dtype> > layer(
+      new ReductionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  ASSERT_EQ(this->blob_top_->num_axes(), 2);
+  EXPECT_EQ(this->blob_top_->shape(0), 2);
+  EXPECT_EQ(this->blob_top_->shape(1), 3);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSum) {
+  const ReductionParameter_ReductionOp kOp = ReductionParameter_ReductionOp_SUM;
+  this->TestForward(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumCoeff) {
+  const ReductionParameter_ReductionOp kOp = ReductionParameter_ReductionOp_SUM;
+  const float kCoeff = 2.3;
+  this->TestForward(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumCoeffAxis1) {
+  const ReductionParameter_ReductionOp kOp = ReductionParameter_ReductionOp_SUM;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestForward(kOp, kCoeff, kAxis);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumGradient) {
+  const ReductionParameter_ReductionOp kOp = ReductionParameter_ReductionOp_SUM;
+  this->TestGradient(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumCoeffGradient) {
+  const ReductionParameter_ReductionOp kOp = ReductionParameter_ReductionOp_SUM;
+  const float kCoeff = 2.3;
+  this->TestGradient(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumCoeffAxis1Gradient) {
+  const ReductionParameter_ReductionOp kOp = ReductionParameter_ReductionOp_SUM;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestGradient(kOp, kCoeff, kAxis);
+}
+
+TYPED_TEST(ReductionLayerTest, TestMean) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_MEAN;
+  this->TestForward(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestMeanCoeff) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_MEAN;
+  const float kCoeff = 2.3;
+  this->TestForward(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestMeanCoeffAxis1) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_MEAN;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestForward(kOp, kCoeff, kAxis);
+}
+
+TYPED_TEST(ReductionLayerTest, TestMeanGradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_MEAN;
+  this->TestGradient(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestMeanCoeffGradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_MEAN;
+  const float kCoeff = 2.3;
+  this->TestGradient(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestMeanCoeffGradientAxis1) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_MEAN;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestGradient(kOp, kCoeff, kAxis);
+}
+
+TYPED_TEST(ReductionLayerTest, TestAbsSum) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_ASUM;
+  this->TestForward(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestAbsSumCoeff) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_ASUM;
+  const float kCoeff = 2.3;
+  this->TestForward(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestAbsSumCoeffAxis1) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_ASUM;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestForward(kOp, kCoeff, kAxis);
+}
+
+TYPED_TEST(ReductionLayerTest, TestAbsSumGradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_ASUM;
+  this->TestGradient(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestAbsSumCoeffGradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_ASUM;
+  const float kCoeff = 2.3;
+  this->TestGradient(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestAbsSumCoeffAxis1Gradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_ASUM;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestGradient(kOp, kCoeff, kAxis);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumOfSquares) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_SUMSQ;
+  this->TestForward(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumOfSquaresCoeff) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_SUMSQ;
+  const float kCoeff = 2.3;
+  this->TestForward(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumOfSquaresCoeffAxis1) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_SUMSQ;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestForward(kOp, kCoeff, kAxis);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumOfSquaresGradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_SUMSQ;
+  this->TestGradient(kOp);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumOfSquaresCoeffGradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_SUMSQ;
+  const float kCoeff = 2.3;
+  this->TestGradient(kOp, kCoeff);
+}
+
+TYPED_TEST(ReductionLayerTest, TestSumOfSquaresCoeffAxis1Gradient) {
+  const ReductionParameter_ReductionOp kOp =
+      ReductionParameter_ReductionOp_SUMSQ;
+  const float kCoeff = 2.3;
+  const int kAxis = 1;
+  this->TestGradient(kOp, kCoeff, kAxis);
+}
+
+}  // namespace caffe