From 9735f4b3b257379ac4e6c9d310aab32bfb198661 Mon Sep 17 00:00:00 2001 From: Ronghang Hu Date: Thu, 28 May 2015 13:40:26 +0800 Subject: [PATCH] Aesthetic changes on code style and some minor fix --- matlab/+caffe/Blob.m | 22 ++--- matlab/+caffe/Layer.m | 2 +- matlab/+caffe/Net.m | 10 +-- matlab/+caffe/Solver.m | 2 +- matlab/+caffe/get_net.m | 3 +- matlab/+caffe/get_solver.m | 1 - matlab/+caffe/io.m | 15 ++-- matlab/+caffe/private/caffe_.cpp | 154 ++++++++++++++++---------------- matlab/+caffe/private/is_valid_handle.m | 1 - matlab/+caffe/run_tests.m | 8 +- 10 files changed, 113 insertions(+), 105 deletions(-) diff --git a/matlab/+caffe/Blob.m b/matlab/+caffe/Blob.m index f9b6409..e39f7ee 100644 --- a/matlab/+caffe/Blob.m +++ b/matlab/+caffe/Blob.m @@ -7,9 +7,9 @@ classdef Blob < handle methods function self = Blob(hBlob_blob) - CHECK(is_valid_handle(hBlob_blob), 'invalid input handle'); + CHECK(is_valid_handle(hBlob_blob), 'invalid Blob handle'); - % setup self handle and attributes + % setup self handle self.hBlob_self = hBlob_blob; end function shape = shape(self) @@ -37,14 +37,16 @@ classdef Blob < handle methods (Access = private) function shape = check_and_preprocess_shape(~, shape) - CHECK(isempty(shape) || isnumeric(shape) && isrow(shape), ... + CHECK(isempty(shape) || (isnumeric(shape) && isrow(shape)), ... 'shape must be a integer row vector'); shape = double(shape); end function data = check_and_preprocess_data(self, data) CHECK(isnumeric(data), 'data or diff must be numeric types'); - self.check_data_size_matches(data) - data = single(data); + self.check_data_size_matches(data); + if ~isa(data, 'single') + data = single(data); + end end function check_data_size_matches(self, data) % check whether size of data matches shape of this blob @@ -59,17 +61,17 @@ classdef Blob < handle % target blob is a vector (1 dim) self_shape_extended = [self_shape_extended, 1]; end - % also, matlab cannot have tailing dimension 1 for ndim > 2, so you + % Also, matlab cannot have tailing dimension 1 for ndim > 2, so you % cannot create 20 x 10 x 1 x 1 array in matlab as it becomes 20 x 10 - % extend matlab arrays to have tailing dimension 1 during shape match + % Extend matlab arrays to have tailing dimension 1 during shape match data_size_extended = ... [size(data), ones(1, length(self_shape_extended) - ndims(data))]; is_matched = ... - (length(self_shape_extended) == length(data_size_extended)) ... + (length(self_shape_extended) == length(data_size_extended)) ... && all(self_shape_extended == data_size_extended); CHECK(is_matched, ... - sprintf('%s, data size: [ %s], blob shape: [ %s]', ... - 'data size does not match blob shape', ... + sprintf('%s, input data/diff size: [ %s] vs target blob shape: [ %s]', ... + 'input data/diff size does not match target blob shape', ... sprintf('%d ', data_size_extended), sprintf('%d ', self_shape_extended))); end end diff --git a/matlab/+caffe/Layer.m b/matlab/+caffe/Layer.m index 7587ed7..4c20231 100644 --- a/matlab/+caffe/Layer.m +++ b/matlab/+caffe/Layer.m @@ -13,7 +13,7 @@ classdef Layer < handle methods function self = Layer(hLayer_layer) - CHECK(is_valid_handle(hLayer_layer), 'invalid input handle'); + CHECK(is_valid_handle(hLayer_layer), 'invalid Layer handle'); % setup self handle and attributes self.hLayer_self = hLayer_layer; diff --git a/matlab/+caffe/Net.m b/matlab/+caffe/Net.m index 5319634..a676106 100644 --- a/matlab/+caffe/Net.m +++ b/matlab/+caffe/Net.m @@ -33,7 +33,7 @@ classdef Net < handle end % construct a net from handle hNet_net = varargin{1}; - CHECK(is_valid_handle(hNet_net), 'invalid input handle'); + CHECK(is_valid_handle(hNet_net), 'invalid Net handle'); % setup self handle and attributes self.hNet_self = hNet_net; @@ -64,7 +64,7 @@ classdef Net < handle self.name2blob_index = containers.Map(self.attributes.blob_names, ... 1:length(self.attributes.blob_names)); - % expose layer_names and blob_names for public access + % expose layer_names and blob_names for public read access self.layer_names = self.attributes.layer_names; self.blob_names = self.attributes.blob_names; end @@ -91,12 +91,12 @@ classdef Net < handle CHECK(iscell(input_data), 'input_data must be a cell array'); CHECK(length(input_data) == length(self.inputs), ... 'input data cell length must match input blob number'); - % copy data to input_blobs + % copy data to input blobs for n = 1:length(self.inputs) self.blobs(self.inputs{n}).set_data(input_data{n}); end self.forward_prefilled(); - % retrieve data from output_blobs + % retrieve data from output blobs res = cell(length(self.outputs), 1); for n = 1:length(self.outputs) res{n} = self.blobs(self.outputs{n}).get_data(); @@ -106,7 +106,7 @@ classdef Net < handle CHECK(iscell(output_diff), 'output_diff must be a cell array'); CHECK(length(output_diff) == length(self.outputs), ... 'output diff cell length must match output blob number'); - % copy diff to output_blobs + % copy diff to output blobs for n = 1:length(self.outputs) self.blobs(self.outputs{n}).set_diff(output_diff{n}); end diff --git a/matlab/+caffe/Solver.m b/matlab/+caffe/Solver.m index 80fa539..daaa802 100644 --- a/matlab/+caffe/Solver.m +++ b/matlab/+caffe/Solver.m @@ -23,7 +23,7 @@ classdef Solver < handle end % construct a solver from handle hSolver_solver = varargin{1}; - CHECK(is_valid_handle(hSolver_solver), 'invalid input handle'); + CHECK(is_valid_handle(hSolver_solver), 'invalid Solver handle'); % setup self handle and attributes self.hSolver_self = hSolver_solver; diff --git a/matlab/+caffe/get_net.m b/matlab/+caffe/get_net.m index d60979d..4b5683e 100644 --- a/matlab/+caffe/get_net.m +++ b/matlab/+caffe/get_net.m @@ -5,7 +5,7 @@ function net = get_net(varargin) % phase_name can only be 'train' or 'test' CHECK(nargin == 2 || nargin == 3, ['usage: ' ... - 'net = get_net(model_file, phase_name) or ', ... + 'net = get_net(model_file, phase_name) or ' ... 'net = get_net(model_file, weights_file, phase_name)']); if nargin == 3 model_file = varargin{1}; @@ -23,6 +23,7 @@ CHECK(strcmp(phase_name, 'train') || strcmp(phase_name, 'test'), ... sprintf('phase_name can only be %strain%s or %stest%s', ... char(39), char(39), char(39), char(39))); +% construct caffe net from model_file hNet = caffe_('get_net', model_file, phase_name); net = caffe.Net(hNet); diff --git a/matlab/+caffe/get_solver.m b/matlab/+caffe/get_solver.m index 30366d8..74d576e 100644 --- a/matlab/+caffe/get_solver.m +++ b/matlab/+caffe/get_solver.m @@ -4,7 +4,6 @@ function solver = get_solver(solver_file) CHECK(ischar(solver_file), 'solver_file must be a string'); CHECK_FILE_EXIST(solver_file); - pSolver = caffe_('get_solver', solver_file); solver = caffe.Solver(pSolver); diff --git a/matlab/+caffe/io.m b/matlab/+caffe/io.m index 7fad968..7a30bfb 100644 --- a/matlab/+caffe/io.m +++ b/matlab/+caffe/io.m @@ -3,17 +3,20 @@ classdef io methods (Static) function im_data = load_image(im_file) + % im_data = load_image(im_file) + % load an image from disk into Caffe-supported data format + % switch channels from RGB to BGR, make width the fastest dimension + % and convert to single CHECK(ischar(im_file), 'im_file must be a string'); CHECK_FILE_EXIST(im_file); - % load an image from disk into Caffe-supported data format - % switch channels from RGB to BGR, make width the fastest dimension, and - % convert to single - im = imread(im_file); - im_data = im(:, :, [3, 2, 1]); - im_data = permute(im_data, [2 1 3]); + im_data = imread(im_file); + im_data = im_data(:, :, [3, 2, 1]); + im_data = permute(im_data, [2, 1, 3]); im_data = single(im_data); end function mean_data = read_mean(mean_proto_file) + % mean_data = read_mean(mean_proto_file) + % read image mean data from binaryproto file CHECK(ischar(mean_proto_file), 'im_file must be a string'); CHECK_FILE_EXIST(mean_proto_file); mean_data = caffe_('read_mean', mean_proto_file); diff --git a/matlab/+caffe/private/caffe_.cpp b/matlab/+caffe/private/caffe_.cpp index 96a1920..4e0ebc1 100644 --- a/matlab/+caffe/private/caffe_.cpp +++ b/matlab/+caffe/private/caffe_.cpp @@ -21,21 +21,22 @@ using namespace caffe; // NOLINT(build/namespaces) -// Do CHECK and throw a Mex error if check failsf -inline void mxCHECK(bool expr, const std::string &msg) { +// Do CHECK and throw a Mex error if check fails +inline void mxCHECK(bool expr, const char* msg) { if (!expr) { - LOG(ERROR) << msg; - mexErrMsgTxt(msg.c_str()); + mexErrMsgTxt(msg); } } -inline void mxERROR(const std::string &msg) { mxCHECK(false, msg); } +inline void mxERROR(const char* msg) { mexErrMsgTxt(msg); } // Check if a file exists and can be opened void mxCHECK_FILE_EXIST(const char* file) { std::ifstream f(file); if (!f.good()) { f.close(); - mxERROR("Could not open file " + string(file)); + std::string msg("Could not open file "); + msg += file; + mxERROR(msg.c_str()); } f.close(); } @@ -43,6 +44,7 @@ void mxCHECK_FILE_EXIST(const char* file) { // The pointers to caffe::Solver and caffe::Net instances static vector > > solvers_; static vector > > nets_; +// init_key is generated at the beginning and everytime you call reset static double init_key = static_cast(caffe_rng_rand()); /** ----------------------------------------------------------------- @@ -104,17 +106,17 @@ static mxArray* blob_to_mx_mat(const Blob* blob, return mx_mat; } -// convert vector to matlab vector +// Convert vector to matlab row vector static mxArray* int_vec_to_mx_vec(const vector& int_vec) { mxArray* mx_vec = mxCreateDoubleMatrix(int_vec.size(), 1, mxREAL); double* vec_mem_ptr = mxGetPr(mx_vec); for (int i = 0; i < int_vec.size(); i++) { - vec_mem_ptr[i] = int_vec[i]; + vec_mem_ptr[i] = static_cast(int_vec[i]); } return mx_vec; } -// convert vector to matlab string cell vector +// Convert vector to matlab cell vector of strings static mxArray* str_vec_to_mx_strcell(const vector& str_vec) { mxArray* mx_strcell = mxCreateCellMatrix(str_vec.size(), 1); for (int i = 0; i < str_vec.size(); i++) { @@ -135,44 +137,46 @@ static T* handle_to_ptr(const mxArray* mx_handle) { mxArray* mx_ptr = mxGetField(mx_handle, 0, "ptr"); mxArray* mx_init_key = mxGetField(mx_handle, 0, "init_key"); mxCHECK(mxIsUint64(mx_ptr), "pointer type must be uint64"); - mxCHECK(mxGetScalar(mx_init_key) == init_key, "incorrect handle init_key"); + mxCHECK(mxGetScalar(mx_init_key) == init_key, + "Could not convert handle to pointer due to invalid init_key. " + "The object might have been cleared."); return reinterpret_cast(*reinterpret_cast(mxGetData(mx_ptr))); } -// Create an empty handle struct array +// Create a handle struct vector, without setting up each handle in it template -static mxArray* create_handles(int ptr_num) { +static mxArray* create_handle_vec(int ptr_num) { const int handle_field_num = 2; const char* handle_fields[handle_field_num] = { "ptr", "init_key" }; return mxCreateStructMatrix(ptr_num, 1, handle_field_num, handle_fields); } -// Set up each handle in a handle struct array +// Set up a handle in a handle struct vector by its index template -static void setup_handle(const T* ptr, int index, mxArray* mx_handles) { +static void setup_handle(const T* ptr, int index, mxArray* mx_handle_vec) { mxArray* mx_ptr = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL); *reinterpret_cast(mxGetData(mx_ptr)) = reinterpret_cast(ptr); - mxSetField(mx_handles, index, "ptr", mx_ptr); - mxSetField(mx_handles, index, "init_key", mxCreateDoubleScalar(init_key)); + mxSetField(mx_handle_vec, index, "ptr", mx_ptr); + mxSetField(mx_handle_vec, index, "init_key", mxCreateDoubleScalar(init_key)); } // Convert a pointer in C++ to a handle in matlab template static mxArray* ptr_to_handle(const T* ptr) { - mxArray* mx_handle = create_handles(1); + mxArray* mx_handle = create_handle_vec(1); setup_handle(ptr, 0, mx_handle); return mx_handle; } -// Convert a vector of shared_ptr in C++ to handle struct array in matlab +// Convert a vector of shared_ptr in C++ to handle struct vector template -static mxArray* ptr_vec_to_handles(const vector >& ptr_vec) { - mxArray* mx_handle = create_handles(ptr_vec.size()); +static mxArray* ptr_vec_to_handle_vec(const vector >& ptr_vec) { + mxArray* mx_handle_vec = create_handle_vec(ptr_vec.size()); for (int i = 0; i < ptr_vec.size(); i++) { - setup_handle(ptr_vec[i].get(), i, mx_handle); + setup_handle(ptr_vec[i].get(), i, mx_handle_vec); } - return mx_handle; + return mx_handle_vec; } /** ----------------------------------------------------------------- @@ -182,12 +186,12 @@ static mxArray* ptr_vec_to_handles(const vector >& ptr_vec) { static void get_solver(MEX_ARGS) { mxCHECK(nrhs == 1 && mxIsChar(prhs[0]), "Usage: caffe_('get_solver', solver_file)"); - const char* solver_file = mxArrayToString(prhs[0]); + char* solver_file = mxArrayToString(prhs[0]); mxCHECK_FILE_EXIST(solver_file); - shared_ptr > solver; - solver.reset(new caffe::SGDSolver(solver_file)); + shared_ptr > solver(new caffe::SGDSolver(solver_file)); solvers_.push_back(solver); plhs[0] = ptr_to_handle >(solver.get()); + mxFree(solver_file); } // Usage: caffe_('solver_get_attr', hSolver) @@ -202,7 +206,7 @@ static void solver_get_attr(MEX_ARGS) { mxSetField(mx_solver_attr, 0, "hNet_net", ptr_to_handle >(solver->net().get())); mxSetField(mx_solver_attr, 0, "hNet_test_nets", - ptr_vec_to_handles >(solver->test_nets())); + ptr_vec_to_handle_vec >(solver->test_nets())); plhs[0] = mx_solver_attr; } @@ -219,9 +223,10 @@ static void solver_restore(MEX_ARGS) { mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]), "Usage: caffe_('solver_restore', hSolver, snapshot_file)"); Solver* solver = handle_to_ptr >(prhs[0]); - const char* snapshot_file = mxArrayToString(prhs[1]); + char* snapshot_file = mxArrayToString(prhs[1]); mxCHECK_FILE_EXIST(snapshot_file); solver->Restore(snapshot_file); + mxFree(snapshot_file); } // Usage: caffe_('solver_solve', hSolver) @@ -232,10 +237,10 @@ static void solver_solve(MEX_ARGS) { solver->Solve(); } -// Usage: caffe_('solver_solve', hSolver, iters) +// Usage: caffe_('solver_step', hSolver, iters) static void solver_step(MEX_ARGS) { mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsDouble(prhs[1]), - "Usage: caffe_('solver_solve', hSolver, iters)"); + "Usage: caffe_('solver_step', hSolver, iters)"); Solver* solver = handle_to_ptr >(prhs[0]); int iters = mxGetScalar(prhs[1]); solver->Step(iters); @@ -245,21 +250,22 @@ static void solver_step(MEX_ARGS) { static void get_net(MEX_ARGS) { mxCHECK(nrhs == 2 && mxIsChar(prhs[0]) && mxIsChar(prhs[1]), "Usage: caffe_('get_net', model_file, phase_name)"); - const char* model_file = mxArrayToString(prhs[0]); + char* model_file = mxArrayToString(prhs[0]); + char* phase_name = mxArrayToString(prhs[1]); mxCHECK_FILE_EXIST(model_file); - const char* phase_name = mxArrayToString(prhs[1]); Phase phase; if (strcmp(phase_name, "train") == 0) { phase = TRAIN; } else if (strcmp(phase_name, "test") == 0) { phase = TEST; } else { - mxERROR("Unknown phase."); + mxERROR("Unknown phase"); } - shared_ptr > net; - net.reset(new caffe::Net(model_file, phase)); + shared_ptr > net(new caffe::Net(model_file, phase)); nets_.push_back(net); plhs[0] = ptr_to_handle >(net.get()); + mxFree(model_file); + mxFree(phase_name); } // Usage: caffe_('net_get_attr', hNet) @@ -273,15 +279,15 @@ static void net_get_attr(MEX_ARGS) { mxArray* mx_net_attr = mxCreateStructMatrix(1, 1, net_attr_num, net_attrs); mxSetField(mx_net_attr, 0, "hLayer_layers", - ptr_vec_to_handles >(net->layers())); + ptr_vec_to_handle_vec >(net->layers())); mxSetField(mx_net_attr, 0, "hBlob_blobs", - ptr_vec_to_handles >(net->blobs())); + ptr_vec_to_handle_vec >(net->blobs())); mxSetField(mx_net_attr, 0, "input_blob_indices", int_vec_to_mx_vec(net->input_blob_indices())); mxSetField(mx_net_attr, 0, "output_blob_indices", int_vec_to_mx_vec(net->output_blob_indices())); mxSetField(mx_net_attr, 0, "layer_names", - str_vec_to_mx_strcell(net->layer_names())); + str_vec_to_mx_strcell(net->layer_names())); mxSetField(mx_net_attr, 0, "blob_names", str_vec_to_mx_strcell(net->blob_names())); plhs[0] = mx_net_attr; @@ -308,9 +314,10 @@ static void net_copy_from(MEX_ARGS) { mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]), "Usage: caffe_('net_copy_from', hNet, weights_file)"); Net* net = handle_to_ptr >(prhs[0]); - const char* weights_file = mxArrayToString(prhs[1]); + char* weights_file = mxArrayToString(prhs[1]); mxCHECK_FILE_EXIST(weights_file); net->CopyTrainedLayersFrom(weights_file); + mxFree(weights_file); } // Usage: caffe_('net_reshape', hNet) @@ -326,10 +333,11 @@ static void net_save(MEX_ARGS) { mxCHECK(nrhs == 2 && mxIsStruct(prhs[0]) && mxIsChar(prhs[1]), "Usage: caffe_('net_save', hNet, save_file)"); Net* net = handle_to_ptr >(prhs[0]); - const char* weights_file = mxArrayToString(prhs[1]); + char* weights_file = mxArrayToString(prhs[1]); NetParameter net_param; net->ToProto(&net_param, false); WriteProtoToBinaryFile(net_param, weights_file); + mxFree(weights_file); } // Usage: caffe_('layer_get_attr', hLayer) @@ -342,14 +350,14 @@ static void layer_get_attr(MEX_ARGS) { mxArray* mx_layer_attr = mxCreateStructMatrix(1, 1, layer_attr_num, layer_attrs); mxSetField(mx_layer_attr, 0, "hBlob_blobs", - ptr_vec_to_handles >(layer->blobs())); + ptr_vec_to_handle_vec >(layer->blobs())); plhs[0] = mx_layer_attr; } -// Usage: caffe_('layer_get_attr', hLayer) +// Usage: caffe_('layer_get_type', hLayer) static void layer_get_type(MEX_ARGS) { mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]), - "Usage: caffe_('layer_get_attr', hLayer)"); + "Usage: caffe_('layer_get_type', hLayer)"); Layer* layer = handle_to_ptr >(prhs[0]); plhs[0] = mxCreateString(layer->type()); } @@ -446,10 +454,12 @@ static void get_init_key(MEX_ARGS) { // Usage: caffe_('reset') static void reset(MEX_ARGS) { mxCHECK(nrhs == 0, "Usage: caffe_('reset')"); - mexPrintf("cleared %d solvers and %d stand-alone nets\n", solvers_.size(), - nets_.size()); + // Clear solvers and stand-alone nets + mexPrintf("Cleared %d solvers and %d stand-alone nets\n", + solvers_.size(), nets_.size()); solvers_.clear(); nets_.clear(); + // Generate new init_key, so that handles created before becomes invalid init_key = static_cast(caffe_rng_rand()); } @@ -457,21 +467,15 @@ static void reset(MEX_ARGS) { static void read_mean(MEX_ARGS) { mxCHECK(nrhs == 1 && mxIsChar(prhs[0]), "Usage: caffe_('read_mean', mean_proto_file)"); - const char* mean_proto_file = mxArrayToString(prhs[0]); + char* mean_proto_file = mxArrayToString(prhs[0]); + mxCHECK_FILE_EXIST(mean_proto_file); Blob data_mean; - mexPrintf("Loading mean file from: %s\n", mean_proto_file); BlobProto blob_proto; bool result = ReadProtoFromBinaryFile(mean_proto_file, &blob_proto); - mxCHECK(result, "Couldn't read the file"); + mxCHECK(result, "Could not read your mean file"); data_mean.FromProto(blob_proto); - mwSize dims[4] = {data_mean.width(), data_mean.height(), - data_mean.channels(), data_mean.num() }; - mxArray* mx_blob = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); - float* data_ptr = reinterpret_cast(mxGetData(mx_blob)); - caffe_copy(data_mean.count(), data_mean.cpu_data(), data_ptr); - mexPrintf("Remember that Caffe saves in [width, height, channels]" - " format and channels are also BGR!\n"); - plhs[0] = mx_blob; + plhs[0] = blob_to_mx_mat(&data_mean, DATA); + mxFree(mean_proto_file); } /** ----------------------------------------------------------------- @@ -516,31 +520,27 @@ static handler_registry handlers[] = { }; /** ----------------------------------------------------------------- - ** matlab entry point: caffe_(api_command, arg1, arg2, ...) + ** matlab entry point. **/ +// Usage: caffe_(api_command, arg1, arg2, ...) void mexFunction(MEX_ARGS) { mexLock(); // Avoid clearing the mex file. - if (nrhs == 0) { - mxERROR("No API command given"); - return; - } - - { // Handle input command - char* cmd = mxArrayToString(prhs[0]); - bool dispatched = false; - // Dispatch to cmd handler - for (int i = 0; handlers[i].func != NULL; i++) { - if (handlers[i].cmd.compare(cmd) == 0) { - handlers[i].func(nlhs, plhs, nrhs-1, prhs+1); - dispatched = true; - break; - } - } - if (!dispatched) { - ostringstream error_msg; - error_msg << "Unknown command '" << cmd << "'"; - mxERROR(error_msg.str().c_str()); + mxCHECK(nrhs > 0, "Usage: caffe_(api_command, arg1, arg2, ...)"); + // Handle input command + char* cmd = mxArrayToString(prhs[0]); + bool dispatched = false; + // Dispatch to cmd handler + for (int i = 0; handlers[i].func != NULL; i++) { + if (handlers[i].cmd.compare(cmd) == 0) { + handlers[i].func(nlhs, plhs, nrhs-1, prhs+1); + dispatched = true; + break; } - mxFree(cmd); } + if (!dispatched) { + ostringstream error_msg; + error_msg << "Unknown command '" << cmd << "'"; + mxERROR(error_msg.str().c_str()); + } + mxFree(cmd); } diff --git a/matlab/+caffe/private/is_valid_handle.m b/matlab/+caffe/private/is_valid_handle.m index 77abf21..a0648ec 100644 --- a/matlab/+caffe/private/is_valid_handle.m +++ b/matlab/+caffe/private/is_valid_handle.m @@ -15,7 +15,6 @@ end % is_valid_handle('get_new_init_key') to get new init_key from C++; if ischar(hObj) && strcmp(hObj, 'get_new_init_key') init_key = caffe_('get_init_key'); - valid = true; return else % check whether data types are correct and init_key matches diff --git a/matlab/+caffe/run_tests.m b/matlab/+caffe/run_tests.m index fb1089c..afdd8f3 100644 --- a/matlab/+caffe/run_tests.m +++ b/matlab/+caffe/run_tests.m @@ -2,11 +2,15 @@ function results = run_tests() % results = run_tests() % run all tests in this caffe matlab wrapper package +% reset caffe before testing caffe.reset(); + +% put all test cases here results = [... run(caffe.test.test_net) ... - run(caffe.test.test_solver) - ]; + run(caffe.test.test_solver) ]; + +% reset caffe after testing caffe.reset(); end -- 2.7.4