Prevent Matlab on OS X from crashing on error
authorDaniel Golden <dgolden1@gmail.com>
Wed, 5 Nov 2014 21:36:00 +0000 (13:36 -0800)
committerDaniel Golden <dgolden1@gmail.com>
Thu, 6 Nov 2014 01:31:43 +0000 (17:31 -0800)
Replace CHECK() and LOG(FATAL) with LOG(ERROR) and mexErrMsgTxt

A failed CHECK() or LOG(FATAL) causes Matlab to crash on OS X 10.9 with Matlab 2014a.

matlab/caffe/matcaffe.cpp

index fc04758..3de0f02 100644 (file)
@@ -3,6 +3,7 @@
 // caffe::Caffe functions so that one could easily call it from matlab.
 // Note that for matlab, we will simply use float as the data type.
 
+#include <sstream>
 #include <string>
 #include <vector>
 
 
 #define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs
 
+// Log and throw a Mex error
+inline void mex_error(const std::string &msg) {
+  LOG(ERROR) << msg;
+  mexErrMsgTxt(msg.c_str());
+}
+
 using namespace caffe;  // NOLINT(build/namespaces)
 
 // The pointer to the internal caffe::Net instance
@@ -46,14 +53,22 @@ static int init_key = -2;
 
 static mxArray* do_forward(const mxArray* const bottom) {
   vector<Blob<float>*>& input_blobs = net_->input_blobs();
-  CHECK_EQ(static_cast<unsigned int>(mxGetDimensions(bottom)[0]),
-      input_blobs.size());
+  if (static_cast<unsigned int>(mxGetDimensions(bottom)[0]) !=
+      input_blobs.size()) {
+    mex_error("Invalid input size");
+  }
   for (unsigned int i = 0; i < input_blobs.size(); ++i) {
     const mxArray* const elem = mxGetCell(bottom, i);
-    CHECK(mxIsSingle(elem))
-        << "MatCaffe require single-precision float point data";
-    CHECK_EQ(mxGetNumberOfElements(elem), input_blobs[i]->count())
-        << "MatCaffe input size does not match the input size of the network";
+    if (!mxIsSingle(elem)) {
+      mex_error("MatCaffe require single-precision float point data");
+    }
+    if (mxGetNumberOfElements(elem) != input_blobs[i]->count()) {
+      std::string error_msg;
+      error_msg += "MatCaffe input size does not match the input size ";
+      error_msg += "of the network";
+      mex_error(error_msg);
+    }
+
     const float* const data_ptr =
         reinterpret_cast<const float* const>(mxGetPr(elem));
     switch (Caffe::mode()) {
@@ -66,7 +81,7 @@ static mxArray* do_forward(const mxArray* const bottom) {
           input_blobs[i]->mutable_gpu_data());
       break;
     default:
-      LOG(FATAL) << "Unknown Caffe mode.";
+      mex_error("Unknown Caffe mode");
     }  // switch (Caffe::mode())
   }
   const vector<Blob<float>*>& output_blobs = net_->ForwardPrefilled();
@@ -89,7 +104,7 @@ static mxArray* do_forward(const mxArray* const bottom) {
           data_ptr);
       break;
     default:
-      LOG(FATAL) << "Unknown Caffe mode.";
+      mex_error("Unknown Caffe mode");
     }  // switch (Caffe::mode())
   }
 
@@ -99,8 +114,10 @@ static mxArray* do_forward(const mxArray* const bottom) {
 static mxArray* do_backward(const mxArray* const top_diff) {
   vector<Blob<float>*>& output_blobs = net_->output_blobs();
   vector<Blob<float>*>& input_blobs = net_->input_blobs();
-  CHECK_EQ(static_cast<unsigned int>(mxGetDimensions(top_diff)[0]),
-      output_blobs.size());
+  if (static_cast<unsigned int>(mxGetDimensions(top_diff)[0]) !=
+      output_blobs.size()) {
+    mex_error("Invalid input size");
+  }
   // First, copy the output diff
   for (unsigned int i = 0; i < output_blobs.size(); ++i) {
     const mxArray* const elem = mxGetCell(top_diff, i);
@@ -116,7 +133,7 @@ static mxArray* do_backward(const mxArray* const top_diff) {
           output_blobs[i]->mutable_gpu_diff());
       break;
     default:
-      LOG(FATAL) << "Unknown Caffe mode.";
+        mex_error("Unknown Caffe mode");
     }  // switch (Caffe::mode())
   }
   // LOG(INFO) << "Start";
@@ -139,7 +156,7 @@ static mxArray* do_backward(const mxArray* const top_diff) {
       caffe_copy(input_blobs[i]->count(), input_blobs[i]->gpu_diff(), data_ptr);
       break;
     default:
-      LOG(FATAL) << "Unknown Caffe mode.";
+        mex_error("Unknown Caffe mode");
     }  // switch (Caffe::mode())
   }
 
@@ -216,7 +233,7 @@ static mxArray* do_get_weights() {
               weights_ptr);
           break;
         default:
-          LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+          mex_error("Unknown Caffe mode");
         }
       }
     }
@@ -247,8 +264,9 @@ static void set_phase_test(MEX_ARGS) {
 
 static void set_device(MEX_ARGS) {
   if (nrhs != 1) {
-    LOG(ERROR) << "Only given " << nrhs << " arguments";
-    mexErrMsgTxt("Wrong number of arguments");
+    ostringstream error_msg;
+    error_msg << "Expected 1 argument, got " << nrhs;
+    mex_error(error_msg.str());
   }
 
   int device_id = static_cast<int>(mxGetScalar(prhs[0]));
@@ -261,8 +279,9 @@ static void get_init_key(MEX_ARGS) {
 
 static void init(MEX_ARGS) {
   if (nrhs != 2) {
-    LOG(ERROR) << "Only given " << nrhs << " arguments";
-    mexErrMsgTxt("Wrong number of arguments");
+    ostringstream error_msg;
+    error_msg << "Expected 2 arguments, got " << nrhs;
+    mex_error(error_msg.str());
   }
 
   char* param_file = mxArrayToString(prhs[0]);
@@ -291,8 +310,9 @@ static void reset(MEX_ARGS) {
 
 static void forward(MEX_ARGS) {
   if (nrhs != 1) {
-    LOG(ERROR) << "Only given " << nrhs << " arguments";
-    mexErrMsgTxt("Wrong number of arguments");
+    ostringstream error_msg;
+    error_msg << "Expected 1 argument, got " << nrhs;
+    mex_error(error_msg.str());
   }
 
   plhs[0] = do_forward(prhs[0]);
@@ -300,8 +320,9 @@ static void forward(MEX_ARGS) {
 
 static void backward(MEX_ARGS) {
   if (nrhs != 1) {
-    LOG(ERROR) << "Only given " << nrhs << " arguments";
-    mexErrMsgTxt("Wrong number of arguments");
+    ostringstream error_msg;
+    error_msg << "Expected 1 argument, got " << nrhs;
+    mex_error(error_msg.str());
   }
 
   plhs[0] = do_backward(prhs[0]);
@@ -374,8 +395,7 @@ static handler_registry handlers[] = {
 void mexFunction(MEX_ARGS) {
   mexLock();  // Avoid clearing the mex file.
   if (nrhs == 0) {
-    LOG(ERROR) << "No API command given";
-    mexErrMsgTxt("An API command is requires");
+    mex_error("No API command given");
     return;
   }
 
@@ -391,8 +411,9 @@ void mexFunction(MEX_ARGS) {
       }
     }
     if (!dispatched) {
-      LOG(ERROR) << "Unknown command `" << cmd << "'";
-      mexErrMsgTxt("API command not recognized");
+      ostringstream error_msg;
+      error_msg << "Unknown command '" << cmd << "'";
+      mex_error(error_msg.str());
     }
     mxFree(cmd);
   }