a32bd5e536d995ede8f96b73cbd5bca18c67afce
[platform/upstream/caffeonacl.git] / matlab / +caffe / private / caffe_.cpp
1 //
2 // caffe_.cpp provides wrappers of the caffe::Solver class, caffe::Net class,
3 // caffe::Layer class and caffe::Blob class and some caffe::Caffe functions,
4 // so that one could easily use Caffe from matlab.
5 // Note that for matlab, we will simply use float as the data type.
6
7 // Internally, data is stored with dimensions reversed from Caffe's:
8 // e.g., if the Caffe blob axes are (num, channels, height, width),
9 // the matcaffe data is stored as (width, height, channels, num)
10 // where width is the fastest dimension.
11
12 #include <sstream>
13 #include <string>
14 #include <vector>
15
16 #include "mex.h"
17
18 #include "caffe/caffe.hpp"
19
20 #define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs
21
22 using namespace caffe;  // NOLINT(build/namespaces)
23
24 // Do CHECK and throw a Mex error if check fails
25 inline void mxCHECK(bool expr, const char* msg) {
26   if (!expr) {
27     mexErrMsgTxt(msg);
28   }
29 }
30 inline void mxERROR(const char* msg) { mexErrMsgTxt(msg); }
31
32 // Check if a file exists and can be opened
33 void mxCHECK_FILE_EXIST(const char* file) {
34   std::ifstream f(file);
35   if (!f.good()) {
36     f.close();
37     std::string msg("Could not open file ");
38     msg += file;
39     mxERROR(msg.c_str());
40   }
41   f.close();
42 }
43
44 // The pointers to caffe::Solver and caffe::Net instances
45 static vector<shared_ptr<Solver<float> > > solvers_;
46 static vector<shared_ptr<Net<float> > > nets_;
47 // init_key is generated at the beginning and every time you call reset
48 static double init_key = static_cast<double>(caffe_rng_rand());
49
50 /** -----------------------------------------------------------------
51  ** data conversion functions
52  **/
53 // Enum indicates which blob memory to use
54 enum WhichMemory { DATA, DIFF };
55
56 // Copy matlab array to Blob data or diff
57 static void mx_mat_to_blob(const mxArray* mx_mat, Blob<float>* blob,
58     WhichMemory data_or_diff) {
59   mxCHECK(blob->count() == mxGetNumberOfElements(mx_mat),
60       "number of elements in target blob doesn't match that in input mxArray");
61   const float* mat_mem_ptr = reinterpret_cast<const float*>(mxGetData(mx_mat));
62   float* blob_mem_ptr = NULL;
63   switch (Caffe::mode()) {
64   case Caffe::CPU:
65     blob_mem_ptr = (data_or_diff == DATA ?
66         blob->mutable_cpu_data() : blob->mutable_cpu_diff());
67     break;
68   case Caffe::GPU:
69     blob_mem_ptr = (data_or_diff == DATA ?
70         blob->mutable_gpu_data() : blob->mutable_gpu_diff());
71     break;
72   default:
73     mxERROR("Unknown Caffe mode");
74   }
75   caffe_copy(blob->count(), mat_mem_ptr, blob_mem_ptr);
76 }
77
78 // Copy Blob data or diff to matlab array
79 static mxArray* blob_to_mx_mat(const Blob<float>* blob,
80     WhichMemory data_or_diff) {
81   const int num_axes = blob->num_axes();
82   vector<mwSize> dims(num_axes);
83   for (int blob_axis = 0, mat_axis = num_axes - 1; blob_axis < num_axes;
84        ++blob_axis, --mat_axis) {
85     dims[mat_axis] = static_cast<mwSize>(blob->shape(blob_axis));
86   }
87   // matlab array needs to have at least one dimension, convert scalar to 1-dim
88   if (num_axes == 0) {
89     dims.push_back(1);
90   }
91   mxArray* mx_mat =
92       mxCreateNumericArray(dims.size(), dims.data(), mxSINGLE_CLASS, mxREAL);
93   float* mat_mem_ptr = reinterpret_cast<float*>(mxGetData(mx_mat));
94   const float* blob_mem_ptr = NULL;
95   switch (Caffe::mode()) {
96   case Caffe::CPU:
97     blob_mem_ptr = (data_or_diff == DATA ? blob->cpu_data() : blob->cpu_diff());
98     break;
99   case Caffe::GPU:
100     blob_mem_ptr = (data_or_diff == DATA ? blob->gpu_data() : blob->gpu_diff());
101     break;
102   default:
103     mxERROR("Unknown Caffe mode");
104   }
105   caffe_copy(blob->count(), blob_mem_ptr, mat_mem_ptr);
106   return mx_mat;
107 }
108
109 // Convert vector<int> to matlab row vector
110 static mxArray* int_vec_to_mx_vec(const vector<int>& int_vec) {
111   mxArray* mx_vec = mxCreateDoubleMatrix(int_vec.size(), 1, mxREAL);
112   double* vec_mem_ptr = mxGetPr(mx_vec);
113   for (int i = 0; i < int_vec.size(); i++) {
114     vec_mem_ptr[i] = static_cast<double>(int_vec[i]);
115   }
116   return mx_vec;
117 }
118
119 // Convert vector<string> to matlab cell vector of strings
120 static mxArray* str_vec_to_mx_strcell(const vector<std::string>& str_vec) {
121   mxArray* mx_strcell = mxCreateCellMatrix(str_vec.size(), 1);
122   for (int i = 0; i < str_vec.size(); i++) {
123     mxSetCell(mx_strcell, i, mxCreateString(str_vec[i].c_str()));
124   }
125   return mx_strcell;
126 }
127
128 /** -----------------------------------------------------------------
129  ** handle and pointer conversion functions
130  ** a handle is a struct array with the following fields
131  **   (uint64) ptr      : the pointer to the C++ object
132  **   (double) init_key : caffe initialization key
133  **/
134 // Convert a handle in matlab to a pointer in C++. Check if init_key matches
135 template <typename T>
136 static T* handle_to_ptr(const mxArray* mx_handle) {
137   mxArray* mx_ptr = mxGetField(mx_handle, 0, "ptr");
138   mxArray* mx_init_key = mxGetField(mx_handle, 0, "init_key");
139   mxCHECK(mxIsUint64(mx_ptr), "pointer type must be uint64");
140   mxCHECK(mxGetScalar(mx_init_key) == init_key,
141       "Could not convert handle to pointer due to invalid init_key. "
142       "The object might have been cleared.");
143   return reinterpret_cast<T*>(*reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)));
144 }
145
146 // Create a handle struct vector, without setting up each handle in it
147 template <typename T>
148 static mxArray* create_handle_vec(int ptr_num) {
149   const int handle_field_num = 2;
150   const char* handle_fields[handle_field_num] = { "ptr", "init_key" };
151   return mxCreateStructMatrix(ptr_num, 1, handle_field_num, handle_fields);
152 }
153
154 // Set up a handle in a handle struct vector by its index
155 template <typename T>
156 static void setup_handle(const T* ptr, int index, mxArray* mx_handle_vec) {
157   mxArray* mx_ptr = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL);
158   *reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)) =
159       reinterpret_cast<uint64_t>(ptr);
160   mxSetField(mx_handle_vec, index, "ptr", mx_ptr);
161   mxSetField(mx_handle_vec, index, "init_key", mxCreateDoubleScalar(init_key));
162 }
163
164 // Convert a pointer in C++ to a handle in matlab
165 template <typename T>
166 static mxArray* ptr_to_handle(const T* ptr) {
167   mxArray* mx_handle = create_handle_vec<T>(1);
168   setup_handle(ptr, 0, mx_handle);
169   return mx_handle;
170 }
171
172 // Convert a vector of shared_ptr in C++ to handle struct vector
173 template <typename T>
174 static mxArray* ptr_vec_to_handle_vec(const vector<shared_ptr<T> >& ptr_vec) {
175   mxArray* mx_handle_vec = create_handle_vec<T>(ptr_vec.size());
176   for (int i = 0; i < ptr_vec.size(); i++) {
177     setup_handle(ptr_vec[i].get(), i, mx_handle_vec);
178   }
179   return mx_handle_vec;
180 }
181
182 /** -----------------------------------------------------------------
183  ** matlab command functions: caffe_(api_command, arg1, arg2, ...)
184  **/
185 // Usage: caffe_('get_solver', solver_file);
186 static void get_solver(MEX_ARGS) {
187   mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
188       "Usage: caffe_('get_solver', solver_file)");
189   char* solver_file = mxArrayToString(prhs[0]);
190   mxCHECK_FILE_EXIST(solver_file);
191   SolverParameter solver_param;
192   ReadSolverParamsFromTextFileOrDie(solver_file, &solver_param);
193   shared_ptr<Solver<float> > solver(
194       SolverRegistry<float>::CreateSolver(solver_param));
195   solvers_.push_back(solver);
196   plhs[0] = ptr_to_handle<Solver<float> >(solver.get());
197   mxFree(solver_file);
198 }
199
200 // Usage: caffe_('delete_solver', hSolver)
201 static void delete_solver(MEX_ARGS) {
202   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
203       "Usage: caffe_('delete_solver', hSolver)");
204   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
205   solvers_.erase(std::remove_if(solvers_.begin(), solvers_.end(),
206       [solver] (const shared_ptr< Solver<float> > &solverPtr) {
207       return solverPtr.get() == solver;
208   }), solvers_.end());
209 }
210
211 // Usage: caffe_('solver_get_attr', hSolver)
212 static void solver_get_attr(MEX_ARGS) {
213   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
214       "Usage: caffe_('solver_get_attr', hSolver)");
215   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
216   const int solver_attr_num = 2;
217   const char* solver_attrs[solver_attr_num] = { "hNet_net", "hNet_test_nets" };
218   mxArray* mx_solver_attr = mxCreateStructMatrix(1, 1, solver_attr_num,
219       solver_attrs);
220   mxSetField(mx_solver_attr, 0, "hNet_net",
221       ptr_to_handle<Net<float> >(solver->net().get()));
222   mxSetField(mx_solver_attr, 0, "hNet_test_nets",
223       ptr_vec_to_handle_vec<Net<float> >(solver->test_nets()));
224   plhs[0] = mx_solver_attr;
225 }
226
227 // Usage: caffe_('solver_get_iter', hSolver)
228 static void solver_get_iter(MEX_ARGS) {
229   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
230       "Usage: caffe_('solver_get_iter', hSolver)");
231   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
232   plhs[0] = mxCreateDoubleScalar(solver->iter());
233 }
234
235 // Usage: caffe_('solver_restore', hSolver, snapshot_file)
236 static void solver_restore(MEX_ARGS) {
237   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
238       "Usage: caffe_('solver_restore', hSolver, snapshot_file)");
239   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
240   char* snapshot_file = mxArrayToString(prhs[1]);
241   mxCHECK_FILE_EXIST(snapshot_file);
242   solver->Restore(snapshot_file);
243   mxFree(snapshot_file);
244 }
245
246 // Usage: caffe_('solver_solve', hSolver)
247 static void solver_solve(MEX_ARGS) {
248   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
249       "Usage: caffe_('solver_solve', hSolver)");
250   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
251   solver->Solve();
252 }
253
254 // Usage: caffe_('solver_step', hSolver, iters)
255 static void solver_step(MEX_ARGS) {
256   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsDouble(prhs[1]),
257       "Usage: caffe_('solver_step', hSolver, iters)");
258   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
259   int iters = mxGetScalar(prhs[1]);
260   solver->Step(iters);
261 }
262
263 // Usage: caffe_('get_net', model_file, phase_name)
264 static void get_net(MEX_ARGS) {
265   mxCHECK(nrhs == 2 && mxIsChar(prhs[0]) && mxIsChar(prhs[1]),
266       "Usage: caffe_('get_net', model_file, phase_name)");
267   char* model_file = mxArrayToString(prhs[0]);
268   char* phase_name = mxArrayToString(prhs[1]);
269   mxCHECK_FILE_EXIST(model_file);
270   Phase phase;
271   if (strcmp(phase_name, "train") == 0) {
272       phase = TRAIN;
273   } else if (strcmp(phase_name, "test") == 0) {
274       phase = TEST;
275   } else {
276     mxERROR("Unknown phase");
277   }
278   shared_ptr<Net<float> > net(new caffe::Net<float>(model_file, phase));
279   nets_.push_back(net);
280   plhs[0] = ptr_to_handle<Net<float> >(net.get());
281   mxFree(model_file);
282   mxFree(phase_name);
283 }
284
285 // Usage: caffe_('delete_solver', hSolver)
286 static void delete_net(MEX_ARGS) {
287   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
288       "Usage: caffe_('delete_solver', hNet)");
289   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
290   nets_.erase(std::remove_if(nets_.begin(), nets_.end(),
291       [net] (const shared_ptr< Net<float> > &netPtr) {
292       return netPtr.get() == net;
293   }), nets_.end());
294 }
295
296 // Usage: caffe_('net_get_attr', hNet)
297 static void net_get_attr(MEX_ARGS) {
298   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
299       "Usage: caffe_('net_get_attr', hNet)");
300   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
301   const int net_attr_num = 6;
302   const char* net_attrs[net_attr_num] = { "hLayer_layers", "hBlob_blobs",
303       "input_blob_indices", "output_blob_indices", "layer_names", "blob_names"};
304   mxArray* mx_net_attr = mxCreateStructMatrix(1, 1, net_attr_num,
305       net_attrs);
306   mxSetField(mx_net_attr, 0, "hLayer_layers",
307       ptr_vec_to_handle_vec<Layer<float> >(net->layers()));
308   mxSetField(mx_net_attr, 0, "hBlob_blobs",
309       ptr_vec_to_handle_vec<Blob<float> >(net->blobs()));
310   mxSetField(mx_net_attr, 0, "input_blob_indices",
311       int_vec_to_mx_vec(net->input_blob_indices()));
312   mxSetField(mx_net_attr, 0, "output_blob_indices",
313       int_vec_to_mx_vec(net->output_blob_indices()));
314   mxSetField(mx_net_attr, 0, "layer_names",
315       str_vec_to_mx_strcell(net->layer_names()));
316   mxSetField(mx_net_attr, 0, "blob_names",
317       str_vec_to_mx_strcell(net->blob_names()));
318   plhs[0] = mx_net_attr;
319 }
320
321 // Usage: caffe_('net_forward', hNet)
322 static void net_forward(MEX_ARGS) {
323   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
324       "Usage: caffe_('net_forward', hNet)");
325   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
326   net->ForwardPrefilled();
327 }
328
329 // Usage: caffe_('net_backward', hNet)
330 static void net_backward(MEX_ARGS) {
331   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
332       "Usage: caffe_('net_backward', hNet)");
333   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
334   net->Backward();
335 }
336
337 // Usage: caffe_('net_copy_from', hNet, weights_file)
338 static void net_copy_from(MEX_ARGS) {
339   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
340       "Usage: caffe_('net_copy_from', hNet, weights_file)");
341   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
342   char* weights_file = mxArrayToString(prhs[1]);
343   mxCHECK_FILE_EXIST(weights_file);
344   net->CopyTrainedLayersFrom(weights_file);
345   mxFree(weights_file);
346 }
347
348 // Usage: caffe_('net_reshape', hNet)
349 static void net_reshape(MEX_ARGS) {
350   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
351       "Usage: caffe_('net_reshape', hNet)");
352   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
353   net->Reshape();
354 }
355
356 // Usage: caffe_('net_save', hNet, save_file)
357 static void net_save(MEX_ARGS) {
358   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
359       "Usage: caffe_('net_save', hNet, save_file)");
360   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
361   char* weights_file = mxArrayToString(prhs[1]);
362   NetParameter net_param;
363   net->ToProto(&net_param, false);
364   WriteProtoToBinaryFile(net_param, weights_file);
365   mxFree(weights_file);
366 }
367
368 // Usage: caffe_('layer_get_attr', hLayer)
369 static void layer_get_attr(MEX_ARGS) {
370   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
371       "Usage: caffe_('layer_get_attr', hLayer)");
372   Layer<float>* layer = handle_to_ptr<Layer<float> >(prhs[0]);
373   const int layer_attr_num = 1;
374   const char* layer_attrs[layer_attr_num] = { "hBlob_blobs" };
375   mxArray* mx_layer_attr = mxCreateStructMatrix(1, 1, layer_attr_num,
376       layer_attrs);
377   mxSetField(mx_layer_attr, 0, "hBlob_blobs",
378       ptr_vec_to_handle_vec<Blob<float> >(layer->blobs()));
379   plhs[0] = mx_layer_attr;
380 }
381
382 // Usage: caffe_('layer_get_type', hLayer)
383 static void layer_get_type(MEX_ARGS) {
384   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
385       "Usage: caffe_('layer_get_type', hLayer)");
386   Layer<float>* layer = handle_to_ptr<Layer<float> >(prhs[0]);
387   plhs[0] = mxCreateString(layer->type());
388 }
389
390 // Usage: caffe_('blob_get_shape', hBlob)
391 static void blob_get_shape(MEX_ARGS) {
392   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
393       "Usage: caffe_('blob_get_shape', hBlob)");
394   Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
395   const int num_axes = blob->num_axes();
396   mxArray* mx_shape = mxCreateDoubleMatrix(1, num_axes, mxREAL);
397   double* shape_mem_mtr = mxGetPr(mx_shape);
398   for (int blob_axis = 0, mat_axis = num_axes - 1; blob_axis < num_axes;
399        ++blob_axis, --mat_axis) {
400     shape_mem_mtr[mat_axis] = static_cast<double>(blob->shape(blob_axis));
401   }
402   plhs[0] = mx_shape;
403 }
404
405 // Usage: caffe_('blob_reshape', hBlob, new_shape)
406 static void blob_reshape(MEX_ARGS) {
407   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsDouble(prhs[1]),
408       "Usage: caffe_('blob_reshape', hBlob, new_shape)");
409   Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
410   const mxArray* mx_shape = prhs[1];
411   double* shape_mem_mtr = mxGetPr(mx_shape);
412   const int num_axes = mxGetNumberOfElements(mx_shape);
413   vector<int> blob_shape(num_axes);
414   for (int blob_axis = 0, mat_axis = num_axes - 1; blob_axis < num_axes;
415        ++blob_axis, --mat_axis) {
416     blob_shape[blob_axis] = static_cast<int>(shape_mem_mtr[mat_axis]);
417   }
418   blob->Reshape(blob_shape);
419 }
420
421 // Usage: caffe_('blob_get_data', hBlob)
422 static void blob_get_data(MEX_ARGS) {
423   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
424       "Usage: caffe_('blob_get_data', hBlob)");
425   Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
426   plhs[0] = blob_to_mx_mat(blob, DATA);
427 }
428
429 // Usage: caffe_('blob_set_data', hBlob, new_data)
430 static void blob_set_data(MEX_ARGS) {
431   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsSingle(prhs[1]),
432       "Usage: caffe_('blob_set_data', hBlob, new_data)");
433   Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
434   mx_mat_to_blob(prhs[1], blob, DATA);
435 }
436
437 // Usage: caffe_('blob_get_diff', hBlob)
438 static void blob_get_diff(MEX_ARGS) {
439   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
440       "Usage: caffe_('blob_get_diff', hBlob)");
441   Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
442   plhs[0] = blob_to_mx_mat(blob, DIFF);
443 }
444
445 // Usage: caffe_('blob_set_diff', hBlob, new_diff)
446 static void blob_set_diff(MEX_ARGS) {
447   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsSingle(prhs[1]),
448       "Usage: caffe_('blob_set_diff', hBlob, new_diff)");
449   Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
450   mx_mat_to_blob(prhs[1], blob, DIFF);
451 }
452
453 // Usage: caffe_('set_mode_cpu')
454 static void set_mode_cpu(MEX_ARGS) {
455   mxCHECK(nrhs == 0, "Usage: caffe_('set_mode_cpu')");
456   Caffe::set_mode(Caffe::CPU);
457 }
458
459 // Usage: caffe_('set_mode_gpu')
460 static void set_mode_gpu(MEX_ARGS) {
461   mxCHECK(nrhs == 0, "Usage: caffe_('set_mode_gpu')");
462   Caffe::set_mode(Caffe::GPU);
463 }
464
465 // Usage: caffe_('set_device', device_id)
466 static void set_device(MEX_ARGS) {
467   mxCHECK(nrhs == 1 && mxIsDouble(prhs[0]),
468       "Usage: caffe_('set_device', device_id)");
469   int device_id = static_cast<int>(mxGetScalar(prhs[0]));
470   Caffe::SetDevice(device_id);
471 }
472
473 // Usage: caffe_('get_init_key')
474 static void get_init_key(MEX_ARGS) {
475   mxCHECK(nrhs == 0, "Usage: caffe_('get_init_key')");
476   plhs[0] = mxCreateDoubleScalar(init_key);
477 }
478
479 // Usage: caffe_('reset')
480 static void reset(MEX_ARGS) {
481   mxCHECK(nrhs == 0, "Usage: caffe_('reset')");
482   // Clear solvers and stand-alone nets
483   mexPrintf("Cleared %d solvers and %d stand-alone nets\n",
484       solvers_.size(), nets_.size());
485   solvers_.clear();
486   nets_.clear();
487   // Generate new init_key, so that handles created before becomes invalid
488   init_key = static_cast<double>(caffe_rng_rand());
489 }
490
491 // Usage: caffe_('read_mean', mean_proto_file)
492 static void read_mean(MEX_ARGS) {
493   mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
494       "Usage: caffe_('read_mean', mean_proto_file)");
495   char* mean_proto_file = mxArrayToString(prhs[0]);
496   mxCHECK_FILE_EXIST(mean_proto_file);
497   Blob<float> data_mean;
498   BlobProto blob_proto;
499   bool result = ReadProtoFromBinaryFile(mean_proto_file, &blob_proto);
500   mxCHECK(result, "Could not read your mean file");
501   data_mean.FromProto(blob_proto);
502   plhs[0] = blob_to_mx_mat(&data_mean, DATA);
503   mxFree(mean_proto_file);
504 }
505
506 // Usage: caffe_('write_mean', mean_data, mean_proto_file)
507 static void write_mean(MEX_ARGS) {
508   mxCHECK(nrhs == 2 && mxIsSingle(prhs[0]) && mxIsChar(prhs[1]),
509       "Usage: caffe_('write_mean', mean_data, mean_proto_file)");
510   char* mean_proto_file = mxArrayToString(prhs[1]);
511   int ndims = mxGetNumberOfDimensions(prhs[0]);
512   mxCHECK(ndims >= 2 && ndims <= 3, "mean_data must have at 2 or 3 dimensions");
513   const mwSize *dims = mxGetDimensions(prhs[0]);
514   int width = dims[0];
515   int height = dims[1];
516   int channels;
517   if (ndims == 3)
518     channels = dims[2];
519   else
520     channels = 1;
521   Blob<float> data_mean(1, channels, height, width);
522   mx_mat_to_blob(prhs[0], &data_mean, DATA);
523   BlobProto blob_proto;
524   data_mean.ToProto(&blob_proto, false);
525   WriteProtoToBinaryFile(blob_proto, mean_proto_file);
526   mxFree(mean_proto_file);
527 }
528
529 // Usage: caffe_('version')
530 static void version(MEX_ARGS) {
531   mxCHECK(nrhs == 0, "Usage: caffe_('version')");
532   // Return version string
533   plhs[0] = mxCreateString(AS_STRING(CAFFE_VERSION));
534 }
535
536 /** -----------------------------------------------------------------
537  ** Available commands.
538  **/
539 struct handler_registry {
540   string cmd;
541   void (*func)(MEX_ARGS);
542 };
543
544 static handler_registry handlers[] = {
545   // Public API functions
546   { "get_solver",         get_solver      },
547   { "delete_solver",      delete_solver   },
548   { "solver_get_attr",    solver_get_attr },
549   { "solver_get_iter",    solver_get_iter },
550   { "solver_restore",     solver_restore  },
551   { "solver_solve",       solver_solve    },
552   { "solver_step",        solver_step     },
553   { "get_net",            get_net         },
554   { "delete_net",         delete_net      },
555   { "net_get_attr",       net_get_attr    },
556   { "net_forward",        net_forward     },
557   { "net_backward",       net_backward    },
558   { "net_copy_from",      net_copy_from   },
559   { "net_reshape",        net_reshape     },
560   { "net_save",           net_save        },
561   { "layer_get_attr",     layer_get_attr  },
562   { "layer_get_type",     layer_get_type  },
563   { "blob_get_shape",     blob_get_shape  },
564   { "blob_reshape",       blob_reshape    },
565   { "blob_get_data",      blob_get_data   },
566   { "blob_set_data",      blob_set_data   },
567   { "blob_get_diff",      blob_get_diff   },
568   { "blob_set_diff",      blob_set_diff   },
569   { "set_mode_cpu",       set_mode_cpu    },
570   { "set_mode_gpu",       set_mode_gpu    },
571   { "set_device",         set_device      },
572   { "get_init_key",       get_init_key    },
573   { "reset",              reset           },
574   { "read_mean",          read_mean       },
575   { "write_mean",         write_mean      },
576   { "version",            version         },
577   // The end.
578   { "END",                NULL            },
579 };
580
581 /** -----------------------------------------------------------------
582  ** matlab entry point.
583  **/
584 // Usage: caffe_(api_command, arg1, arg2, ...)
585 void mexFunction(MEX_ARGS) {
586   mexLock();  // Avoid clearing the mex file.
587   mxCHECK(nrhs > 0, "Usage: caffe_(api_command, arg1, arg2, ...)");
588   // Handle input command
589   char* cmd = mxArrayToString(prhs[0]);
590   bool dispatched = false;
591   // Dispatch to cmd handler
592   for (int i = 0; handlers[i].func != NULL; i++) {
593     if (handlers[i].cmd.compare(cmd) == 0) {
594       handlers[i].func(nlhs, plhs, nrhs-1, prhs+1);
595       dispatched = true;
596       break;
597     }
598   }
599   if (!dispatched) {
600     ostringstream error_msg;
601     error_msg << "Unknown command '" << cmd << "'";
602     mxERROR(error_msg.str().c_str());
603   }
604   mxFree(cmd);
605 }