return model weights
authorRoss Girshick <rbg@eecs.berkeley.edu>
Fri, 6 Dec 2013 04:58:03 +0000 (20:58 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 20 Mar 2014 02:12:54 +0000 (19:12 -0700)
matlab/caffe/matcaffe.cpp

index ddbacca..99ef3f4 100644 (file)
@@ -24,8 +24,8 @@ static shared_ptr<Net<float> > net_;
 //   matlab uses RGB color channel order
 //   images need to have the data mean subtracted
 //
-// Data coming in from matlab needs to be in the order
-//   [batch_images, channels, height, width]
+// Data coming in from matlab needs to be in the order 
+//   [width, height, channels, images]
 // where width is the fastest dimension.
 // Here is the rough matlab for putting image data into the correct
 // format:
@@ -87,7 +87,92 @@ static mxArray* do_forward(const mxArray* const bottom) {
   return mx_out;
 }
 
-// The caffe::Caffe utility functions.
+static mxArray* do_get_weights() {
+  const vector<shared_ptr<Layer<float> > >& layers = net_->layers();
+  const vector<string>& layer_names = net_->layer_names();
+
+  // Step 1: count the number of layers
+  int num_layers = 0;
+  {
+    string prev_layer_name = "";
+    for (unsigned int i = 0; i < layers.size(); ++i) {
+      vector<shared_ptr<Blob<float> > >& layer_blobs = layers[i]->blobs();
+      if (layer_blobs.size() == 0) {
+        continue;
+      }
+      if (layer_names[i] != prev_layer_name) {
+        prev_layer_name = layer_names[i];
+        num_layers++;
+      }
+    }
+  }
+
+  // Step 2: prepare output array of structures
+  mxArray* mx_layers;
+  {
+    const mwSize dims[2] = {num_layers, 1};
+    const char* fnames[2] = {"weights", "layer_names"};
+    mx_layers = mxCreateStructArray(2, dims, 2, fnames);
+  }
+
+  // Step 3: copy weights into output
+  {
+    string prev_layer_name = "";
+    int mx_layer_index = 0;
+    for (unsigned int i = 0; i < layers.size(); ++i) {
+      vector<shared_ptr<Blob<float> > >& layer_blobs = layers[i]->blobs();
+      if (layer_blobs.size() == 0) {
+        continue;
+      }
+
+      mxArray* mx_layer_cells = NULL;
+      if (layer_names[i] != prev_layer_name) {
+        prev_layer_name = layer_names[i];
+        const mwSize dims[2] = {layer_blobs.size(), 1};
+        mx_layer_cells = mxCreateCellArray(2, dims);
+        mxSetField(mx_layers, mx_layer_index, "weights", mx_layer_cells); 
+        mxSetField(mx_layers, mx_layer_index, "layer_names",
+            mxCreateString(layer_names[i].c_str()));
+        mx_layer_index++;
+      }
+
+      for (unsigned int j = 0; j < layer_blobs.size(); ++j) {
+        // internally data is stored as (width, height, channels, num)
+        // where width is the fastest dimension
+        mwSize dims[4] = {layer_blobs[j]->width(), layer_blobs[j]->height(),
+            layer_blobs[j]->channels(), layer_blobs[j]->num()};
+        mxArray* mx_weights = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL);
+        mxSetCell(mx_layer_cells, j, mx_weights);
+        float* weights_ptr = reinterpret_cast<float*>(mxGetPr(mx_weights));
+
+//        mexPrintf("layer: %s (%d) blob: %d  %d: (%d, %d, %d) %d\n", 
+//            layer_names[i].c_str(), i, j, layer_blobs[j]->num(), 
+//            layer_blobs[j]->height(), layer_blobs[j]->width(), 
+//            layer_blobs[j]->channels(), layer_blobs[j]->count());
+
+        switch (Caffe::mode()) {
+        case Caffe::CPU:
+          memcpy(weights_ptr, layer_blobs[j]->cpu_data(), 
+              sizeof(float) * layer_blobs[j]->count());
+          break;
+        case Caffe::GPU:
+          CUDA_CHECK(cudaMemcpy(weights_ptr, layer_blobs[j]->gpu_data(),
+              sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDeviceToHost));
+          break;
+        default:
+          LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+        }
+      }
+    }
+  }
+
+  return mx_layers;
+}
+
+static void get_weights(MEX_ARGS) {
+  plhs[0] = do_get_weights();
+}
+
 static void set_mode_cpu(MEX_ARGS) {
   Caffe::set_mode(Caffe::CPU);
 }
@@ -139,6 +224,14 @@ static void forward(MEX_ARGS) {
   plhs[0] = do_forward(prhs[0]);
 }
 
+static void is_initialized(MEX_ARGS) {
+  if (!net_) {
+    plhs[0] = mxCreateDoubleScalar(0);
+  } else {
+    plhs[0] = mxCreateDoubleScalar(1);
+  }
+}
+
 /** -----------------------------------------------------------------
  ** Available commands.
  **/
@@ -151,11 +244,13 @@ static handler_registry handlers[] = {
   // Public API functions
   { "forward",            forward         },
   { "init",               init            },
+  { "is_initialized",     is_initialized  },
   { "set_mode_cpu",       set_mode_cpu    },
   { "set_mode_gpu",       set_mode_gpu    },
   { "set_phase_train",    set_phase_train },
   { "set_phase_test",     set_phase_test  },
   { "set_device",         set_device      },
+  { "get_weights",        get_weights     },
   // The end.
   { "END",                NULL            },
 };