Add a caffe.io.write_mean function to the MATLAB interface
authorzoharby <Zohar Bar-Yehuda>
Fri, 11 Sep 2015 12:06:28 +0000 (15:06 +0300)
committerzoharby <Zohar Bar-Yehuda>
Mon, 5 Oct 2015 06:44:14 +0000 (09:44 +0300)
Useful for exporting models from MATLAB (e.g. MatConvNet) to Caffe

matlab/+caffe/io.m
matlab/+caffe/private/caffe_.cpp

index af8369d..4b072fe 100644 (file)
@@ -29,5 +29,13 @@ classdef io
       CHECK_FILE_EXIST(mean_proto_file);
       mean_data = caffe_('read_mean', mean_proto_file);
     end
+    function write_mean(mean_data, mean_proto_file)
+      % write_mean(mean_data, mean_proto_file)
+      %   write image mean data to binaryproto file
+      %   mean_data should be W x H x C with BGR channels
+      CHECK(ischar(mean_proto_file), 'mean_proto_file must be a string');
+      CHECK(isa(mean_data, 'single'), 'mean_data must be a SINGLE matrix');
+      caffe_('write_mean', mean_data, mean_proto_file);
+    end   
   end
 end
index 4e0ebc1..7883f79 100644 (file)
@@ -478,6 +478,29 @@ static void read_mean(MEX_ARGS) {
   mxFree(mean_proto_file);
 }
 
+// Usage: caffe_('write_mean', mean_data, mean_proto_file)
+static void write_mean(MEX_ARGS) {
+  mxCHECK(nrhs == 2 && mxIsSingle(prhs[0]) && mxIsChar(prhs[1]),
+      "Usage: caffe_('write_mean', mean_data, mean_proto_file)");
+  char* mean_proto_file = mxArrayToString(prhs[1]);
+  int ndims = mxGetNumberOfDimensions(prhs[0]);
+  mxCHECK(ndims >= 2 && ndims <= 3, "mean_data must have at 2 or 3 dimensions");
+  const mwSize *dims = mxGetDimensions(prhs[0]);
+  int width = dims[0];
+  int height = dims[1];
+  int channels;
+  if (ndims == 3)
+    channels = dims[2];
+  else
+    channels = 1;
+  Blob<float> data_mean(1, channels, height, width);
+  mx_mat_to_blob(prhs[0], &data_mean, DATA);
+  BlobProto blob_proto;
+  data_mean.ToProto(&blob_proto, false);
+  WriteProtoToBinaryFile(blob_proto, mean_proto_file);
+  mxFree(mean_proto_file);
+}
+
 /** -----------------------------------------------------------------
  ** Available commands.
  **/
@@ -515,6 +538,7 @@ static handler_registry handlers[] = {
   { "get_init_key",       get_init_key    },
   { "reset",              reset           },
   { "read_mean",          read_mean       },
+  { "write_mean",         write_mean      },
   // The end.
   { "END",                NULL            },
 };