From d07e5f796907a2bc048bdab3cdb4ace05fa60d7a Mon Sep 17 00:00:00 2001 From: Ronghang Hu Date: Fri, 29 May 2015 07:50:23 +0800 Subject: [PATCH] More tests for Blob, Layer, copy_from and step, fix some typos More testes are added into test_net.m and test_solver.m --- docs/tutorial/interfaces.md | 4 ++-- matlab/+caffe/+test/test_net.m | 24 +++++++++++++++++++++++- matlab/+caffe/+test/test_solver.m | 2 ++ matlab/+caffe/Net.m | 2 +- matlab/+caffe/Solver.m | 2 +- matlab/+caffe/io.m | 2 +- matlab/+caffe/run_tests.m | 3 +++ 7 files changed, 33 insertions(+), 6 deletions(-) diff --git a/docs/tutorial/interfaces.md b/docs/tutorial/interfaces.md index a59a410..1296331 100644 --- a/docs/tutorial/interfaces.md +++ b/docs/tutorial/interfaces.md @@ -82,9 +82,9 @@ In MatCaffe, you can * Resume training from solver snapshots * Access train net and test nets in a solver * Run for a certain number of iterations and give back control to Matlab -* Intermingle arbitrary Matlab code to with gradient steps +* Intermingle arbitrary Matlab code with gradient steps -An ILSVRC image classification demo is in caffe/matlab/demo/classification_demo.m +An ILSVRC image classification demo is in caffe/matlab/demo/classification_demo.m (you need to download BVLC CaffeNet from [Model Zoo](http://caffe.berkeleyvision.org/model_zoo.html) to run it). ### Build MatCaffe diff --git a/matlab/+caffe/+test/test_net.m b/matlab/+caffe/+test/test_net.m index 5d9ba00..3dabe84 100644 --- a/matlab/+caffe/+test/test_net.m +++ b/matlab/+caffe/+test/test_net.m @@ -48,6 +48,24 @@ classdef test_net < matlab.unittest.TestCase end end methods (Test) + function self = test_blob(self) + self.net.blobs('data').set_data(10 * ones(self.net.blobs('data').shape)); + self.verifyEqual(self.net.blobs('data').get_data(), ... + 10 * ones(self.net.blobs('data').shape, 'single')); + self.net.blobs('data').set_diff(-2 * ones(self.net.blobs('data').shape)); + self.verifyEqual(self.net.blobs('data').get_diff(), ... + -2 * ones(self.net.blobs('data').shape, 'single')); + original_shape = self.net.blobs('data').shape; + self.net.blobs('data').reshape([6 5 4 3 2 1]); + self.verifyEqual(self.net.blobs('data').shape, [6 5 4 3 2 1]); + self.net.blobs('data').reshape(original_shape); + self.net.reshape(); + end + function self = test_layer(self) + self.verifyEqual(self.net.params('conv', 1).shape, [2 2 2 11]); + self.verifyEqual(self.net.layers('conv').params(2).shape, 11); + self.verifyEqual(self.net.layers('conv').type(), 'Convolution'); + end function test_forward_backward(self) self.net.forward_prefilled(); self.net.backward_prefilled(); @@ -60,13 +78,17 @@ classdef test_net < matlab.unittest.TestCase weights_file = tempname(); self.net.save(weights_file); model_file2 = caffe.test.test_net.simple_net_file(self.num_output); - net2 = caffe.Net(model_file2, weights_file, 'train'); + net2 = caffe.Net(model_file2, 'train'); + net2.copy_from(weights_file); + net3 = caffe.Net(model_file2, weights_file, 'train'); delete(model_file2); delete(weights_file); for l = 1:length(self.net.layer_vec) for i = 1:length(self.net.layer_vec(l).params) self.verifyEqual(self.net.layer_vec(l).params(i).get_data(), ... net2.layer_vec(l).params(i).get_data()); + self.verifyEqual(self.net.layer_vec(l).params(i).get_data(), ... + net3.layer_vec(l).params(i).get_data()); end end end diff --git a/matlab/+caffe/+test/test_solver.m b/matlab/+caffe/+test/test_solver.m index 682dad4..739258b 100644 --- a/matlab/+caffe/+test/test_solver.m +++ b/matlab/+caffe/+test/test_solver.m @@ -36,6 +36,8 @@ classdef test_solver < matlab.unittest.TestCase methods (Test) function test_solve(self) self.verifyEqual(self.solver.iter(), 0) + self.solver.step(30); + self.verifyEqual(self.solver.iter(), 30) self.solver.solve() self.verifyEqual(self.solver.iter(), 100) end diff --git a/matlab/+caffe/Net.m b/matlab/+caffe/Net.m index a676106..e6295bb 100644 --- a/matlab/+caffe/Net.m +++ b/matlab/+caffe/Net.m @@ -111,7 +111,7 @@ classdef Net < handle self.blobs(self.outputs{n}).set_diff(output_diff{n}); end self.backward_prefilled(); - % retrieve diff from input_blobs + % retrieve diff from input blobs res = cell(length(self.inputs), 1); for n = 1:length(self.inputs) res{n} = self.blobs(self.inputs{n}).get_diff(); diff --git a/matlab/+caffe/Solver.m b/matlab/+caffe/Solver.m index daaa802..f8bdc4e 100644 --- a/matlab/+caffe/Solver.m +++ b/matlab/+caffe/Solver.m @@ -41,7 +41,7 @@ classdef Solver < handle end function restore(self, snapshot_filename) CHECK(ischar(snapshot_filename), 'snapshot_filename must be a string'); - CHECK_FILE_EXIST(snapshot_filename) + CHECK_FILE_EXIST(snapshot_filename); caffe_('solver_restore', self.hSolver_self, snapshot_filename); end function solve(self) diff --git a/matlab/+caffe/io.m b/matlab/+caffe/io.m index 7a30bfb..c9e07ae 100644 --- a/matlab/+caffe/io.m +++ b/matlab/+caffe/io.m @@ -17,7 +17,7 @@ classdef io function mean_data = read_mean(mean_proto_file) % mean_data = read_mean(mean_proto_file) % read image mean data from binaryproto file - CHECK(ischar(mean_proto_file), 'im_file must be a string'); + CHECK(ischar(mean_proto_file), 'mean_proto_file must be a string'); CHECK_FILE_EXIST(mean_proto_file); mean_data = caffe_('read_mean', mean_proto_file); end diff --git a/matlab/+caffe/run_tests.m b/matlab/+caffe/run_tests.m index 8773c9f..9389685 100644 --- a/matlab/+caffe/run_tests.m +++ b/matlab/+caffe/run_tests.m @@ -2,6 +2,9 @@ function results = run_tests() % results = run_tests() % run all tests in this caffe matlab wrapper package +% use CPU for testing +caffe.set_mode_cpu(); + % reset caffe before testing caffe.reset_all(); -- 2.7.4