self.net.forward()
self.net.backward()
+ def test_clear_param_diffs(self):
+ # Run a forward/backward step to have non-zero diffs
+ self.net.forward()
+ self.net.backward()
+ diff = self.net.params["conv"][0].diff
+ # Check that we have non-zero diffs
+ self.assertTrue(diff.max() > 0)
+ self.net.clear_param_diffs()
+ # Check that the diffs are now 0
+ self.assertTrue((diff == 0).all())
+
def test_inputs_outputs(self):
self.assertEqual(self.net.inputs, [])
self.assertEqual(self.net.outputs, ['loss'])