[matcaffe] give phase to Net
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 26 Jan 2015 06:06:23 +0000 (22:06 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 17 Feb 2015 19:35:51 +0000 (11:35 -0800)
matlab/caffe/matcaffe.cpp
matlab/caffe/matcaffe_init.m

index fd8397e..996d3d2 100644 (file)
@@ -254,14 +254,6 @@ static void set_mode_gpu(MEX_ARGS) {
   Caffe::set_mode(Caffe::GPU);
 }
 
-static void set_phase_train(MEX_ARGS) {
-  Caffe::set_phase(Caffe::TRAIN);
-}
-
-static void set_phase_test(MEX_ARGS) {
-  Caffe::set_phase(Caffe::TEST);
-}
-
 static void set_device(MEX_ARGS) {
   if (nrhs != 1) {
     ostringstream error_msg;
@@ -278,7 +270,7 @@ static void get_init_key(MEX_ARGS) {
 }
 
 static void init(MEX_ARGS) {
-  if (nrhs != 2) {
+  if (nrhs != 3) {
     ostringstream error_msg;
     error_msg << "Expected 2 arguments, got " << nrhs;
     mex_error(error_msg.str());
@@ -286,12 +278,23 @@ static void init(MEX_ARGS) {
 
   char* param_file = mxArrayToString(prhs[0]);
   char* model_file = mxArrayToString(prhs[1]);
+  char* phase_name = mxArrayToString(prhs[2]);
+
+  Phase phase;
+  if (strcmp(phase_name, "train") == 0) {
+      phase = TRAIN;
+  } else if (strcmp(phase_name, "test") == 0) {
+      phase = TEST;
+  } else {
+    mex_error("Unknown phase.");
+  }
 
-  net_.reset(new Net<float>(string(param_file)));
+  net_.reset(new Net<float>(string(param_file), phase));
   net_->CopyTrainedLayersFrom(string(model_file));
 
   mxFree(param_file);
   mxFree(model_file);
+  mxFree(phase_name);
 
   init_key = random();  // NOLINT(caffe/random_fn)
 
@@ -377,8 +380,6 @@ static handler_registry handlers[] = {
   { "is_initialized",     is_initialized  },
   { "set_mode_cpu",       set_mode_cpu    },
   { "set_mode_gpu",       set_mode_gpu    },
-  { "set_phase_train",    set_phase_train },
-  { "set_phase_test",     set_phase_test  },
   { "set_device",         set_device      },
   { "get_weights",        get_weights     },
   { "get_init_key",       get_init_key    },
index 7cc6935..5d0a0a7 100644 (file)
@@ -25,7 +25,8 @@ if caffe('is_initialized') == 0
     % NOTE: you'll have to get network definition
     error('You need the network prototxt definition');
   end
-  caffe('init', model_def_file, model_file)
+  % load network in TEST phase
+  caffe('init', model_def_file, model_file, 'test')
 end
 fprintf('Done with init\n');
 
@@ -38,7 +39,3 @@ else
   caffe('set_mode_cpu');
 end
 fprintf('Done with set_mode\n');
-
-% put into test mode
-caffe('set_phase_test');
-fprintf('Done with set_phase_test\n');