// caffe::Caffe functions so that one could easily call it from matlab.
// Note that for matlab, we will simply use float as the data type.
+#include <sstream>
#include <string>
#include <vector>
#define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs
+// Log and throw a Mex error
+inline void mex_error(const std::string &msg) {
+ LOG(ERROR) << msg;
+ mexErrMsgTxt(msg.c_str());
+}
+
using namespace caffe; // NOLINT(build/namespaces)
// The pointer to the internal caffe::Net instance
static mxArray* do_forward(const mxArray* const bottom) {
vector<Blob<float>*>& input_blobs = net_->input_blobs();
- CHECK_EQ(static_cast<unsigned int>(mxGetDimensions(bottom)[0]),
- input_blobs.size());
+ if (static_cast<unsigned int>(mxGetDimensions(bottom)[0]) !=
+ input_blobs.size()) {
+ mex_error("Invalid input size");
+ }
for (unsigned int i = 0; i < input_blobs.size(); ++i) {
const mxArray* const elem = mxGetCell(bottom, i);
- CHECK(mxIsSingle(elem))
- << "MatCaffe require single-precision float point data";
- CHECK_EQ(mxGetNumberOfElements(elem), input_blobs[i]->count())
- << "MatCaffe input size does not match the input size of the network";
+ if (!mxIsSingle(elem)) {
+ mex_error("MatCaffe require single-precision float point data");
+ }
+ if (mxGetNumberOfElements(elem) != input_blobs[i]->count()) {
+ std::string error_msg;
+ error_msg += "MatCaffe input size does not match the input size ";
+ error_msg += "of the network";
+ mex_error(error_msg);
+ }
+
const float* const data_ptr =
reinterpret_cast<const float* const>(mxGetPr(elem));
switch (Caffe::mode()) {
input_blobs[i]->mutable_gpu_data());
break;
default:
- LOG(FATAL) << "Unknown Caffe mode.";
+ mex_error("Unknown Caffe mode");
} // switch (Caffe::mode())
}
const vector<Blob<float>*>& output_blobs = net_->ForwardPrefilled();
data_ptr);
break;
default:
- LOG(FATAL) << "Unknown Caffe mode.";
+ mex_error("Unknown Caffe mode");
} // switch (Caffe::mode())
}
static mxArray* do_backward(const mxArray* const top_diff) {
vector<Blob<float>*>& output_blobs = net_->output_blobs();
vector<Blob<float>*>& input_blobs = net_->input_blobs();
- CHECK_EQ(static_cast<unsigned int>(mxGetDimensions(top_diff)[0]),
- output_blobs.size());
+ if (static_cast<unsigned int>(mxGetDimensions(top_diff)[0]) !=
+ output_blobs.size()) {
+ mex_error("Invalid input size");
+ }
// First, copy the output diff
for (unsigned int i = 0; i < output_blobs.size(); ++i) {
const mxArray* const elem = mxGetCell(top_diff, i);
output_blobs[i]->mutable_gpu_diff());
break;
default:
- LOG(FATAL) << "Unknown Caffe mode.";
+ mex_error("Unknown Caffe mode");
} // switch (Caffe::mode())
}
// LOG(INFO) << "Start";
caffe_copy(input_blobs[i]->count(), input_blobs[i]->gpu_diff(), data_ptr);
break;
default:
- LOG(FATAL) << "Unknown Caffe mode.";
+ mex_error("Unknown Caffe mode");
} // switch (Caffe::mode())
}
weights_ptr);
break;
default:
- LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ mex_error("Unknown Caffe mode");
}
}
}
static void set_device(MEX_ARGS) {
if (nrhs != 1) {
- LOG(ERROR) << "Only given " << nrhs << " arguments";
- mexErrMsgTxt("Wrong number of arguments");
+ ostringstream error_msg;
+ error_msg << "Expected 1 argument, got " << nrhs;
+ mex_error(error_msg.str());
}
int device_id = static_cast<int>(mxGetScalar(prhs[0]));
static void init(MEX_ARGS) {
if (nrhs != 2) {
- LOG(ERROR) << "Only given " << nrhs << " arguments";
- mexErrMsgTxt("Wrong number of arguments");
+ ostringstream error_msg;
+ error_msg << "Expected 2 arguments, got " << nrhs;
+ mex_error(error_msg.str());
}
char* param_file = mxArrayToString(prhs[0]);
static void forward(MEX_ARGS) {
if (nrhs != 1) {
- LOG(ERROR) << "Only given " << nrhs << " arguments";
- mexErrMsgTxt("Wrong number of arguments");
+ ostringstream error_msg;
+ error_msg << "Expected 1 argument, got " << nrhs;
+ mex_error(error_msg.str());
}
plhs[0] = do_forward(prhs[0]);
static void backward(MEX_ARGS) {
if (nrhs != 1) {
- LOG(ERROR) << "Only given " << nrhs << " arguments";
- mexErrMsgTxt("Wrong number of arguments");
+ ostringstream error_msg;
+ error_msg << "Expected 1 argument, got " << nrhs;
+ mex_error(error_msg.str());
}
plhs[0] = do_backward(prhs[0]);
void mexFunction(MEX_ARGS) {
mexLock(); // Avoid clearing the mex file.
if (nrhs == 0) {
- LOG(ERROR) << "No API command given";
- mexErrMsgTxt("An API command is requires");
+ mex_error("No API command given");
return;
}
}
}
if (!dispatched) {
- LOG(ERROR) << "Unknown command `" << cmd << "'";
- mexErrMsgTxt("API command not recognized");
+ ostringstream error_msg;
+ error_msg << "Unknown command '" << cmd << "'";
+ mex_error(error_msg.str());
}
mxFree(cmd);
}