test net reshaping
authorJonathan L Long <jonlong@cs.berkeley.edu>
Wed, 2 Jul 2014 19:21:10 +0000 (12:21 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 18 Sep 2014 20:17:43 +0000 (13:17 -0700)
src/caffe/test/test_net.cpp

index 24f72ae..9b10d10 100644 (file)
@@ -7,6 +7,7 @@
 #include "gtest/gtest.h"
 
 #include "caffe/common.hpp"
+#include "caffe/filler.hpp"
 #include "caffe/net.hpp"
 #include "caffe/util/math_functions.hpp"
 
@@ -533,6 +534,68 @@ class NetTest : public MultiDeviceTest<TypeParam> {
     InitNetFromProtoString(proto);
   }
 
+  virtual void InitReshapableNet() {
+    const string& proto =
+        "name: 'ReshapableNetwork' "
+        "input: 'data' "
+        "input_dim: 1 "
+        "input_dim: 3 "
+        "input_dim: 100 "
+        "input_dim: 100 "
+        "layers: { "
+        "  name: 'conv1' "
+        "  type: CONVOLUTION "
+        "  bottom: 'data' "
+        "  top: 'conv1' "
+        "  convolution_param { "
+        "    num_output: 5 "
+        "    kernel_size: 3 "
+        "    stride: 2 "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
+        "    bias_filler { "
+        "      type: 'constant' "
+        "      value: 0.2 "
+        "    } "
+        "  } "
+        "} "
+        "layers: { "
+        "  name: 'relu1' "
+        "  type: RELU "
+        "  bottom: 'conv1' "
+        "  top: 'conv1' "
+        "} "
+        "layers: { "
+        "  name: 'pool1' "
+        "  type: POOLING "
+        "  bottom: 'conv1' "
+        "  top: 'pool1' "
+        "  pooling_param { "
+        "    pool: MAX "
+        "    kernel_size: 2 "
+        "    stride: 2 "
+        "  } "
+        "} "
+        "layers: { "
+        "  name: 'norm1' "
+        "  type: LRN "
+        "  bottom: 'pool1' "
+        "  top: 'norm1' "
+        "  lrn_param { "
+        "    local_size: 3 "
+        "  } "
+        "} "
+        "layers: { "
+        "  name: 'softmax' "
+        "  type: SOFTMAX "
+        "  bottom: 'norm1' "
+        "  top: 'softmax' "
+        "} ";
+    InitNetFromProtoString(proto);
+  }
+
   int seed_;
   shared_ptr<Net<Dtype> > net_;
 };
@@ -2028,4 +2091,62 @@ TEST_F(FilterNetTest, TestFilterInOutByExcludeMultiRule) {
   this->RunFilterNetTest(input_proto_test, output_proto_test);
 }
 
+TYPED_TEST(NetTest, TestReshape) {
+  typedef typename TypeParam::Dtype Dtype;
+  // We set up bottom blobs of two different sizes, switch between
+  // them, and check that forward and backward both run and the results
+  // are the same.
+  Caffe::set_random_seed(this->seed_);
+  Caffe::set_mode(Caffe::CPU);
+  FillerParameter filler_param;
+  filler_param.set_std(1);
+  GaussianFiller<Dtype> filler(filler_param);
+  Blob<Dtype> blob1(4, 3, 9, 11);
+  Blob<Dtype> blob2(2, 3, 12, 10);
+  filler.Fill(&blob1);
+  filler.Fill(&blob2);
+
+  this->InitReshapableNet();
+  Blob<Dtype>* input_blob = this->net_->input_blobs()[0];
+  Blob<Dtype>* output_blob = this->net_->output_blobs()[0];
+  input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(),
+      blob1.width());
+  caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data());
+  this->net_->ForwardPrefilled();
+  // call backward just to make sure it runs
+  this->net_->Backward();
+  Blob<Dtype> output1(output_blob->num(), output_blob->channels(),
+      output_blob->height(), output_blob->width());
+  caffe_copy(output1.count(), output_blob->cpu_data(),
+      output1.mutable_cpu_data());
+
+  input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(),
+      blob2.width());
+  caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data());
+  this->net_->ForwardPrefilled();
+  this->net_->Backward();
+  Blob<Dtype> output2(output_blob->num(), output_blob->channels(),
+      output_blob->height(), output_blob->width());
+  caffe_copy(output2.count(), output_blob->cpu_data(),
+      output2.mutable_cpu_data());
+
+  input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(),
+      blob1.width());
+  caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data());
+  this->net_->ForwardPrefilled();
+  this->net_->Backward();
+  for (int i = 0; i < output1.count(); ++i) {
+    CHECK_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i));
+  }
+
+  input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(),
+      blob2.width());
+  caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data());
+  this->net_->ForwardPrefilled();
+  this->net_->Backward();
+  for (int i = 0; i < output2.count(); ++i) {
+    CHECK_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i));
+  }
+}
+
 }  // namespace caffe