Aesthetic changes on code style and some minor fix
authorRonghang Hu <huronghang@hotmail.com>
Thu, 28 May 2015 05:40:26 +0000 (13:40 +0800)
committerRonghang Hu <huronghang@hotmail.com>
Fri, 29 May 2015 05:21:25 +0000 (13:21 +0800)
matlab/+caffe/Blob.m
matlab/+caffe/Layer.m
matlab/+caffe/Net.m
matlab/+caffe/Solver.m
matlab/+caffe/get_net.m
matlab/+caffe/get_solver.m
matlab/+caffe/io.m
matlab/+caffe/private/caffe_.cpp
matlab/+caffe/private/is_valid_handle.m
matlab/+caffe/run_tests.m

index f9b6409..e39f7ee 100644 (file)
@@ -7,9 +7,9 @@ classdef Blob < handle
   
   methods
     function self = Blob(hBlob_blob)
-      CHECK(is_valid_handle(hBlob_blob), 'invalid input handle');
+      CHECK(is_valid_handle(hBlob_blob), 'invalid Blob handle');
       
-      % setup self handle and attributes
+      % setup self handle
       self.hBlob_self = hBlob_blob;
     end
     function shape = shape(self)
@@ -37,14 +37,16 @@ classdef Blob < handle
   
   methods (Access = private)
     function shape = check_and_preprocess_shape(~, shape)
-      CHECK(isempty(shape) || isnumeric(shape) && isrow(shape), ...
+      CHECK(isempty(shape) || (isnumeric(shape) && isrow(shape)), ...
         'shape must be a integer row vector');
       shape = double(shape);
     end
     function data = check_and_preprocess_data(self, data)
       CHECK(isnumeric(data), 'data or diff must be numeric types');
-      self.check_data_size_matches(data)
-      data = single(data);
+      self.check_data_size_matches(data);
+      if ~isa(data, 'single')
+        data = single(data);
+      end
     end
     function check_data_size_matches(self, data)
       % check whether size of data matches shape of this blob
@@ -59,17 +61,17 @@ classdef Blob < handle
         % target blob is a vector (1 dim)
         self_shape_extended = [self_shape_extended, 1];
       end
-      % also, matlab cannot have tailing dimension 1 for ndim > 2, so you
+      % Also, matlab cannot have tailing dimension 1 for ndim > 2, so you
       % cannot create 20 x 10 x 1 x 1 array in matlab as it becomes 20 x 10
-      % extend matlab arrays to have tailing dimension 1 during shape match
+      % Extend matlab arrays to have tailing dimension 1 during shape match
       data_size_extended = ...
         [size(data), ones(1, length(self_shape_extended) - ndims(data))];
       is_matched = ...
-      (length(self_shape_extended) == length(data_size_extended)) ...
+        (length(self_shape_extended) == length(data_size_extended)) ...
         && all(self_shape_extended == data_size_extended);
       CHECK(is_matched, ...
-        sprintf('%s, data size: [ %s], blob shape: [ %s]', ...
-        'data size does not match blob shape', ...
+        sprintf('%s, input data/diff size: [ %s] vs target blob shape: [ %s]', ...
+        'input data/diff size does not match target blob shape', ...
         sprintf('%d ', data_size_extended), sprintf('%d ', self_shape_extended)));
     end
   end
index 7587ed7..4c20231 100644 (file)
@@ -13,7 +13,7 @@ classdef Layer < handle
   
   methods
     function self = Layer(hLayer_layer)
-      CHECK(is_valid_handle(hLayer_layer), 'invalid input handle');
+      CHECK(is_valid_handle(hLayer_layer), 'invalid Layer handle');
       
       % setup self handle and attributes
       self.hLayer_self = hLayer_layer;
index 5319634..a676106 100644 (file)
@@ -33,7 +33,7 @@ classdef Net < handle
       end
       % construct a net from handle
       hNet_net = varargin{1};
-      CHECK(is_valid_handle(hNet_net), 'invalid input handle');
+      CHECK(is_valid_handle(hNet_net), 'invalid Net handle');
       
       % setup self handle and attributes
       self.hNet_self = hNet_net;
@@ -64,7 +64,7 @@ classdef Net < handle
       self.name2blob_index = containers.Map(self.attributes.blob_names, ...
         1:length(self.attributes.blob_names));
       
-      % expose layer_names and blob_names for public access
+      % expose layer_names and blob_names for public read access
       self.layer_names = self.attributes.layer_names;
       self.blob_names = self.attributes.blob_names;
     end
@@ -91,12 +91,12 @@ classdef Net < handle
       CHECK(iscell(input_data), 'input_data must be a cell array');
       CHECK(length(input_data) == length(self.inputs), ...
         'input data cell length must match input blob number');
-      % copy data to input_blobs
+      % copy data to input blobs
       for n = 1:length(self.inputs)
         self.blobs(self.inputs{n}).set_data(input_data{n});
       end
       self.forward_prefilled();
-      % retrieve data from output_blobs
+      % retrieve data from output blobs
       res = cell(length(self.outputs), 1);
       for n = 1:length(self.outputs)
         res{n} = self.blobs(self.outputs{n}).get_data();
@@ -106,7 +106,7 @@ classdef Net < handle
       CHECK(iscell(output_diff), 'output_diff must be a cell array');
       CHECK(length(output_diff) == length(self.outputs), ...
         'output diff cell length must match output blob number');
-      % copy diff to output_blobs
+      % copy diff to output blobs
       for n = 1:length(self.outputs)
         self.blobs(self.outputs{n}).set_diff(output_diff{n});
       end
index 80fa539..daaa802 100644 (file)
@@ -23,7 +23,7 @@ classdef Solver < handle
       end
       % construct a solver from handle
       hSolver_solver = varargin{1};
-      CHECK(is_valid_handle(hSolver_solver), 'invalid input handle');
+      CHECK(is_valid_handle(hSolver_solver), 'invalid Solver handle');
       
       % setup self handle and attributes
       self.hSolver_self = hSolver_solver;
index d60979d..4b5683e 100644 (file)
@@ -5,7 +5,7 @@ function net = get_net(varargin)
 %   phase_name can only be 'train' or 'test'
 
 CHECK(nargin == 2 || nargin == 3, ['usage: ' ...
-  'net = get_net(model_file, phase_name) or ', ...
+  'net = get_net(model_file, phase_name) or ' ...
   'net = get_net(model_file, weights_file, phase_name)']);
 if nargin == 3
   model_file = varargin{1};
@@ -23,6 +23,7 @@ CHECK(strcmp(phase_name, 'train') || strcmp(phase_name, 'test'), ...
   sprintf('phase_name can only be %strain%s or %stest%s', ...
   char(39), char(39), char(39), char(39)));
 
+% construct caffe net from model_file
 hNet = caffe_('get_net', model_file, phase_name);
 net = caffe.Net(hNet);
 
index 30366d8..74d576e 100644 (file)
@@ -4,7 +4,6 @@ function solver = get_solver(solver_file)
 
 CHECK(ischar(solver_file), 'solver_file must be a string');
 CHECK_FILE_EXIST(solver_file);
-
 pSolver = caffe_('get_solver', solver_file);
 solver = caffe.Solver(pSolver);
 
index 7fad968..7a30bfb 100644 (file)
@@ -3,17 +3,20 @@ classdef io
   
   methods (Static)
     function im_data = load_image(im_file)
+      % im_data = load_image(im_file)
+      %   load an image from disk into Caffe-supported data format
+      %   switch channels from RGB to BGR, make width the fastest dimension
+      %   and convert to single
       CHECK(ischar(im_file), 'im_file must be a string');
       CHECK_FILE_EXIST(im_file);
-      %   load an image from disk into Caffe-supported data format
-      %   switch channels from RGB to BGR, make width the fastest dimension, and
-      %   convert to single
-      im = imread(im_file);
-      im_data = im(:, :, [3, 2, 1]);
-      im_data = permute(im_data, [2 1 3]);
+      im_data = imread(im_file);
+      im_data = im_data(:, :, [3, 2, 1]);
+      im_data = permute(im_data, [2, 1, 3]);
       im_data = single(im_data);
     end
     function mean_data = read_mean(mean_proto_file)
+      % mean_data = read_mean(mean_proto_file)
+      %   read image mean data from binaryproto file
       CHECK(ischar(mean_proto_file), 'im_file must be a string');
       CHECK_FILE_EXIST(mean_proto_file);
       mean_data = caffe_('read_mean', mean_proto_file);
index 96a1920..4e0ebc1 100644 (file)
 
 using namespace caffe;  // NOLINT(build/namespaces)
 
-// Do CHECK and throw a Mex error if check failsf
-inline void mxCHECK(bool expr, const std::string &msg) {
+// Do CHECK and throw a Mex error if check fails
+inline void mxCHECK(bool expr, const char* msg) {
   if (!expr) {
-    LOG(ERROR) << msg;
-    mexErrMsgTxt(msg.c_str());
+    mexErrMsgTxt(msg);
   }
 }
-inline void mxERROR(const std::string &msg) { mxCHECK(false, msg); }
+inline void mxERROR(const char* msg) { mexErrMsgTxt(msg); }
 
 // Check if a file exists and can be opened
 void mxCHECK_FILE_EXIST(const char* file) {
   std::ifstream f(file);
   if (!f.good()) {
     f.close();
-    mxERROR("Could not open file " + string(file));
+    std::string msg("Could not open file ");
+    msg += file;
+    mxERROR(msg.c_str());
   }
   f.close();
 }
@@ -43,6 +44,7 @@ void mxCHECK_FILE_EXIST(const char* file) {
 // The pointers to caffe::Solver and caffe::Net instances
 static vector<shared_ptr<Solver<float> > > solvers_;
 static vector<shared_ptr<Net<float> > > nets_;
+// init_key is generated at the beginning and everytime you call reset
 static double init_key = static_cast<double>(caffe_rng_rand());
 
 /** -----------------------------------------------------------------
@@ -104,17 +106,17 @@ static mxArray* blob_to_mx_mat(const Blob<float>* blob,
   return mx_mat;
 }
 
-// convert vector<int> to matlab vector
+// Convert vector<int> to matlab row vector
 static mxArray* int_vec_to_mx_vec(const vector<int>& int_vec) {
   mxArray* mx_vec = mxCreateDoubleMatrix(int_vec.size(), 1, mxREAL);
   double* vec_mem_ptr = mxGetPr(mx_vec);
   for (int i = 0; i < int_vec.size(); i++) {
-    vec_mem_ptr[i] = int_vec[i];
+    vec_mem_ptr[i] = static_cast<double>(int_vec[i]);
   }
   return mx_vec;
 }
 
-// convert vector<string> to matlab string cell vector
+// Convert vector<string> to matlab cell vector of strings
 static mxArray* str_vec_to_mx_strcell(const vector<std::string>& str_vec) {
   mxArray* mx_strcell = mxCreateCellMatrix(str_vec.size(), 1);
   for (int i = 0; i < str_vec.size(); i++) {
@@ -135,44 +137,46 @@ static T* handle_to_ptr(const mxArray* mx_handle) {
   mxArray* mx_ptr = mxGetField(mx_handle, 0, "ptr");
   mxArray* mx_init_key = mxGetField(mx_handle, 0, "init_key");
   mxCHECK(mxIsUint64(mx_ptr), "pointer type must be uint64");
-  mxCHECK(mxGetScalar(mx_init_key) == init_key, "incorrect handle init_key");
+  mxCHECK(mxGetScalar(mx_init_key) == init_key,
+      "Could not convert handle to pointer due to invalid init_key. "
+      "The object might have been cleared.");
   return reinterpret_cast<T*>(*reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)));
 }
 
-// Create an empty handle struct array
+// Create a handle struct vector, without setting up each handle in it
 template <typename T>
-static mxArray* create_handles(int ptr_num) {
+static mxArray* create_handle_vec(int ptr_num) {
   const int handle_field_num = 2;
   const char* handle_fields[handle_field_num] = { "ptr", "init_key" };
   return mxCreateStructMatrix(ptr_num, 1, handle_field_num, handle_fields);
 }
 
-// Set up each handle in a handle struct array
+// Set up a handle in a handle struct vector by its index
 template <typename T>
-static void setup_handle(const T* ptr, int index, mxArray* mx_handles) {
+static void setup_handle(const T* ptr, int index, mxArray* mx_handle_vec) {
   mxArray* mx_ptr = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL);
   *reinterpret_cast<uint64_t*>(mxGetData(mx_ptr)) =
       reinterpret_cast<uint64_t>(ptr);
-  mxSetField(mx_handles, index, "ptr", mx_ptr);
-  mxSetField(mx_handles, index, "init_key", mxCreateDoubleScalar(init_key));
+  mxSetField(mx_handle_vec, index, "ptr", mx_ptr);
+  mxSetField(mx_handle_vec, index, "init_key", mxCreateDoubleScalar(init_key));
 }
 
 // Convert a pointer in C++ to a handle in matlab
 template <typename T>
 static mxArray* ptr_to_handle(const T* ptr) {
-  mxArray* mx_handle = create_handles<T>(1);
+  mxArray* mx_handle = create_handle_vec<T>(1);
   setup_handle(ptr, 0, mx_handle);
   return mx_handle;
 }
 
-// Convert a vector of shared_ptr in C++ to handle struct array in matlab
+// Convert a vector of shared_ptr in C++ to handle struct vector
 template <typename T>
-static mxArray* ptr_vec_to_handles(const vector<shared_ptr<T> >& ptr_vec) {
-  mxArray* mx_handle = create_handles<T>(ptr_vec.size());
+static mxArray* ptr_vec_to_handle_vec(const vector<shared_ptr<T> >& ptr_vec) {
+  mxArray* mx_handle_vec = create_handle_vec<T>(ptr_vec.size());
   for (int i = 0; i < ptr_vec.size(); i++) {
-    setup_handle(ptr_vec[i].get(), i, mx_handle);
+    setup_handle(ptr_vec[i].get(), i, mx_handle_vec);
   }
-  return mx_handle;
+  return mx_handle_vec;
 }
 
 /** -----------------------------------------------------------------
@@ -182,12 +186,12 @@ static mxArray* ptr_vec_to_handles(const vector<shared_ptr<T> >& ptr_vec) {
 static void get_solver(MEX_ARGS) {
   mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
       "Usage: caffe_('get_solver', solver_file)");
-  const char* solver_file = mxArrayToString(prhs[0]);
+  char* solver_file = mxArrayToString(prhs[0]);
   mxCHECK_FILE_EXIST(solver_file);
-  shared_ptr<Solver<float> > solver;
-  solver.reset(new caffe::SGDSolver<float>(solver_file));
+  shared_ptr<Solver<float> > solver(new caffe::SGDSolver<float>(solver_file));
   solvers_.push_back(solver);
   plhs[0] = ptr_to_handle<Solver<float> >(solver.get());
+  mxFree(solver_file);
 }
 
 // Usage: caffe_('solver_get_attr', hSolver)
@@ -202,7 +206,7 @@ static void solver_get_attr(MEX_ARGS) {
   mxSetField(mx_solver_attr, 0, "hNet_net",
       ptr_to_handle<Net<float> >(solver->net().get()));
   mxSetField(mx_solver_attr, 0, "hNet_test_nets",
-      ptr_vec_to_handles<Net<float> >(solver->test_nets()));
+      ptr_vec_to_handle_vec<Net<float> >(solver->test_nets()));
   plhs[0] = mx_solver_attr;
 }
 
@@ -219,9 +223,10 @@ static void solver_restore(MEX_ARGS) {
   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
       "Usage: caffe_('solver_restore', hSolver, snapshot_file)");
   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
-  const char* snapshot_file = mxArrayToString(prhs[1]);
+  char* snapshot_file = mxArrayToString(prhs[1]);
   mxCHECK_FILE_EXIST(snapshot_file);
   solver->Restore(snapshot_file);
+  mxFree(snapshot_file);
 }
 
 // Usage: caffe_('solver_solve', hSolver)
@@ -232,10 +237,10 @@ static void solver_solve(MEX_ARGS) {
   solver->Solve();
 }
 
-// Usage: caffe_('solver_solve', hSolver, iters)
+// Usage: caffe_('solver_step', hSolver, iters)
 static void solver_step(MEX_ARGS) {
   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsDouble(prhs[1]),
-      "Usage: caffe_('solver_solve', hSolver, iters)");
+      "Usage: caffe_('solver_step', hSolver, iters)");
   Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
   int iters = mxGetScalar(prhs[1]);
   solver->Step(iters);
@@ -245,21 +250,22 @@ static void solver_step(MEX_ARGS) {
 static void get_net(MEX_ARGS) {
   mxCHECK(nrhs == 2 && mxIsChar(prhs[0]) && mxIsChar(prhs[1]),
       "Usage: caffe_('get_net', model_file, phase_name)");
-  const char* model_file = mxArrayToString(prhs[0]);
+  char* model_file = mxArrayToString(prhs[0]);
+  char* phase_name = mxArrayToString(prhs[1]);
   mxCHECK_FILE_EXIST(model_file);
-  const char* phase_name = mxArrayToString(prhs[1]);
   Phase phase;
   if (strcmp(phase_name, "train") == 0) {
       phase = TRAIN;
   } else if (strcmp(phase_name, "test") == 0) {
       phase = TEST;
   } else {
-    mxERROR("Unknown phase.");
+    mxERROR("Unknown phase");
   }
-  shared_ptr<Net<float> > net;
-  net.reset(new caffe::Net<float>(model_file, phase));
+  shared_ptr<Net<float> > net(new caffe::Net<float>(model_file, phase));
   nets_.push_back(net);
   plhs[0] = ptr_to_handle<Net<float> >(net.get());
+  mxFree(model_file);
+  mxFree(phase_name);
 }
 
 // Usage: caffe_('net_get_attr', hNet)
@@ -273,15 +279,15 @@ static void net_get_attr(MEX_ARGS) {
   mxArray* mx_net_attr = mxCreateStructMatrix(1, 1, net_attr_num,
       net_attrs);
   mxSetField(mx_net_attr, 0, "hLayer_layers",
-      ptr_vec_to_handles<Layer<float> >(net->layers()));
+      ptr_vec_to_handle_vec<Layer<float> >(net->layers()));
   mxSetField(mx_net_attr, 0, "hBlob_blobs",
-      ptr_vec_to_handles<Blob<float> >(net->blobs()));
+      ptr_vec_to_handle_vec<Blob<float> >(net->blobs()));
   mxSetField(mx_net_attr, 0, "input_blob_indices",
       int_vec_to_mx_vec(net->input_blob_indices()));
   mxSetField(mx_net_attr, 0, "output_blob_indices",
       int_vec_to_mx_vec(net->output_blob_indices()));
   mxSetField(mx_net_attr, 0, "layer_names",
-    str_vec_to_mx_strcell(net->layer_names()));
+      str_vec_to_mx_strcell(net->layer_names()));
   mxSetField(mx_net_attr, 0, "blob_names",
       str_vec_to_mx_strcell(net->blob_names()));
   plhs[0] = mx_net_attr;
@@ -308,9 +314,10 @@ static void net_copy_from(MEX_ARGS) {
   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
       "Usage: caffe_('net_copy_from', hNet, weights_file)");
   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
-  const char* weights_file = mxArrayToString(prhs[1]);
+  char* weights_file = mxArrayToString(prhs[1]);
   mxCHECK_FILE_EXIST(weights_file);
   net->CopyTrainedLayersFrom(weights_file);
+  mxFree(weights_file);
 }
 
 // Usage: caffe_('net_reshape', hNet)
@@ -326,10 +333,11 @@ static void net_save(MEX_ARGS) {
   mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]),
       "Usage: caffe_('net_save', hNet, save_file)");
   Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
-  const char* weights_file = mxArrayToString(prhs[1]);
+  char* weights_file = mxArrayToString(prhs[1]);
   NetParameter net_param;
   net->ToProto(&net_param, false);
   WriteProtoToBinaryFile(net_param, weights_file);
+  mxFree(weights_file);
 }
 
 // Usage: caffe_('layer_get_attr', hLayer)
@@ -342,14 +350,14 @@ static void layer_get_attr(MEX_ARGS) {
   mxArray* mx_layer_attr = mxCreateStructMatrix(1, 1, layer_attr_num,
       layer_attrs);
   mxSetField(mx_layer_attr, 0, "hBlob_blobs",
-      ptr_vec_to_handles<Blob<float> >(layer->blobs()));
+      ptr_vec_to_handle_vec<Blob<float> >(layer->blobs()));
   plhs[0] = mx_layer_attr;
 }
 
-// Usage: caffe_('layer_get_attr', hLayer)
+// Usage: caffe_('layer_get_type', hLayer)
 static void layer_get_type(MEX_ARGS) {
   mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
-      "Usage: caffe_('layer_get_attr', hLayer)");
+      "Usage: caffe_('layer_get_type', hLayer)");
   Layer<float>* layer = handle_to_ptr<Layer<float> >(prhs[0]);
   plhs[0] = mxCreateString(layer->type());
 }
@@ -446,10 +454,12 @@ static void get_init_key(MEX_ARGS) {
 // Usage: caffe_('reset')
 static void reset(MEX_ARGS) {
   mxCHECK(nrhs == 0, "Usage: caffe_('reset')");
-  mexPrintf("cleared %d solvers and %d stand-alone nets\n", solvers_.size(),
-      nets_.size());
+  // Clear solvers and stand-alone nets
+  mexPrintf("Cleared %d solvers and %d stand-alone nets\n",
+      solvers_.size(), nets_.size());
   solvers_.clear();
   nets_.clear();
+  // Generate new init_key, so that handles created before becomes invalid
   init_key = static_cast<double>(caffe_rng_rand());
 }
 
@@ -457,21 +467,15 @@ static void reset(MEX_ARGS) {
 static void read_mean(MEX_ARGS) {
   mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
       "Usage: caffe_('read_mean', mean_proto_file)");
-  const char* mean_proto_file = mxArrayToString(prhs[0]);
+  char* mean_proto_file = mxArrayToString(prhs[0]);
+  mxCHECK_FILE_EXIST(mean_proto_file);
   Blob<float> data_mean;
-  mexPrintf("Loading mean file from: %s\n", mean_proto_file);
   BlobProto blob_proto;
   bool result = ReadProtoFromBinaryFile(mean_proto_file, &blob_proto);
-  mxCHECK(result, "Couldn't read the file");
+  mxCHECK(result, "Could not read your mean file");
   data_mean.FromProto(blob_proto);
-  mwSize dims[4] = {data_mean.width(), data_mean.height(),
-                    data_mean.channels(), data_mean.num() };
-  mxArray* mx_blob =  mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL);
-  float* data_ptr = reinterpret_cast<float*>(mxGetData(mx_blob));
-  caffe_copy(data_mean.count(), data_mean.cpu_data(), data_ptr);
-  mexPrintf("Remember that Caffe saves in [width, height, channels]"
-                " format and channels are also BGR!\n");
-  plhs[0] = mx_blob;
+  plhs[0] = blob_to_mx_mat(&data_mean, DATA);
+  mxFree(mean_proto_file);
 }
 
 /** -----------------------------------------------------------------
@@ -516,31 +520,27 @@ static handler_registry handlers[] = {
 };
 
 /** -----------------------------------------------------------------
- ** matlab entry point: caffe_(api_command, arg1, arg2, ...)
+ ** matlab entry point.
  **/
+// Usage: caffe_(api_command, arg1, arg2, ...)
 void mexFunction(MEX_ARGS) {
   mexLock();  // Avoid clearing the mex file.
-  if (nrhs == 0) {
-    mxERROR("No API command given");
-    return;
-  }
-
-  { // Handle input command
-    char* cmd = mxArrayToString(prhs[0]);
-    bool dispatched = false;
-    // Dispatch to cmd handler
-    for (int i = 0; handlers[i].func != NULL; i++) {
-      if (handlers[i].cmd.compare(cmd) == 0) {
-        handlers[i].func(nlhs, plhs, nrhs-1, prhs+1);
-        dispatched = true;
-        break;
-      }
-    }
-    if (!dispatched) {
-      ostringstream error_msg;
-      error_msg << "Unknown command '" << cmd << "'";
-      mxERROR(error_msg.str().c_str());
+  mxCHECK(nrhs > 0, "Usage: caffe_(api_command, arg1, arg2, ...)");
+  // Handle input command
+  char* cmd = mxArrayToString(prhs[0]);
+  bool dispatched = false;
+  // Dispatch to cmd handler
+  for (int i = 0; handlers[i].func != NULL; i++) {
+    if (handlers[i].cmd.compare(cmd) == 0) {
+      handlers[i].func(nlhs, plhs, nrhs-1, prhs+1);
+      dispatched = true;
+      break;
     }
-    mxFree(cmd);
   }
+  if (!dispatched) {
+    ostringstream error_msg;
+    error_msg << "Unknown command '" << cmd << "'";
+    mxERROR(error_msg.str().c_str());
+  }
+  mxFree(cmd);
 }
index 77abf21..a0648ec 100644 (file)
@@ -15,7 +15,6 @@ end
 % is_valid_handle('get_new_init_key') to get new init_key from C++;
 if ischar(hObj) && strcmp(hObj, 'get_new_init_key')
   init_key = caffe_('get_init_key');
-  valid = true;
   return
 else
   % check whether data types are correct and init_key matches
index fb1089c..afdd8f3 100644 (file)
@@ -2,11 +2,15 @@ function results = run_tests()
 % results = run_tests()
 %   run all tests in this caffe matlab wrapper package
 
+% reset caffe before testing
 caffe.reset();
+
+% put all test cases here
 results = [...
   run(caffe.test.test_net) ...
-  run(caffe.test.test_solver)
-  ];
+  run(caffe.test.test_solver) ];
+
+% reset caffe after testing
 caffe.reset();
 
 end