Test for python forward and backward with start and end layer.
authorCarl Doersch <cdoersch@cs.cmu.edu>
Tue, 25 Aug 2015 18:26:14 +0000 (11:26 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 14 Apr 2017 06:26:44 +0000 (23:26 -0700)
python/caffe/test/test_net.py

index 24391cc..afd2769 100644 (file)
@@ -25,11 +25,11 @@ def simple_net_file(num_output):
         bias_filler { type: 'constant' value: 2 } }
         param { decay_mult: 1 } param { decay_mult: 0 }
         }
-    layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip'
+    layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip_blob'
       inner_product_param { num_output: """ + str(num_output) + """
         weight_filler { type: 'gaussian' std: 2.5 }
         bias_filler { type: 'constant' value: -3 } } }
-    layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip' bottom: 'label'
+    layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip_blob' bottom: 'label'
       top: 'loss' }""")
     f.close()
     return f.name
@@ -71,6 +71,43 @@ class TestNet(unittest.TestCase):
         self.net.forward()
         self.net.backward()
 
+    def test_forward_start_end(self):
+        conv_blob=self.net.blobs['conv'];
+        ip_blob=self.net.blobs['ip_blob'];
+        sample_data=np.random.uniform(size=conv_blob.data.shape);
+        sample_data=sample_data.astype(np.float32);
+        conv_blob.data[:]=sample_data;
+        forward_blob=self.net.forward(start='ip',end='ip');
+        self.assertIn('ip_blob',forward_blob);
+
+        manual_forward=[];
+        for i in range(0,conv_blob.data.shape[0]):
+          dot=np.dot(self.net.params['ip'][0].data,
+                     conv_blob.data[i].reshape(-1));
+          manual_forward.append(dot+self.net.params['ip'][1].data);
+        manual_forward=np.array(manual_forward);
+
+        np.testing.assert_allclose(ip_blob.data,manual_forward,rtol=1e-3);
+
+    def test_backward_start_end(self):
+        conv_blob=self.net.blobs['conv'];
+        ip_blob=self.net.blobs['ip_blob'];
+        sample_data=np.random.uniform(size=ip_blob.data.shape)
+        sample_data=sample_data.astype(np.float32);
+        ip_blob.diff[:]=sample_data;
+        backward_blob=self.net.backward(start='ip',end='ip');
+        self.assertIn('conv',backward_blob);
+
+        manual_backward=[];
+        for i in range(0,conv_blob.data.shape[0]):
+          dot=np.dot(self.net.params['ip'][0].data.transpose(),
+                     sample_data[i].reshape(-1));
+          manual_backward.append(dot);
+        manual_backward=np.array(manual_backward);
+        manual_backward=manual_backward.reshape(conv_blob.data.shape);
+
+        np.testing.assert_allclose(conv_blob.diff,manual_backward,rtol=1e-3);
+
     def test_clear_param_diffs(self):
         # Run a forward/backward step to have non-zero diffs
         self.net.forward()
@@ -90,13 +127,13 @@ class TestNet(unittest.TestCase):
         self.assertEqual(self.net.top_names,
                          OrderedDict([('data', ['data', 'label']),
                                       ('conv', ['conv']),
-                                      ('ip', ['ip']),
+                                      ('ip', ['ip_blob']),
                                       ('loss', ['loss'])]))
         self.assertEqual(self.net.bottom_names,
                          OrderedDict([('data', []),
                                       ('conv', ['data']),
                                       ('ip', ['conv']),
-                                      ('loss', ['ip', 'label'])]))
+                                      ('loss', ['ip_blob', 'label'])]))
 
     def test_save_and_read(self):
         f = tempfile.NamedTemporaryFile(mode='w+', delete=False)