give phase to Net and Layer
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 23 Jan 2015 08:31:28 +0000 (00:31 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 17 Feb 2015 19:35:50 +0000 (11:35 -0800)
Give the responsibility for phase to Net and Layer, making phase an
immutable choice at instantiation and dropping it from the Caffe
singleton.

20 files changed:
include/caffe/common.hpp
include/caffe/data_layers.hpp
include/caffe/layer.hpp
include/caffe/loss_layers.hpp
include/caffe/net.hpp
include/caffe/neuron_layers.hpp
src/caffe/common.cpp
src/caffe/layers/base_data_layer.cpp
src/caffe/layers/dropout_layer.cpp
src/caffe/layers/dropout_layer.cu
src/caffe/layers/pooling_layer.cu
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/test/test_common.cpp
src/caffe/test/test_data_layer.cpp
src/caffe/test/test_data_transformer.cpp
src/caffe/test/test_maxpool_dropout_layers.cpp
src/caffe/test/test_net.cpp
src/caffe/test/test_neuron_layer.cpp
src/caffe/test/test_stochastic_pooling.cpp

index 5fc4ed3..890673c 100644 (file)
@@ -104,8 +104,6 @@ class Caffe {
     return *singleton_;
   }
   enum Brew { CPU, GPU };
-  enum Phase { TRAIN, TEST };
-
 
   // This random number generator facade hides boost and CUDA rng
   // implementation from one another (for cross-platform compatibility).
@@ -137,16 +135,12 @@ class Caffe {
 
   // Returns the mode: running on CPU or GPU.
   inline static Brew mode() { return Get().mode_; }
-  // Returns the phase: TRAIN or TEST.
-  inline static Phase phase() { return Get().phase_; }
   // The setters for the variables
   // Sets the mode. It is recommended that you don't change the mode halfway
   // into the program since that may cause allocation of pinned memory being
   // freed in a non-pinned way, which may cause problems - I haven't verified
   // it personally but better to note it here in the header file.
   inline static void set_mode(Brew mode) { Get().mode_ = mode; }
-  // Sets the phase.
-  inline static void set_phase(Phase phase) { Get().phase_ = phase; }
   // Sets the random seed of both boost and curand
   static void set_random_seed(const unsigned int seed);
   // Sets the device. Since we have cublas and curand stuff, set device also
@@ -163,7 +157,6 @@ class Caffe {
   shared_ptr<RNG> random_generator_;
 
   Brew mode_;
-  Phase phase_;
   static shared_ptr<Caffe> singleton_;
 
  private:
index e0d5a8a..aa7e696 100644 (file)
@@ -14,6 +14,7 @@
 #include "caffe/filler.hpp"
 #include "caffe/internal_thread.hpp"
 #include "caffe/layer.hpp"
+#include "caffe/net.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/db.hpp"
 
index c6461c1..34e00d7 100644 (file)
@@ -33,7 +33,8 @@ class Layer {
    */
   explicit Layer(const LayerParameter& param)
     : layer_param_(param) {
-      // The only thing we do is to copy blobs if there are any.
+      // Set phase and copy blobs (if there are any).
+      phase_ = param.phase();
       if (layer_param_.blobs_size() > 0) {
         blobs_.resize(layer_param_.blobs_size());
         for (int i = 0; i < layer_param_.blobs_size(); ++i) {
@@ -288,6 +289,8 @@ class Layer {
  protected:
   /** The protobuf that stores the layer parameters */
   LayerParameter layer_param_;
+  /** The phase: TRAIN or TEST */
+  Phase phase_;
   /** The vector that stores the learnable parameters as a set of blobs. */
   vector<shared_ptr<Blob<Dtype> > > blobs_;
   /** Vector indicating whether to compute the diff of each param blob. */
index ea52f56..36413cc 100644 (file)
@@ -169,8 +169,8 @@ class ContrastiveLossLayer : public LossLayer<Dtype> {
 
   /**
    * @brief Computes the Contrastive error gradient w.r.t. the inputs.
-   * 
-   * Computes the gradients with respect to the two input vectors (bottom[0] and 
+   *
+   * Computes the gradients with respect to the two input vectors (bottom[0] and
    * bottom[1]), but not the similarity label (bottom[2]).
    *
    * @param top output Blob vector (length 1), providing the error gradient with
@@ -189,7 +189,7 @@ class ContrastiveLossLayer : public LossLayer<Dtype> {
    *      the features @f$a@f$; Backward fills their diff with
    *      gradients if propagate_down[0]
    *   -# @f$ (N \times C \times 1 \times 1) @f$
-   *      the features @f$b@f$; Backward fills their diff with gradients if 
+   *      the features @f$b@f$; Backward fills their diff with gradients if
    *      propagate_down[1]
    */
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
index 23da5db..153b881 100644 (file)
@@ -115,6 +115,8 @@ class Net {
   inline const vector<shared_ptr<Layer<Dtype> > >& layers() const {
     return layers_;
   }
+  /// @brief returns the phase: TRAIN or TEST
+  inline Phase phase() const { return phase_; }
   /**
    * @brief returns the bottom vecs for each layer -- usually you won't
    *        need this unless you do per-layer checks such as gradients.
@@ -207,6 +209,10 @@ class Net {
   /// @brief Get misc parameters, e.g. the LR multiplier and weight decay.
   void GetLearningRateAndWeightDecay();
 
+  /// @brief The network name
+  string name_;
+  /// @brief The phase: TRAIN or TEST
+  Phase phase_;
   /// @brief Individual layers in the net
   vector<shared_ptr<Layer<Dtype> > > layers_;
   vector<string> layer_names_;
@@ -239,7 +245,6 @@ class Net {
   vector<int> net_output_blob_indices_;
   vector<Blob<Dtype>*> net_input_blobs_;
   vector<Blob<Dtype>*> net_output_blobs_;
-  string name_;
   /// The parameters in the network.
   vector<shared_ptr<Blob<Dtype> > > params_;
   /// the learning rate multipliers
index 46c8417..0c306fb 100644 (file)
@@ -8,6 +8,7 @@
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/layer.hpp"
+#include "caffe/net.hpp"
 #include "caffe/proto/caffe.pb.h"
 
 #define HDF5_DATA_DATASET_NAME "data"
index 834d569..af96cac 100644 (file)
@@ -42,7 +42,7 @@ void GlobalInit(int* pargc, char*** pargv) {
 #ifdef CPU_ONLY  // CPU-only Caffe.
 
 Caffe::Caffe()
-    : random_generator_(), mode_(Caffe::CPU), phase_(Caffe::TRAIN) { }
+    : random_generator_(), mode_(Caffe::CPU) { }
 
 Caffe::~Caffe() { }
 
@@ -86,7 +86,7 @@ void* Caffe::RNG::generator() {
 
 Caffe::Caffe()
     : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
-    mode_(Caffe::CPU), phase_(Caffe::TRAIN) {
+    mode_(Caffe::CPU) {
   // Try to create a cublas handler, and report an error if failed (but we will
   // keep the program running as one might just want to run CPU code).
   if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
index c3b9bc4..d7e0b90 100644 (file)
@@ -23,7 +23,7 @@ void BaseDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   }
   // The subclasses should setup the size of bottom and top
   DataLayerSetUp(bottom, top);
-  data_transformer_.InitRand();
+  data_transformer_->InitRand();
 }
 
 template <typename Dtype>
index 5f81cc1..ec1256f 100644 (file)
@@ -37,7 +37,7 @@ void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   Dtype* top_data = top[0]->mutable_cpu_data();
   unsigned int* mask = rand_vec_.mutable_cpu_data();
   const int count = bottom[0]->count();
-  if (Caffe::phase() == Caffe::TRAIN) {
+  if (this->phase_ == TRAIN) {
     // Create random numbers
     caffe_rng_bernoulli(count, 1. - threshold_, mask);
     for (int i = 0; i < count; ++i) {
@@ -55,7 +55,7 @@ void DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   if (propagate_down[0]) {
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
-    if (Caffe::phase() == Caffe::TRAIN) {
+    if (this->phase_ == TRAIN) {
       const unsigned int* mask = rand_vec_.cpu_data();
       const int count = bottom[0]->count();
       for (int i = 0; i < count; ++i) {
index df13d8e..f9ea04f 100644 (file)
@@ -26,7 +26,7 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = top[0]->mutable_gpu_data();
   const int count = bottom[0]->count();
-  if (Caffe::phase() == Caffe::TRAIN) {
+  if (this->phase_ == TRAIN) {
     unsigned int* mask =
         static_cast<unsigned int*>(rand_vec_.mutable_gpu_data());
     caffe_gpu_rng_uniform(count, mask);
@@ -56,7 +56,7 @@ void DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   if (propagate_down[0]) {
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
-    if (Caffe::phase() == Caffe::TRAIN) {
+    if (this->phase_ == TRAIN) {
       const unsigned int* mask =
           static_cast<const unsigned int*>(rand_vec_.gpu_data());
       const int count = bottom[0]->count();
index 0d3f218..d1d4850 100644 (file)
@@ -182,7 +182,7 @@ void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
         kernel_w_, stride_h_, stride_w_, pad_h_, pad_w_, top_data);
     break;
   case PoolingParameter_PoolMethod_STOCHASTIC:
-    if (Caffe::phase() == Caffe::TRAIN) {
+    if (this->phase_ == TRAIN) {
       // We need to create the random index as well.
       caffe_gpu_rng_uniform(count, Dtype(0), Dtype(1),
                             rand_idx_.mutable_gpu_data());
index 2e44911..d1ae0bc 100644 (file)
@@ -32,6 +32,8 @@ Net<Dtype>::Net(const string& param_file) {
 
 template <typename Dtype>
 void Net<Dtype>::Init(const NetParameter& in_param) {
+  // Set phase from the state.
+  phase_ = in_param.state().phase();
   // Filter layers based on their include/exclude rules and
   // the current NetState.
   NetParameter filtered_param;
@@ -62,6 +64,11 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   top_id_vecs_.resize(param.layer_size());
   bottom_need_backward_.resize(param.layer_size());
   for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) {
+    // Inherit phase from net if unset.
+    if (!param.layer(layer_id).has_phase()) {
+      param.mutable_layer(layer_id)->set_phase(phase_);
+    }
+    // Setup layer.
     const LayerParameter& layer_param = param.layer(layer_id);
     layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
     layer_names_.push_back(layer_param.name());
@@ -210,20 +217,6 @@ template <typename Dtype>
 void Net<Dtype>::FilterNet(const NetParameter& param,
     NetParameter* param_filtered) {
   NetState net_state(param.state());
-  // Let the phase of the net be the current global phase provided in the Caffe
-  // singleton, unless explicitly provided by the state.
-  if (!net_state.has_phase()) {
-    switch (Caffe::phase()) {
-      case Caffe::TRAIN:
-        net_state.set_phase(TRAIN);
-        break;
-      case Caffe::TEST:
-        net_state.set_phase(TEST);
-        break;
-      default:
-        LOG(FATAL) << "Unknown phase: " << Caffe::phase();
-    }
-  }
   param_filtered->CopyFrom(param);
   param_filtered->clear_layer();
   for (int i = 0; i < param.layer_size(); ++i) {
index 61dd186..6bdda33 100644 (file)
@@ -253,6 +253,9 @@ message LayerParameter {
   repeated string bottom = 3; // the name of each bottom blob
   repeated string top = 4; // the name of each top blob
 
+  // The train / test phase for computation.
+  optional Phase phase = 10;
+
   // The amount of weight to assign each top blob in the objective.
   // Each layer assigns a default value, usually of either 0 or 1,
   // to each top blob.
index 0b3639c..b3a61b0 100644 (file)
@@ -29,13 +29,6 @@ TEST_F(CommonTest, TestBrewMode) {
   EXPECT_EQ(Caffe::mode(), Caffe::GPU);
 }
 
-TEST_F(CommonTest, TestPhase) {
-  Caffe::set_phase(Caffe::TRAIN);
-  EXPECT_EQ(Caffe::phase(), Caffe::TRAIN);
-  Caffe::set_phase(Caffe::TEST);
-  EXPECT_EQ(Caffe::phase(), Caffe::TEST);
-}
-
 TEST_F(CommonTest, TestRandSeedCPU) {
   SyncedMemory data_a(10 * sizeof(int));
   SyncedMemory data_b(10 * sizeof(int));
index adc99dd..afe2a40 100644 (file)
@@ -69,6 +69,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
   void TestRead() {
     const Dtype scale = 3;
     LayerParameter param;
+    param.set_phase(TRAIN);
     DataParameter* data_param = param.mutable_data_param();
     data_param->set_batch_size(5);
     data_param->set_source(filename_->c_str());
@@ -132,6 +133,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
 
     // Load and check data of various shapes.
     LayerParameter param;
+    param.set_phase(TEST);
     DataParameter* data_param = param.mutable_data_param();
     data_param->set_batch_size(1);
     data_param->set_source(filename_->c_str());
@@ -167,9 +169,10 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
     }
   }
 
-  void TestReadCrop() {
+  void TestReadCrop(Phase phase) {
     const Dtype scale = 3;
     LayerParameter param;
+    param.set_phase(phase);
     Caffe::set_random_seed(1701);
 
     DataParameter* data_param = param.mutable_data_param();
@@ -205,7 +208,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
           num_with_center_value +=
               (center_value == blob_top_data_->cpu_data()[i * 2 + j]);
           // At TEST time, check that we always get center value.
-          if (Caffe::phase() == Caffe::TEST) {
+          if (phase == caffe::TEST) {
             EXPECT_EQ(center_value, this->blob_top_data_->cpu_data()[i * 2 + j])
                 << "debug: iter " << iter << " i " << i << " j " << j;
           }
@@ -214,7 +217,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
       // At TRAIN time, check that we did not get the center crop all 10 times.
       // (This check fails with probability 1-1/12^10 in a correct
       // implementation, so we call set_random_seed.)
-      if (Caffe::phase() == Caffe::TRAIN) {
+      if (phase == caffe::TRAIN) {
         EXPECT_LT(num_with_center_value, 10);
       }
     }
@@ -222,6 +225,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
 
   void TestReadCropTrainSequenceSeeded() {
     LayerParameter param;
+    param.set_phase(TRAIN);
     DataParameter* data_param = param.mutable_data_param();
     data_param->set_batch_size(5);
     data_param->set_source(filename_->c_str());
@@ -276,6 +280,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
 
   void TestReadCropTrainSequenceUnseeded() {
     LayerParameter param;
+    param.set_phase(TRAIN);
     DataParameter* data_param = param.mutable_data_param();
     data_param->set_batch_size(5);
     data_param->set_source(filename_->c_str());
@@ -354,16 +359,14 @@ TYPED_TEST(DataLayerTest, TestReshapeLevelDB) {
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTrainLevelDB) {
-  Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
-  this->TestReadCrop();
+  this->TestReadCrop(TRAIN);
 }
 
 // Test that the sequence of random crops is consistent when using
 // Caffe::set_random_seed.
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLevelDB) {
-  Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
   this->TestReadCropTrainSequenceSeeded();
@@ -372,17 +375,15 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLevelDB) {
 // Test that the sequence of random crops differs across iterations when
 // Caffe::set_random_seed isn't called (and seeds from srand are ignored).
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLevelDB) {
-  Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
   this->TestReadCropTrainSequenceUnseeded();
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTestLevelDB) {
-  Caffe::set_phase(Caffe::TEST);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LEVELDB);
-  this->TestReadCrop();
+  this->TestReadCrop(TEST);
 }
 
 TYPED_TEST(DataLayerTest, TestReadLMDB) {
@@ -396,16 +397,14 @@ TYPED_TEST(DataLayerTest, TestReshapeLMDB) {
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTrainLMDB) {
-  Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LMDB);
-  this->TestReadCrop();
+  this->TestReadCrop(TRAIN);
 }
 
 // Test that the sequence of random crops is consistent when using
 // Caffe::set_random_seed.
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLMDB) {
-  Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LMDB);
   this->TestReadCropTrainSequenceSeeded();
@@ -414,17 +413,15 @@ TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceSeededLMDB) {
 // Test that the sequence of random crops differs across iterations when
 // Caffe::set_random_seed isn't called (and seeds from srand are ignored).
 TYPED_TEST(DataLayerTest, TestReadCropTrainSequenceUnseededLMDB) {
-  Caffe::set_phase(Caffe::TRAIN);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LMDB);
   this->TestReadCropTrainSequenceUnseeded();
 }
 
 TYPED_TEST(DataLayerTest, TestReadCropTestLMDB) {
-  Caffe::set_phase(Caffe::TEST);
   const bool unique_pixels = true;  // all images the same; pixels different
   this->Fill(unique_pixels, DataParameter_DB_LMDB);
-  this->TestReadCrop();
+  this->TestReadCrop(TEST);
 }
 
 }  // namespace caffe
index 28c7241..16570e2 100644 (file)
@@ -37,10 +37,10 @@ class DataTransformTest : public ::testing::Test {
       num_iter_(10) {}
 
   int NumSequenceMatches(const TransformationParameter transform_param,
-      const Datum& datum) {
+      const Datum& datum, Phase phase) {
     // Get crop sequence with Caffe seed 1701.
     DataTransformer<Dtype>* transformer =
-        new DataTransformer<Dtype>(transform_param);
+        new DataTransformer<Dtype>(transform_param, phase);
     const int crop_size = transform_param.crop_size();
     Caffe::set_random_seed(seed_);
     transformer->InitRand();
@@ -92,7 +92,7 @@ TYPED_TEST(DataTransformTest, TestEmptyTransform) {
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   Blob<TypeParam>* blob = new Blob<TypeParam>(1, channels, height, width);
   DataTransformer<TypeParam>* transformer =
-      new DataTransformer<TypeParam>(transform_param);
+      new DataTransformer<TypeParam>(transform_param, TEST);
   transformer->InitRand();
   transformer->Transform(datum, blob);
   EXPECT_EQ(blob->num(), 1);
@@ -116,7 +116,7 @@ TYPED_TEST(DataTransformTest, TestEmptyTransformUniquePixels) {
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   Blob<TypeParam>* blob = new Blob<TypeParam>(1, 3, 4, 5);
   DataTransformer<TypeParam>* transformer =
-      new DataTransformer<TypeParam>(transform_param);
+      new DataTransformer<TypeParam>(transform_param, TEST);
   transformer->InitRand();
   transformer->Transform(datum, blob);
   EXPECT_EQ(blob->num(), 1);
@@ -141,7 +141,7 @@ TYPED_TEST(DataTransformTest, TestCropSize) {
   Datum datum;
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   DataTransformer<TypeParam>* transformer =
-      new DataTransformer<TypeParam>(transform_param);
+      new DataTransformer<TypeParam>(transform_param, TEST);
   transformer->InitRand();
   Blob<TypeParam>* blob =
       new Blob<TypeParam>(1, channels, crop_size, crop_size);
@@ -170,8 +170,7 @@ TYPED_TEST(DataTransformTest, TestCropTrain) {
   transform_param.set_crop_size(crop_size);
   Datum datum;
   FillDatum(label, channels, height, width, unique_pixels, &datum);
-  Caffe::set_phase(Caffe::TRAIN);
-  int num_matches = this->NumSequenceMatches(transform_param, datum);
+  int num_matches = this->NumSequenceMatches(transform_param, datum, TRAIN);
   EXPECT_LT(num_matches, size * this->num_iter_);
 }
 
@@ -188,8 +187,7 @@ TYPED_TEST(DataTransformTest, TestCropTest) {
   transform_param.set_crop_size(crop_size);
   Datum datum;
   FillDatum(label, channels, height, width, unique_pixels, &datum);
-  Caffe::set_phase(Caffe::TEST);
-  int num_matches = this->NumSequenceMatches(transform_param, datum);
+  int num_matches = this->NumSequenceMatches(transform_param, datum, TEST);
   EXPECT_EQ(num_matches, size * this->num_iter_);
 }
 
@@ -205,8 +203,7 @@ TYPED_TEST(DataTransformTest, TestMirrorTrain) {
   transform_param.set_mirror(true);
   Datum datum;
   FillDatum(label, channels, height, width, unique_pixels, &datum);
-  Caffe::set_phase(Caffe::TRAIN);
-  int num_matches = this->NumSequenceMatches(transform_param, datum);
+  int num_matches = this->NumSequenceMatches(transform_param, datum, TRAIN);
   EXPECT_LT(num_matches, size * this->num_iter_);
 }
 
@@ -222,8 +219,7 @@ TYPED_TEST(DataTransformTest, TestMirrorTest) {
   transform_param.set_mirror(true);
   Datum datum;
   FillDatum(label, channels, height, width, unique_pixels, &datum);
-  Caffe::set_phase(Caffe::TEST);
-  int num_matches = this->NumSequenceMatches(transform_param, datum);
+  int num_matches = this->NumSequenceMatches(transform_param, datum, TEST);
   EXPECT_LT(num_matches, size * this->num_iter_);
 }
 
@@ -239,12 +235,12 @@ TYPED_TEST(DataTransformTest, TestCropMirrorTrain) {
   Datum datum;
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   transform_param.set_crop_size(crop_size);
-  Caffe::set_phase(Caffe::TRAIN);
-  int num_matches_crop = this->NumSequenceMatches(transform_param, datum);
+  int num_matches_crop = this->NumSequenceMatches(
+      transform_param, datum, TRAIN);
 
   transform_param.set_mirror(true);
   int num_matches_crop_mirror =
-      this->NumSequenceMatches(transform_param, datum);
+      this->NumSequenceMatches(transform_param, datum, TRAIN);
   // When doing crop and mirror we expect less num_matches than just crop
   EXPECT_LE(num_matches_crop_mirror, num_matches_crop);
 }
@@ -261,12 +257,11 @@ TYPED_TEST(DataTransformTest, TestCropMirrorTest) {
   Datum datum;
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   transform_param.set_crop_size(crop_size);
-  Caffe::set_phase(Caffe::TEST);
-  int num_matches_crop = this->NumSequenceMatches(transform_param, datum);
+  int num_matches_crop = this->NumSequenceMatches(transform_param, datum, TEST);
 
   transform_param.set_mirror(true);
   int num_matches_crop_mirror =
-      this->NumSequenceMatches(transform_param, datum);
+      this->NumSequenceMatches(transform_param, datum, TEST);
   // When doing crop and mirror we expect less num_matches than just crop
   EXPECT_LT(num_matches_crop_mirror, num_matches_crop);
 }
@@ -286,7 +281,7 @@ TYPED_TEST(DataTransformTest, TestMeanValue) {
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   Blob<TypeParam>* blob = new Blob<TypeParam>(1, channels, height, width);
   DataTransformer<TypeParam>* transformer =
-      new DataTransformer<TypeParam>(transform_param);
+      new DataTransformer<TypeParam>(transform_param, TEST);
   transformer->InitRand();
   transformer->Transform(datum, blob);
   for (int j = 0; j < blob->count(); ++j) {
@@ -309,7 +304,7 @@ TYPED_TEST(DataTransformTest, TestMeanValues) {
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   Blob<TypeParam>* blob = new Blob<TypeParam>(1, channels, height, width);
   DataTransformer<TypeParam>* transformer =
-      new DataTransformer<TypeParam>(transform_param);
+      new DataTransformer<TypeParam>(transform_param, TEST);
   transformer->InitRand();
   transformer->Transform(datum, blob);
   for (int c = 0; c < channels; ++c) {
@@ -349,7 +344,7 @@ TYPED_TEST(DataTransformTest, TestMeanFile) {
   FillDatum(label, channels, height, width, unique_pixels, &datum);
   Blob<TypeParam>* blob = new Blob<TypeParam>(1, channels, height, width);
   DataTransformer<TypeParam>* transformer =
-      new DataTransformer<TypeParam>(transform_param);
+      new DataTransformer<TypeParam>(transform_param, TEST);
   transformer->InitRand();
   transformer->Transform(datum, blob);
   for (int j = 0; j < blob->count(); ++j) {
index b1f4e4e..611d979 100644 (file)
@@ -88,8 +88,8 @@ TYPED_TEST(MaxPoolingDropoutTest, TestForward) {
 
 TYPED_TEST(MaxPoolingDropoutTest, TestBackward) {
   typedef typename TypeParam::Dtype Dtype;
-  Caffe::set_phase(Caffe::TRAIN);
   LayerParameter layer_param;
+  layer_param.set_phase(TRAIN);
   PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
   pooling_param->set_kernel_size(3);
   pooling_param->set_stride(2);
index bc0dae3..1680a3f 100644 (file)
@@ -1534,20 +1534,6 @@ TEST_F(FilterNetTest, TestFilterLeNetTrainTest) {
       output_proto_test + " state: { phase: TEST } ";
   this->RunFilterNetTest(input_proto_train, output_proto_train_explicit);
   this->RunFilterNetTest(input_proto_test, output_proto_test_explicit);
-
-  // Also check that nets are filtered according to the Caffe singleton phase,
-  // if not explicitly specified in the input proto.
-  Caffe::set_phase(Caffe::TRAIN);
-  this->RunFilterNetTest(input_proto, output_proto_train);
-  Caffe::set_phase(Caffe::TEST);
-  this->RunFilterNetTest(input_proto, output_proto_test);
-
-  // Finally, check that the current Caffe singleton phase is ignored if the
-  // phase is explicitly specified in the input proto.
-  Caffe::set_phase(Caffe::TEST);
-  this->RunFilterNetTest(input_proto_train, output_proto_train_explicit);
-  Caffe::set_phase(Caffe::TRAIN);
-  this->RunFilterNetTest(input_proto_test, output_proto_test_explicit);
 }
 
 TEST_F(FilterNetTest, TestFilterOutByStage) {
index b19a5ab..ad10720 100644 (file)
@@ -43,8 +43,8 @@ class NeuronLayerTest : public MultiDeviceTest<TypeParam> {
     if (dropout_ratio != 0.5) {
       layer_param.mutable_dropout_param()->set_dropout_ratio(dropout_ratio);
     }
-    Caffe::set_phase(Caffe::TRAIN);
     DropoutLayer<Dtype> layer(layer_param);
+    layer_param.set_phase(TRAIN);
     layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
     layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
     // Now, check values
@@ -334,7 +334,7 @@ TYPED_TEST(NeuronLayerTest, TestDropoutThreeQuarters) {
 TYPED_TEST(NeuronLayerTest, TestDropoutTestPhase) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  Caffe::set_phase(Caffe::TEST);
+  layer_param.set_phase(TEST);
   DropoutLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
@@ -351,7 +351,7 @@ TYPED_TEST(NeuronLayerTest, TestDropoutTestPhase) {
 TYPED_TEST(NeuronLayerTest, TestDropoutGradient) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  Caffe::set_phase(Caffe::TRAIN);
+  layer_param.set_phase(TRAIN);
   DropoutLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-3);
   checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
@@ -361,7 +361,7 @@ TYPED_TEST(NeuronLayerTest, TestDropoutGradient) {
 TYPED_TEST(NeuronLayerTest, TestDropoutGradientTest) {
   typedef typename TypeParam::Dtype Dtype;
   LayerParameter layer_param;
-  Caffe::set_phase(Caffe::TEST);
+  layer_param.set_phase(TEST);
   DropoutLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-3);
   checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
index ad51510..12962c6 100644 (file)
@@ -62,8 +62,8 @@ TYPED_TEST(StochasticPoolingLayerTest, TestSetup) {
 
 TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPU) {
   Caffe::set_mode(Caffe::GPU);
-  Caffe::set_phase(Caffe::TRAIN);
   LayerParameter layer_param;
+  layer_param.set_phase(TRAIN);
   PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
   pooling_param->set_kernel_size(3);
   pooling_param->set_stride(2);
@@ -106,8 +106,8 @@ TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPU) {
 
 TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPUTestPhase) {
   Caffe::set_mode(Caffe::GPU);
-  Caffe::set_phase(Caffe::TEST);
   LayerParameter layer_param;
+  layer_param.set_phase(TEST);
   PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
   pooling_param->set_kernel_size(3);
   pooling_param->set_stride(2);
@@ -144,8 +144,8 @@ TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPUTestPhase) {
 
 TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) {
   Caffe::set_mode(Caffe::GPU);
-  Caffe::set_phase(Caffe::TRAIN);
   LayerParameter layer_param;
+  layer_param.set_phase(TRAIN);
   PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
   pooling_param->set_kernel_size(3);
   pooling_param->set_stride(2);