2 // caffe_.cpp provides wrappers of the caffe::Solver class, caffe::Net class,
3 // caffe::Layer class and caffe::Blob class and some caffe::Caffe functions,
4 // so that one could easily use Caffe from matlab.
5 // Note that for matlab, we will simply use float as the data type.
7 // Internally, data is stored with dimensions reversed from Caffe's:
8 // e.g., if the Caffe blob axes are (num, channels, height, width),
9 // the matcaffe data is stored as (width, height, channels, num)
10 // where width is the fastest dimension.
18 #include "caffe/caffe.hpp"
20 #define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs
22 using namespace caffe; // NOLINT(build/namespaces)
24 // Do CHECK and throw a Mex error if check fails
25 inline void mxCHECK(bool expr, const char* msg) {
30 inline void mxERROR(const char* msg) { mexErrMsgTxt(msg); }
32 // Check if a file exists and can be opened
33 void mxCHECK_FILE_EXIST(const char* file) {
34 std::ifstream f(file);
37 std::string msg("Could not open file ");
44 // The pointers to caffe::Solver and caffe::Net instances
45 static vector<shared_ptr<Solver<float> > > solvers_;
46 static vector<shared_ptr<Net<float> > > nets_;
47 // init_key is generated at the beginning and everytime you call reset
48 static double init_key = static_cast<double>(caffe_rng_rand());
50 /** -----------------------------------------------------------------
51 ** data conversion functions
53 // Enum indicates which blob memory to use
54 enum WhichMemory { DATA, DIFF };
56 // Copy matlab array to Blob data or diff
57 static void mx_mat_to_blob(const mxArray* mx_mat, Blob<float>* blob,
58 WhichMemory data_or_diff) {
59 mxCHECK(blob->count() == mxGetNumberOfElements(mx_mat),
60 "number of elements in target blob doesn't match that in input mxArray");
61 const float* mat_mem_ptr = reinterpret_cast<const float*>(mxGetData(mx_mat));
62 float* blob_mem_ptr = NULL;
63 switch (Caffe::mode()) {
65 blob_mem_ptr = (data_or_diff == DATA ?
66 blob->mutable_cpu_data() : blob->mutable_cpu_diff());
69 blob_mem_ptr = (data_or_diff == DATA ?
70 blob->mutable_gpu_data() : blob->mutable_gpu_diff());
73 mxERROR("Unknown Caffe mode");
75 caffe_copy(blob->count(), mat_mem_ptr, blob_mem_ptr);
78 // Copy Blob data or diff to matlab array
79 static mxArray* blob_to_mx_mat(const Blob<float>* blob,
80 WhichMemory data_or_diff) {
81 const int num_axes = blob->num_axes();
82 vector<mwSize> dims(num_axes);
83 for (int blob_axis = 0, mat_axis = num_axes - 1; blob_axis < num_axes;
84 ++blob_axis, --mat_axis) {
85 dims[mat_axis] = static_cast<mwSize>(blob->shape(blob_axis));
87 // matlab array needs to have at least one dimension, convert scalar to 1-dim
92 mxCreateNumericArray(dims.size(), dims.data(), mxSINGLE_CLASS, mxREAL);
93 float* mat_mem_ptr = reinterpret_cast<float*>(mxGetData(mx_mat));
94 const float* blob_mem_ptr = NULL;
95 switch (Caffe::mode()) {
97 blob_mem_ptr = (data_or_diff == DATA ? blob->cpu_data() : blob->cpu_diff());
100 blob_mem_ptr = (data_or_diff == DATA ? blob->gpu_data() : blob->gpu_diff());
103 mxERROR("Unknown Caffe mode");
105 caffe_copy(blob->count(), blob_mem_ptr, mat_mem_ptr);
109 // Convert vector<int> to matlab row vector
110 static mxArray* int_vec_to_mx_vec(const vector<int>& int_vec) {
111 mxArray* mx_vec = mxCreateDoubleMatrix(int_vec.size(), 1, mxREAL);
112 double* vec_mem_ptr = mxGetPr(mx_vec);
113 for (int i = 0; i < int_vec.size(); i++) {
114 vec_mem_ptr[i] = static_cast<double>(int_vec[i]);
119 // Convert vector<string> to matlab cell vector of strings
120 static mxArray* str_vec_to_mx_strcell(const vector<std::string>& str_vec) {
121 mxArray* mx_strcell = mxCreateCellMatrix(str_vec.size(), 1);
122 for (int i = 0; i < str_vec.size(); i++) {
123 mxSetCell(mx_strcell, i, mxCreateString(str_vec[i].c_str()));
128 /** -----------------------------------------------------------------
129 ** handle and pointer conversion functions
130 ** a handle is a struct array with the following fields
131 ** (uint64) ptr : the pointer to the C++ object
132 ** (double) init_key : caffe initialization key
134 // Convert a handle in matlab to a pointer in C++. Check if init_key matches
135 template <typename T>
136 static T* handle_to_ptr(const mxArray* mx_handle) {
137 mxArray* mx_ptr = mxGetField(mx_handle, 0, "ptr");
138 mxArray* mx_init_key = mxGetField(mx_handle, 0, "init_key");
139 mxCHECK(mxIsUint64(mx_ptr), "pointer type must be uint64");
140 mxCHECK(mxGetScalar(mx_init_key) == init_key,
141 "Could not convert handle to pointer due to invalid init_key. "
142 "The object might have been cleared.");
143 return reinterpret_cast<T*>(*reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)));
146 // Create a handle struct vector, without setting up each handle in it
147 template <typename T>
148 static mxArray* create_handle_vec(int ptr_num) {
149 const int handle_field_num = 2;
150 const char* handle_fields[handle_field_num] = { "ptr", "init_key" };
151 return mxCreateStructMatrix(ptr_num, 1, handle_field_num, handle_fields);
154 // Set up a handle in a handle struct vector by its index
155 template <typename T>
156 static void setup_handle(const T* ptr, int index, mxArray* mx_handle_vec) {
157 mxArray* mx_ptr = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL);
158 *reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)) =
159 reinterpret_cast<uint64_t>(ptr);
160 mxSetField(mx_handle_vec, index, "ptr", mx_ptr);
161 mxSetField(mx_handle_vec, index, "init_key", mxCreateDoubleScalar(init_key));
164 // Convert a pointer in C++ to a handle in matlab
165 template <typename T>
166 static mxArray* ptr_to_handle(const T* ptr) {
167 mxArray* mx_handle = create_handle_vec<T>(1);
168 setup_handle(ptr, 0, mx_handle);
172 // Convert a vector of shared_ptr in C++ to handle struct vector
173 template <typename T>
174 static mxArray* ptr_vec_to_handle_vec(const vector<shared_ptr<T> >& ptr_vec) {
175 mxArray* mx_handle_vec = create_handle_vec<T>(ptr_vec.size());
176 for (int i = 0; i < ptr_vec.size(); i++) {
177 setup_handle(ptr_vec[i].get(), i, mx_handle_vec);
179 return mx_handle_vec;
182 /** -----------------------------------------------------------------
183 ** matlab command functions: caffe_(api_command, arg1, arg2, ...)
185 // Usage: caffe_('get_solver', solver_file);
186 static void get_solver(MEX_ARGS) {
187 mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
188 "Usage: caffe_('get_solver', solver_file)");
189 char* solver_file = mxArrayToString(prhs[0]);
190 mxCHECK_FILE_EXIST(solver_file);
191 SolverParameter solver_param;
192 ReadSolverParamsFromTextFileOrDie(solver_file, &solver_param);
193 shared_ptr<Solver<float> > solver(
194 SolverRegistry<float>::CreateSolver(solver_param));
195 solvers_.push_back(solver);
196 plhs[0] = ptr_to_handle<Solver<float> >(solver.get());
200 // Usage: caffe_('solver_get_attr', hSolver)
201 static void solver_get_attr(MEX_ARGS) {
202 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
203 "Usage: caffe_('solver_get_attr', hSolver)");
204 Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
205 const int solver_attr_num = 2;
206 const char* solver_attrs[solver_attr_num] = { "hNet_net", "hNet_test_nets" };
207 mxArray* mx_solver_attr = mxCreateStructMatrix(1, 1, solver_attr_num,
209 mxSetField(mx_solver_attr, 0, "hNet_net",
210 ptr_to_handle<Net<float> >(solver->net().get()));
211 mxSetField(mx_solver_attr, 0, "hNet_test_nets",
212 ptr_vec_to_handle_vec<Net<float> >(solver->test_nets()));
213 plhs[0] = mx_solver_attr;
216 // Usage: caffe_('solver_get_iter', hSolver)
217 static void solver_get_iter(MEX_ARGS) {
218 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
219 "Usage: caffe_('solver_get_iter', hSolver)");
220 Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
221 plhs[0] = mxCreateDoubleScalar(solver->iter());
224 // Usage: caffe_('solver_restore', hSolver, snapshot_file)
225 static void solver_restore(MEX_ARGS) {
226 mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
227 "Usage: caffe_('solver_restore', hSolver, snapshot_file)");
228 Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
229 char* snapshot_file = mxArrayToString(prhs[1]);
230 mxCHECK_FILE_EXIST(snapshot_file);
231 solver->Restore(snapshot_file);
232 mxFree(snapshot_file);
235 // Usage: caffe_('solver_solve', hSolver)
236 static void solver_solve(MEX_ARGS) {
237 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
238 "Usage: caffe_('solver_solve', hSolver)");
239 Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
243 // Usage: caffe_('solver_step', hSolver, iters)
244 static void solver_step(MEX_ARGS) {
245 mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsDouble(prhs[1]),
246 "Usage: caffe_('solver_step', hSolver, iters)");
247 Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
248 int iters = mxGetScalar(prhs[1]);
252 // Usage: caffe_('get_net', model_file, phase_name)
253 static void get_net(MEX_ARGS) {
254 mxCHECK(nrhs == 2 && mxIsChar(prhs[0]) && mxIsChar(prhs[1]),
255 "Usage: caffe_('get_net', model_file, phase_name)");
256 char* model_file = mxArrayToString(prhs[0]);
257 char* phase_name = mxArrayToString(prhs[1]);
258 mxCHECK_FILE_EXIST(model_file);
260 if (strcmp(phase_name, "train") == 0) {
262 } else if (strcmp(phase_name, "test") == 0) {
265 mxERROR("Unknown phase");
267 shared_ptr<Net<float> > net(new caffe::Net<float>(model_file, phase));
268 nets_.push_back(net);
269 plhs[0] = ptr_to_handle<Net<float> >(net.get());
274 // Usage: caffe_('net_get_attr', hNet)
275 static void net_get_attr(MEX_ARGS) {
276 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
277 "Usage: caffe_('net_get_attr', hNet)");
278 Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
279 const int net_attr_num = 6;
280 const char* net_attrs[net_attr_num] = { "hLayer_layers", "hBlob_blobs",
281 "input_blob_indices", "output_blob_indices", "layer_names", "blob_names"};
282 mxArray* mx_net_attr = mxCreateStructMatrix(1, 1, net_attr_num,
284 mxSetField(mx_net_attr, 0, "hLayer_layers",
285 ptr_vec_to_handle_vec<Layer<float> >(net->layers()));
286 mxSetField(mx_net_attr, 0, "hBlob_blobs",
287 ptr_vec_to_handle_vec<Blob<float> >(net->blobs()));
288 mxSetField(mx_net_attr, 0, "input_blob_indices",
289 int_vec_to_mx_vec(net->input_blob_indices()));
290 mxSetField(mx_net_attr, 0, "output_blob_indices",
291 int_vec_to_mx_vec(net->output_blob_indices()));
292 mxSetField(mx_net_attr, 0, "layer_names",
293 str_vec_to_mx_strcell(net->layer_names()));
294 mxSetField(mx_net_attr, 0, "blob_names",
295 str_vec_to_mx_strcell(net->blob_names()));
296 plhs[0] = mx_net_attr;
299 // Usage: caffe_('net_forward', hNet)
300 static void net_forward(MEX_ARGS) {
301 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
302 "Usage: caffe_('net_forward', hNet)");
303 Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
304 net->ForwardPrefilled();
307 // Usage: caffe_('net_backward', hNet)
308 static void net_backward(MEX_ARGS) {
309 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
310 "Usage: caffe_('net_backward', hNet)");
311 Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
315 // Usage: caffe_('net_copy_from', hNet, weights_file)
316 static void net_copy_from(MEX_ARGS) {
317 mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
318 "Usage: caffe_('net_copy_from', hNet, weights_file)");
319 Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
320 char* weights_file = mxArrayToString(prhs[1]);
321 mxCHECK_FILE_EXIST(weights_file);
322 net->CopyTrainedLayersFrom(weights_file);
323 mxFree(weights_file);
326 // Usage: caffe_('net_reshape', hNet)
327 static void net_reshape(MEX_ARGS) {
328 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
329 "Usage: caffe_('net_reshape', hNet)");
330 Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
334 // Usage: caffe_('net_save', hNet, save_file)
335 static void net_save(MEX_ARGS) {
336 mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
337 "Usage: caffe_('net_save', hNet, save_file)");
338 Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
339 char* weights_file = mxArrayToString(prhs[1]);
340 NetParameter net_param;
341 net->ToProto(&net_param, false);
342 WriteProtoToBinaryFile(net_param, weights_file);
343 mxFree(weights_file);
346 // Usage: caffe_('layer_get_attr', hLayer)
347 static void layer_get_attr(MEX_ARGS) {
348 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
349 "Usage: caffe_('layer_get_attr', hLayer)");
350 Layer<float>* layer = handle_to_ptr<Layer<float> >(prhs[0]);
351 const int layer_attr_num = 1;
352 const char* layer_attrs[layer_attr_num] = { "hBlob_blobs" };
353 mxArray* mx_layer_attr = mxCreateStructMatrix(1, 1, layer_attr_num,
355 mxSetField(mx_layer_attr, 0, "hBlob_blobs",
356 ptr_vec_to_handle_vec<Blob<float> >(layer->blobs()));
357 plhs[0] = mx_layer_attr;
360 // Usage: caffe_('layer_get_type', hLayer)
361 static void layer_get_type(MEX_ARGS) {
362 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
363 "Usage: caffe_('layer_get_type', hLayer)");
364 Layer<float>* layer = handle_to_ptr<Layer<float> >(prhs[0]);
365 plhs[0] = mxCreateString(layer->type());
368 // Usage: caffe_('blob_get_shape', hBlob)
369 static void blob_get_shape(MEX_ARGS) {
370 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
371 "Usage: caffe_('blob_get_shape', hBlob)");
372 Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
373 const int num_axes = blob->num_axes();
374 mxArray* mx_shape = mxCreateDoubleMatrix(1, num_axes, mxREAL);
375 double* shape_mem_mtr = mxGetPr(mx_shape);
376 for (int blob_axis = 0, mat_axis = num_axes - 1; blob_axis < num_axes;
377 ++blob_axis, --mat_axis) {
378 shape_mem_mtr[mat_axis] = static_cast<double>(blob->shape(blob_axis));
383 // Usage: caffe_('blob_reshape', hBlob, new_shape)
384 static void blob_reshape(MEX_ARGS) {
385 mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsDouble(prhs[1]),
386 "Usage: caffe_('blob_reshape', hBlob, new_shape)");
387 Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
388 const mxArray* mx_shape = prhs[1];
389 double* shape_mem_mtr = mxGetPr(mx_shape);
390 const int num_axes = mxGetNumberOfElements(mx_shape);
391 vector<int> blob_shape(num_axes);
392 for (int blob_axis = 0, mat_axis = num_axes - 1; blob_axis < num_axes;
393 ++blob_axis, --mat_axis) {
394 blob_shape[blob_axis] = static_cast<int>(shape_mem_mtr[mat_axis]);
396 blob->Reshape(blob_shape);
399 // Usage: caffe_('blob_get_data', hBlob)
400 static void blob_get_data(MEX_ARGS) {
401 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
402 "Usage: caffe_('blob_get_data', hBlob)");
403 Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
404 plhs[0] = blob_to_mx_mat(blob, DATA);
407 // Usage: caffe_('blob_set_data', hBlob, new_data)
408 static void blob_set_data(MEX_ARGS) {
409 mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsSingle(prhs[1]),
410 "Usage: caffe_('blob_set_data', hBlob, new_data)");
411 Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
412 mx_mat_to_blob(prhs[1], blob, DATA);
415 // Usage: caffe_('blob_get_diff', hBlob)
416 static void blob_get_diff(MEX_ARGS) {
417 mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
418 "Usage: caffe_('blob_get_diff', hBlob)");
419 Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
420 plhs[0] = blob_to_mx_mat(blob, DIFF);
423 // Usage: caffe_('blob_set_diff', hBlob, new_diff)
424 static void blob_set_diff(MEX_ARGS) {
425 mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsSingle(prhs[1]),
426 "Usage: caffe_('blob_set_diff', hBlob, new_diff)");
427 Blob<float>* blob = handle_to_ptr<Blob<float> >(prhs[0]);
428 mx_mat_to_blob(prhs[1], blob, DIFF);
431 // Usage: caffe_('set_mode_cpu')
432 static void set_mode_cpu(MEX_ARGS) {
433 mxCHECK(nrhs == 0, "Usage: caffe_('set_mode_cpu')");
434 Caffe::set_mode(Caffe::CPU);
437 // Usage: caffe_('set_mode_gpu')
438 static void set_mode_gpu(MEX_ARGS) {
439 mxCHECK(nrhs == 0, "Usage: caffe_('set_mode_gpu')");
440 Caffe::set_mode(Caffe::GPU);
443 // Usage: caffe_('set_device', device_id)
444 static void set_device(MEX_ARGS) {
445 mxCHECK(nrhs == 1 && mxIsDouble(prhs[0]),
446 "Usage: caffe_('set_device', device_id)");
447 int device_id = static_cast<int>(mxGetScalar(prhs[0]));
448 Caffe::SetDevice(device_id);
451 // Usage: caffe_('get_init_key')
452 static void get_init_key(MEX_ARGS) {
453 mxCHECK(nrhs == 0, "Usage: caffe_('get_init_key')");
454 plhs[0] = mxCreateDoubleScalar(init_key);
457 // Usage: caffe_('reset')
458 static void reset(MEX_ARGS) {
459 mxCHECK(nrhs == 0, "Usage: caffe_('reset')");
460 // Clear solvers and stand-alone nets
461 mexPrintf("Cleared %d solvers and %d stand-alone nets\n",
462 solvers_.size(), nets_.size());
465 // Generate new init_key, so that handles created before becomes invalid
466 init_key = static_cast<double>(caffe_rng_rand());
469 // Usage: caffe_('read_mean', mean_proto_file)
470 static void read_mean(MEX_ARGS) {
471 mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
472 "Usage: caffe_('read_mean', mean_proto_file)");
473 char* mean_proto_file = mxArrayToString(prhs[0]);
474 mxCHECK_FILE_EXIST(mean_proto_file);
475 Blob<float> data_mean;
476 BlobProto blob_proto;
477 bool result = ReadProtoFromBinaryFile(mean_proto_file, &blob_proto);
478 mxCHECK(result, "Could not read your mean file");
479 data_mean.FromProto(blob_proto);
480 plhs[0] = blob_to_mx_mat(&data_mean, DATA);
481 mxFree(mean_proto_file);
484 // Usage: caffe_('write_mean', mean_data, mean_proto_file)
485 static void write_mean(MEX_ARGS) {
486 mxCHECK(nrhs == 2 && mxIsSingle(prhs[0]) && mxIsChar(prhs[1]),
487 "Usage: caffe_('write_mean', mean_data, mean_proto_file)");
488 char* mean_proto_file = mxArrayToString(prhs[1]);
489 int ndims = mxGetNumberOfDimensions(prhs[0]);
490 mxCHECK(ndims >= 2 && ndims <= 3, "mean_data must have at 2 or 3 dimensions");
491 const mwSize *dims = mxGetDimensions(prhs[0]);
493 int height = dims[1];
499 Blob<float> data_mean(1, channels, height, width);
500 mx_mat_to_blob(prhs[0], &data_mean, DATA);
501 BlobProto blob_proto;
502 data_mean.ToProto(&blob_proto, false);
503 WriteProtoToBinaryFile(blob_proto, mean_proto_file);
504 mxFree(mean_proto_file);
507 // Usage: caffe_('version')
508 static void version(MEX_ARGS) {
509 mxCHECK(nrhs == 0, "Usage: caffe_('version')");
510 // Return version string
511 plhs[0] = mxCreateString(AS_STRING(CAFFE_VERSION));
514 /** -----------------------------------------------------------------
515 ** Available commands.
517 struct handler_registry {
519 void (*func)(MEX_ARGS);
522 static handler_registry handlers[] = {
523 // Public API functions
524 { "get_solver", get_solver },
525 { "solver_get_attr", solver_get_attr },
526 { "solver_get_iter", solver_get_iter },
527 { "solver_restore", solver_restore },
528 { "solver_solve", solver_solve },
529 { "solver_step", solver_step },
530 { "get_net", get_net },
531 { "net_get_attr", net_get_attr },
532 { "net_forward", net_forward },
533 { "net_backward", net_backward },
534 { "net_copy_from", net_copy_from },
535 { "net_reshape", net_reshape },
536 { "net_save", net_save },
537 { "layer_get_attr", layer_get_attr },
538 { "layer_get_type", layer_get_type },
539 { "blob_get_shape", blob_get_shape },
540 { "blob_reshape", blob_reshape },
541 { "blob_get_data", blob_get_data },
542 { "blob_set_data", blob_set_data },
543 { "blob_get_diff", blob_get_diff },
544 { "blob_set_diff", blob_set_diff },
545 { "set_mode_cpu", set_mode_cpu },
546 { "set_mode_gpu", set_mode_gpu },
547 { "set_device", set_device },
548 { "get_init_key", get_init_key },
550 { "read_mean", read_mean },
551 { "write_mean", write_mean },
552 { "version", version },
557 /** -----------------------------------------------------------------
558 ** matlab entry point.
560 // Usage: caffe_(api_command, arg1, arg2, ...)
561 void mexFunction(MEX_ARGS) {
562 mexLock(); // Avoid clearing the mex file.
563 mxCHECK(nrhs > 0, "Usage: caffe_(api_command, arg1, arg2, ...)");
564 // Handle input command
565 char* cmd = mxArrayToString(prhs[0]);
566 bool dispatched = false;
567 // Dispatch to cmd handler
568 for (int i = 0; handlers[i].func != NULL; i++) {
569 if (handlers[i].cmd.compare(cmd) == 0) {
570 handlers[i].func(nlhs, plhs, nrhs-1, prhs+1);
576 ostringstream error_msg;
577 error_msg << "Unknown command '" << cmd << "'";
578 mxERROR(error_msg.str().c_str());