Add unit tests and skeleton code for Net/Solver filtering functionality.
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 29 Jul 2014 18:22:52 +0000 (11:22 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 29 Jul 2014 20:19:24 +0000 (13:19 -0700)
include/caffe/net.hpp
include/caffe/solver.hpp
src/caffe/net.cpp
src/caffe/test/test_net.cpp
src/caffe/test/test_solver.cpp [new file with mode: 0644]

index 7548011..808b244 100644 (file)
@@ -115,6 +115,14 @@ class Net {
 
   void set_debug_info(const bool value) { debug_info_ = value; }
 
+  // Helpers for Init.
+  // Remove layers that the user specified should be excluded given the current
+  // phase, level, and stage.
+  static void FilterNet(const NetParameter& param,
+      NetParameter* param_filtered);
+  static bool StateMeetsRule(const NetState& state, const NetStateRule& rule,
+      const string& layer_name);
+
  protected:
   // Helpers for Init.
   // Append a new input or top blob to the net.
index 3112c59..52dcd19 100644 (file)
@@ -6,6 +6,8 @@
 #include <string>
 #include <vector>
 
+#include "caffe/net.hpp"
+
 namespace caffe {
 
 template <typename Dtype>
@@ -20,6 +22,9 @@ class Solver {
   inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
   virtual ~Solver() {}
   inline shared_ptr<Net<Dtype> > net() { return net_; }
+  inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
+    return test_nets_;
+  }
 
  protected:
   // PreSolve is run before any solving iteration starts, allowing one to
index 228f826..a02b111 100644 (file)
@@ -167,6 +167,17 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   debug_info_ = false;
 }
 
+template <typename Dtype>
+void Net<Dtype>::FilterNet(const NetParameter& param,
+    NetParameter* param_filtered) {
+}
+
+template <typename Dtype>
+bool Net<Dtype>::StateMeetsRule(const NetState& state,
+    const NetStateRule& rule, const string& layer_name) {
+  return true;
+}
+
 // Helper for Net::Init: add a new input or top blob to the net.  (Inputs have
 // layer_id == -1, tops have layer_id >= 0.)
 template <typename Dtype>
index e84701d..c418134 100644 (file)
@@ -835,4 +835,824 @@ TYPED_TEST(NetTest, TestFromTo) {
   }
 }
 
+class FilterNetTest : public ::testing::Test {
+ protected:
+  void RunFilterNetTest(
+      const string& input_param_string, const string& filtered_param_string) {
+    NetParameter input_param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(
+        input_param_string, &input_param));
+    NetParameter expected_filtered_param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(
+        filtered_param_string, &expected_filtered_param));
+    NetParameter actual_filtered_param;
+    Net<float>::FilterNet(input_param, &actual_filtered_param);
+    EXPECT_EQ(expected_filtered_param.DebugString(),
+        actual_filtered_param.DebugString());
+    // Also test idempotence.
+    NetParameter double_filtered_param;
+    Net<float>::FilterNet(actual_filtered_param, &double_filtered_param);
+    EXPECT_EQ(actual_filtered_param.DebugString(),
+       double_filtered_param.DebugString());
+  }
+};
+
+TEST_F(FilterNetTest, TestNoFilter) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterLeNetTrainTest) {
+  const string& input_proto =
+      "name: 'LeNet' "
+      "layers { "
+      "  name: 'mnist' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "  data_param { "
+      "    source: 'mnist-train-leveldb' "
+      "    scale: 0.00390625 "
+      "    batch_size: 64 "
+      "  } "
+      "  include: { phase: TRAIN } "
+      "} "
+      "layers { "
+      "  name: 'mnist' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "  data_param { "
+      "    source: 'mnist-test-leveldb' "
+      "    scale: 0.00390625 "
+      "    batch_size: 100 "
+      "  } "
+      "  include: { phase: TEST } "
+      "} "
+      "layers { "
+      "  name: 'conv1' "
+      "  type: CONVOLUTION "
+      "  bottom: 'data' "
+      "  top: 'conv1' "
+      "  blobs_lr: 1 "
+      "  blobs_lr: 2 "
+      "  convolution_param { "
+      "    num_output: 20 "
+      "    kernel_size: 5 "
+      "    stride: 1 "
+      "    weight_filler { "
+      "      type: 'xavier' "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "    } "
+      "  } "
+      "} "
+      "layers { "
+      "  name: 'ip1' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'conv1' "
+      "  top: 'ip1' "
+      "  blobs_lr: 1 "
+      "  blobs_lr: 2 "
+      "  inner_product_param { "
+      "    num_output: 10 "
+      "    weight_filler { "
+      "      type: 'xavier' "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "    } "
+      "  } "
+      "} "
+      "layers { "
+      "  name: 'accuracy' "
+      "  type: ACCURACY "
+      "  bottom: 'ip1' "
+      "  bottom: 'label' "
+      "  top: 'accuracy' "
+      "  include: { phase: TEST } "
+      "} "
+      "layers { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'ip2' "
+      "  bottom: 'label' "
+      "  top: 'loss' "
+      "} ";
+  const string input_proto_train = "state: { phase: TRAIN } " + input_proto;
+  const string input_proto_test = "state: { phase: TEST } " + input_proto;
+  const string& output_proto_train =
+      "state: { phase: TRAIN } "
+      "name: 'LeNet' "
+      "layers { "
+      "  name: 'mnist' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "  data_param { "
+      "    source: 'mnist-train-leveldb' "
+      "    scale: 0.00390625 "
+      "    batch_size: 64 "
+      "  } "
+      "  include: { phase: TRAIN } "
+      "} "
+      "layers { "
+      "  name: 'conv1' "
+      "  type: CONVOLUTION "
+      "  bottom: 'data' "
+      "  top: 'conv1' "
+      "  blobs_lr: 1 "
+      "  blobs_lr: 2 "
+      "  convolution_param { "
+      "    num_output: 20 "
+      "    kernel_size: 5 "
+      "    stride: 1 "
+      "    weight_filler { "
+      "      type: 'xavier' "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "    } "
+      "  } "
+      "} "
+      "layers { "
+      "  name: 'ip1' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'conv1' "
+      "  top: 'ip1' "
+      "  blobs_lr: 1 "
+      "  blobs_lr: 2 "
+      "  inner_product_param { "
+      "    num_output: 10 "
+      "    weight_filler { "
+      "      type: 'xavier' "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "    } "
+      "  } "
+      "} "
+      "layers { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'ip2' "
+      "  bottom: 'label' "
+      "  top: 'loss' "
+      "} ";
+  const string& output_proto_test =
+      "state: { phase: TEST } "
+      "name: 'LeNet' "
+      "layers { "
+      "  name: 'mnist' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "  data_param { "
+      "    source: 'mnist-test-leveldb' "
+      "    scale: 0.00390625 "
+      "    batch_size: 100 "
+      "  } "
+      "  include: { phase: TEST } "
+      "} "
+      "layers { "
+      "  name: 'conv1' "
+      "  type: CONVOLUTION "
+      "  bottom: 'data' "
+      "  top: 'conv1' "
+      "  blobs_lr: 1 "
+      "  blobs_lr: 2 "
+      "  convolution_param { "
+      "    num_output: 20 "
+      "    kernel_size: 5 "
+      "    stride: 1 "
+      "    weight_filler { "
+      "      type: 'xavier' "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "    } "
+      "  } "
+      "} "
+      "layers { "
+      "  name: 'ip1' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'conv1' "
+      "  top: 'ip1' "
+      "  blobs_lr: 1 "
+      "  blobs_lr: 2 "
+      "  inner_product_param { "
+      "    num_output: 10 "
+      "    weight_filler { "
+      "      type: 'xavier' "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "    } "
+      "  } "
+      "} "
+      "layers { "
+      "  name: 'accuracy' "
+      "  type: ACCURACY "
+      "  bottom: 'ip1' "
+      "  bottom: 'label' "
+      "  top: 'accuracy' "
+      "  include: { phase: TEST } "
+      "} "
+      "layers { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'ip2' "
+      "  bottom: 'label' "
+      "  top: 'loss' "
+      "} ";
+  this->RunFilterNetTest(input_proto_train, output_proto_train);
+  this->RunFilterNetTest(input_proto_test, output_proto_test);
+}
+
+TEST_F(FilterNetTest, TestFilterOutByStage) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "  include: { stage: 'mystage' } "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  const string& output_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, output_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterOutByStage2) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { stage: 'mystage' } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  const string& output_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, output_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByStage) {
+  const string& input_proto =
+      "state: { stage: 'mystage' } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { stage: 'mystage' } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByStage2) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  exclude: { stage: 'mystage' } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterOutByMultipleStage) {
+  const string& input_proto =
+      "state: { stage: 'mystage' } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { stage: 'mystage' stage: 'myotherstage' } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  include: { stage: 'mystage' } "
+      "} ";
+  const string& output_proto =
+      "state: { stage: 'mystage' } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  include: { stage: 'mystage' } "
+      "} ";
+  this->RunFilterNetTest(input_proto, output_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByMultipleStage) {
+  const string& input_proto =
+      "state: { stage: 'mystage' } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { stage: 'myotherstage' } "
+      "  include: { stage: 'mystage' } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  include: { stage: 'mystage' } "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByMultipleStage2) {
+  const string& input_proto =
+      "state: { stage: 'mystage' stage: 'myotherstage' } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { stage: 'mystage' stage: 'myotherstage' } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  include: { stage: 'mystage' } "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterOutByMinLevel) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { min_level: 3 } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  const string& output_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, output_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterOutByMaxLevel) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { max_level: -3 } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  const string& output_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, output_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByMinLevel) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { min_level: 0 } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByMinLevel2) {
+  const string& input_proto =
+      "state: { level: 7 } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { min_level: 3 } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByMaxLevel) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { max_level: 0 } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInByMaxLevel2) {
+  const string& input_proto =
+      "state: { level: -7 } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { max_level: -3 } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "} ";
+  this->RunFilterNetTest(input_proto, input_proto);
+}
+
+TEST_F(FilterNetTest, TestFilterInOutByIncludeMultiRule) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { min_level: 2  phase: TRAIN } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  include: { min_level: 2  phase: TEST } "
+      "} ";
+  const string& input_proto_train =
+      "state: { level: 4  phase: TRAIN } " + input_proto;
+  const string& input_proto_test =
+      "state: { level: 4  phase: TEST } " + input_proto;
+  const string& output_proto_train =
+      "state: { level: 4  phase: TRAIN } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { min_level: 2  phase: TRAIN } "
+      "} ";
+  const string& output_proto_test =
+      "state: { level: 4  phase: TEST } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  include: { min_level: 2  phase: TEST } "
+      "} ";
+  this->RunFilterNetTest(input_proto_train, output_proto_train);
+  this->RunFilterNetTest(input_proto_test, output_proto_test);
+}
+
+TEST_F(FilterNetTest, TestFilterInByIncludeMultiRule) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  include: { min_level: 2  phase: TRAIN } "
+      "  include: { phase: TEST } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  include: { min_level: 2  phase: TEST } "
+      "  include: { phase: TRAIN } "
+      "} ";
+  const string& input_proto_train =
+      "state: { level: 2  phase: TRAIN } " + input_proto;
+  const string& input_proto_test =
+      "state: { level: 2  phase: TEST } " + input_proto;
+  this->RunFilterNetTest(input_proto_train, input_proto_train);
+  this->RunFilterNetTest(input_proto_test, input_proto_test);
+}
+
+TEST_F(FilterNetTest, TestFilterInOutByExcludeMultiRule) {
+  const string& input_proto =
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  exclude: { min_level: 2  phase: TRAIN } "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  exclude: { min_level: 2  phase: TEST } "
+      "} ";
+  const string& input_proto_train =
+      "state: { level: 4  phase: TRAIN } " + input_proto;
+  const string& input_proto_test =
+      "state: { level: 4  phase: TEST } " + input_proto;
+  const string& output_proto_train =
+      "state: { level: 4  phase: TRAIN } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'loss' "
+      "  type: SOFTMAX_LOSS "
+      "  bottom: 'innerprod' "
+      "  bottom: 'label' "
+      "  exclude: { min_level: 2  phase: TEST } "
+      "} ";
+  const string& output_proto_test =
+      "state: { level: 4  phase: TEST } "
+      "name: 'TestNetwork' "
+      "layers: { "
+      "  name: 'data' "
+      "  type: DATA "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layers: { "
+      "  name: 'innerprod' "
+      "  type: INNER_PRODUCT "
+      "  bottom: 'data' "
+      "  top: 'innerprod' "
+      "  exclude: { min_level: 2  phase: TRAIN } "
+      "} ";
+  this->RunFilterNetTest(input_proto_train, output_proto_train);
+  this->RunFilterNetTest(input_proto_test, output_proto_test);
+}
+
 }  // namespace caffe
diff --git a/src/caffe/test/test_solver.cpp b/src/caffe/test/test_solver.cpp
new file mode 100644 (file)
index 0000000..7380ab8
--- /dev/null
@@ -0,0 +1,98 @@
+// Copyright 2014 BVLC and contributors.
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/text_format.h"
+
+#include "gtest/gtest.h"
+#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/solver.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+using std::ostringstream;
+
+namespace caffe {
+
+template <typename TypeParam>
+class SolverTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  virtual void InitSolverFromProtoString(const string& proto) {
+    SolverParameter param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
+    solver_.reset(new SGDSolver<Dtype>(param));
+  }
+
+  shared_ptr<Solver<Dtype> > solver_;
+};
+
+TYPED_TEST_CASE(SolverTest, TestDtypesAndDevices);
+
+TYPED_TEST(SolverTest, TestInitTrainTestNets) {
+  const string& proto =
+     "test_interval: 10 "
+     "test_iter: 10 "
+     "test_state: { stage: 'with-softmax' }"
+     "test_iter: 10 "
+     "test_state: {}"
+     "net_param { "
+     "  name: 'TestNetwork' "
+     "  layers: { "
+     "    name: 'data' "
+     "    type: DUMMY_DATA "
+     "    dummy_data_param { "
+     "      num: 5 "
+     "      channels: 3 "
+     "      height: 10 "
+     "      width: 10 "
+     "      num: 5 "
+     "      channels: 1 "
+     "      height: 1 "
+     "      width: 1 "
+     "    } "
+     "    top: 'data' "
+     "    top: 'label' "
+     "  } "
+     "  layers: { "
+     "    name: 'innerprod' "
+     "    type: INNER_PRODUCT "
+     "    inner_product_param { "
+     "      num_output: 10 "
+     "    } "
+     "    bottom: 'data' "
+     "    top: 'innerprod' "
+     "  } "
+     "  layers: { "
+     "    name: 'accuracy' "
+     "    type: ACCURACY "
+     "    bottom: 'innerprod' "
+     "    bottom: 'label' "
+     "    top: 'accuracy' "
+     "    exclude: { phase: TRAIN } "
+     "  } "
+     "  layers: { "
+     "    name: 'loss' "
+     "    type: SOFTMAX_LOSS "
+     "    bottom: 'innerprod' "
+     "    bottom: 'label' "
+     "    include: { phase: TRAIN } "
+     "    include: { phase: TEST stage: 'with-softmax' } "
+     "  } "
+     "} ";
+  this->InitSolverFromProtoString(proto);
+  ASSERT_TRUE(this->solver_->net());
+  EXPECT_TRUE(this->solver_->net()->has_layer("loss"));
+  EXPECT_FALSE(this->solver_->net()->has_layer("accuracy"));
+  ASSERT_EQ(2, this->solver_->test_nets().size());
+  EXPECT_TRUE(this->solver_->test_nets()[0]->has_layer("loss"));
+  EXPECT_TRUE(this->solver_->test_nets()[0]->has_layer("accuracy"));
+  EXPECT_FALSE(this->solver_->test_nets()[1]->has_layer("loss"));
+  EXPECT_TRUE(this->solver_->test_nets()[1]->has_layer("accuracy"));
+}
+
+}  // namespace caffe