From f7625e3632fd18ef0ebeafbe9c807990582618dc Mon Sep 17 00:00:00 2001 From: Daniel Golden Date: Wed, 5 Nov 2014 13:36:00 -0800 Subject: [PATCH] Prevent Matlab on OS X from crashing on error Replace CHECK() and LOG(FATAL) with LOG(ERROR) and mexErrMsgTxt A failed CHECK() or LOG(FATAL) causes Matlab to crash on OS X 10.9 with Matlab 2014a. --- matlab/caffe/matcaffe.cpp | 71 ++++++++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/matlab/caffe/matcaffe.cpp b/matlab/caffe/matcaffe.cpp index fc04758..3de0f02 100644 --- a/matlab/caffe/matcaffe.cpp +++ b/matlab/caffe/matcaffe.cpp @@ -3,6 +3,7 @@ // caffe::Caffe functions so that one could easily call it from matlab. // Note that for matlab, we will simply use float as the data type. +#include #include #include @@ -12,6 +13,12 @@ #define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs +// Log and throw a Mex error +inline void mex_error(const std::string &msg) { + LOG(ERROR) << msg; + mexErrMsgTxt(msg.c_str()); +} + using namespace caffe; // NOLINT(build/namespaces) // The pointer to the internal caffe::Net instance @@ -46,14 +53,22 @@ static int init_key = -2; static mxArray* do_forward(const mxArray* const bottom) { vector*>& input_blobs = net_->input_blobs(); - CHECK_EQ(static_cast(mxGetDimensions(bottom)[0]), - input_blobs.size()); + if (static_cast(mxGetDimensions(bottom)[0]) != + input_blobs.size()) { + mex_error("Invalid input size"); + } for (unsigned int i = 0; i < input_blobs.size(); ++i) { const mxArray* const elem = mxGetCell(bottom, i); - CHECK(mxIsSingle(elem)) - << "MatCaffe require single-precision float point data"; - CHECK_EQ(mxGetNumberOfElements(elem), input_blobs[i]->count()) - << "MatCaffe input size does not match the input size of the network"; + if (!mxIsSingle(elem)) { + mex_error("MatCaffe require single-precision float point data"); + } + if (mxGetNumberOfElements(elem) != input_blobs[i]->count()) { + std::string error_msg; + error_msg += "MatCaffe input size does not match the input size "; + error_msg += "of the network"; + mex_error(error_msg); + } + const float* const data_ptr = reinterpret_cast(mxGetPr(elem)); switch (Caffe::mode()) { @@ -66,7 +81,7 @@ static mxArray* do_forward(const mxArray* const bottom) { input_blobs[i]->mutable_gpu_data()); break; default: - LOG(FATAL) << "Unknown Caffe mode."; + mex_error("Unknown Caffe mode"); } // switch (Caffe::mode()) } const vector*>& output_blobs = net_->ForwardPrefilled(); @@ -89,7 +104,7 @@ static mxArray* do_forward(const mxArray* const bottom) { data_ptr); break; default: - LOG(FATAL) << "Unknown Caffe mode."; + mex_error("Unknown Caffe mode"); } // switch (Caffe::mode()) } @@ -99,8 +114,10 @@ static mxArray* do_forward(const mxArray* const bottom) { static mxArray* do_backward(const mxArray* const top_diff) { vector*>& output_blobs = net_->output_blobs(); vector*>& input_blobs = net_->input_blobs(); - CHECK_EQ(static_cast(mxGetDimensions(top_diff)[0]), - output_blobs.size()); + if (static_cast(mxGetDimensions(top_diff)[0]) != + output_blobs.size()) { + mex_error("Invalid input size"); + } // First, copy the output diff for (unsigned int i = 0; i < output_blobs.size(); ++i) { const mxArray* const elem = mxGetCell(top_diff, i); @@ -116,7 +133,7 @@ static mxArray* do_backward(const mxArray* const top_diff) { output_blobs[i]->mutable_gpu_diff()); break; default: - LOG(FATAL) << "Unknown Caffe mode."; + mex_error("Unknown Caffe mode"); } // switch (Caffe::mode()) } // LOG(INFO) << "Start"; @@ -139,7 +156,7 @@ static mxArray* do_backward(const mxArray* const top_diff) { caffe_copy(input_blobs[i]->count(), input_blobs[i]->gpu_diff(), data_ptr); break; default: - LOG(FATAL) << "Unknown Caffe mode."; + mex_error("Unknown Caffe mode"); } // switch (Caffe::mode()) } @@ -216,7 +233,7 @@ static mxArray* do_get_weights() { weights_ptr); break; default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + mex_error("Unknown Caffe mode"); } } } @@ -247,8 +264,9 @@ static void set_phase_test(MEX_ARGS) { static void set_device(MEX_ARGS) { if (nrhs != 1) { - LOG(ERROR) << "Only given " << nrhs << " arguments"; - mexErrMsgTxt("Wrong number of arguments"); + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); } int device_id = static_cast(mxGetScalar(prhs[0])); @@ -261,8 +279,9 @@ static void get_init_key(MEX_ARGS) { static void init(MEX_ARGS) { if (nrhs != 2) { - LOG(ERROR) << "Only given " << nrhs << " arguments"; - mexErrMsgTxt("Wrong number of arguments"); + ostringstream error_msg; + error_msg << "Expected 2 arguments, got " << nrhs; + mex_error(error_msg.str()); } char* param_file = mxArrayToString(prhs[0]); @@ -291,8 +310,9 @@ static void reset(MEX_ARGS) { static void forward(MEX_ARGS) { if (nrhs != 1) { - LOG(ERROR) << "Only given " << nrhs << " arguments"; - mexErrMsgTxt("Wrong number of arguments"); + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); } plhs[0] = do_forward(prhs[0]); @@ -300,8 +320,9 @@ static void forward(MEX_ARGS) { static void backward(MEX_ARGS) { if (nrhs != 1) { - LOG(ERROR) << "Only given " << nrhs << " arguments"; - mexErrMsgTxt("Wrong number of arguments"); + ostringstream error_msg; + error_msg << "Expected 1 argument, got " << nrhs; + mex_error(error_msg.str()); } plhs[0] = do_backward(prhs[0]); @@ -374,8 +395,7 @@ static handler_registry handlers[] = { void mexFunction(MEX_ARGS) { mexLock(); // Avoid clearing the mex file. if (nrhs == 0) { - LOG(ERROR) << "No API command given"; - mexErrMsgTxt("An API command is requires"); + mex_error("No API command given"); return; } @@ -391,8 +411,9 @@ void mexFunction(MEX_ARGS) { } } if (!dispatched) { - LOG(ERROR) << "Unknown command `" << cmd << "'"; - mexErrMsgTxt("API command not recognized"); + ostringstream error_msg; + error_msg << "Unknown command '" << cmd << "'"; + mex_error(error_msg.str()); } mxFree(cmd); } -- 2.7.4