add test for input/input_dim and fix bug, wasn't copying input
authorJeff Donahue <jeff.donahue@gmail.com>
Mon, 17 Mar 2014 12:18:55 +0000 (05:18 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 28 Mar 2014 06:42:28 +0000 (23:42 -0700)
src/caffe/test/test_upgrade_proto.cpp
src/caffe/util/upgrade_proto.cpp
tools/upgrade_net_proto.cpp

index d6ca6d0..8538f11 100644 (file)
@@ -1107,8 +1107,6 @@ class V0UpgradeTest : public ::testing::Test {
         output_param_string, &expected_output_param));
     NetParameter actual_output_param;
     UpgradeV0Net(input_param, &actual_output_param);
-    CHECK_EQ(expected_output_param.DebugString(),
-        actual_output_param.DebugString());
     EXPECT_EQ(expected_output_param.DebugString(),
         actual_output_param.DebugString());
   }
@@ -1266,6 +1264,11 @@ TYPED_TEST(V0UpgradeTest, TestSimple) {
 TYPED_TEST(V0UpgradeTest, TestAllParams) {
   const string& input_proto =
       "name: 'CaffeNet' "
+      "input: 'input_data' "
+      "input_dim: 64 "
+      "input_dim: 3 "
+      "input_dim: 32 "
+      "input_dim: 32 "
       "layers { "
       "  layer { "
       "    name: 'data' "
@@ -1408,6 +1411,11 @@ TYPED_TEST(V0UpgradeTest, TestAllParams) {
       "} ";
   const string& expected_output_proto =
       "name: 'CaffeNet' "
+      "input: 'input_data' "
+      "input_dim: 64 "
+      "input_dim: 3 "
+      "input_dim: 32 "
+      "input_dim: 32 "
       "layers { "
       "  name: 'data' "
       "  type: 'data' "
index 0248252..38bc216 100644 (file)
@@ -32,6 +32,9 @@ bool UpgradeV0Net(const V0NetParameter& v0_net_param_padding_layers,
     is_fully_compatible &= UpgradeV0LayerConnection(v0_net_param.layers(i),
                                                     net_param->add_layers());
   }
+  for (int i = 0; i < v0_net_param.input_size(); ++i) {
+    net_param->add_input(v0_net_param.input(i));
+  }
   for (int i = 0; i < v0_net_param.input_dim_size(); ++i) {
     net_param->add_input_dim(v0_net_param.input_dim(i));
   }
index 8c50c51..a416878 100644 (file)
@@ -18,7 +18,7 @@ using namespace caffe;  // NOLINT(build/namespaces)
 
 int main(int argc, char** argv) {
   ::google::InitGoogleLogging(argv[0]);
-  if (argc < 3) {
+  if (argc != 3) {
     LOG(ERROR) << "Usage: "
         << "upgrade_net_proto v0_net_proto_file_in net_proto_file_out";
     return 0;