output_param_string, &expected_output_param));
NetParameter actual_output_param;
UpgradeV0Net(input_param, &actual_output_param);
- CHECK_EQ(expected_output_param.DebugString(),
- actual_output_param.DebugString());
EXPECT_EQ(expected_output_param.DebugString(),
actual_output_param.DebugString());
}
TYPED_TEST(V0UpgradeTest, TestAllParams) {
const string& input_proto =
"name: 'CaffeNet' "
+ "input: 'input_data' "
+ "input_dim: 64 "
+ "input_dim: 3 "
+ "input_dim: 32 "
+ "input_dim: 32 "
"layers { "
" layer { "
" name: 'data' "
"} ";
const string& expected_output_proto =
"name: 'CaffeNet' "
+ "input: 'input_data' "
+ "input_dim: 64 "
+ "input_dim: 3 "
+ "input_dim: 32 "
+ "input_dim: 32 "
"layers { "
" name: 'data' "
" type: 'data' "
is_fully_compatible &= UpgradeV0LayerConnection(v0_net_param.layers(i),
net_param->add_layers());
}
+ for (int i = 0; i < v0_net_param.input_size(); ++i) {
+ net_param->add_input(v0_net_param.input(i));
+ }
for (int i = 0; i < v0_net_param.input_dim_size(); ++i) {
net_param->add_input_dim(v0_net_param.input_dim(i));
}
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
- if (argc < 3) {
+ if (argc != 3) {
LOG(ERROR) << "Usage: "
<< "upgrade_net_proto v0_net_proto_file_in net_proto_file_out";
return 0;