Add tests for dropout/batchnorm train/eval, remove training constants (#14780)
authorWanchao Liang <wanchaol@users.noreply.github.com>
Wed, 5 Dec 2018 02:15:14 +0000 (18:15 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 02:17:43 +0000 (18:17 -0800)
Summary:
This PR:

1. add tests for batchnorm/dropout for train/eval parameter mutatino
2. remove training constants from all our standard library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14780

Differential Revision: D13331578

Pulled By: wanchaol

fbshipit-source-id: d92ca3ce38cc2888688d50fe015e3e22539a20a5

test/test_jit.py
torch/nn/modules/activation.py
torch/nn/modules/batchnorm.py
torch/nn/modules/dropout.py
torch/nn/modules/instancenorm.py

index a71e31f..b3a249b 100644 (file)
@@ -641,42 +641,35 @@ class TestJit(JitTestCase):
                     return -input
 
         class MyModule(torch.jit.ScriptModule):
-            def __init__(self):
+            def __init__(self, module):
                 super(MyModule, self).__init__()
-                self.sub = Sub()
+                self.module = module
 
             @torch.jit.script_method
             def forward(self, input):
-                return self.sub(input) + 1
+                return self.module(input) + 1
 
-        m = MyModule()
+        m = MyModule(Sub())
         input = torch.rand(3, 4)
         self.assertEqual(input + 1, m(input))
         m.eval()
         self.assertEqual(-input + 1, m(input))
 
-    def test_train_eval_const(self):
-        class MyModule(torch.jit.ScriptModule):
-            __constants__ = ['training']
-
-            def __init__(self):
-                super(MyModule, self).__init__()
-                # TODO: it is illegal to try to call
-                # eval/train because training has already
-                # been set. Consider allowing
-                # constants to be mutable until the end of __init__
+        # test batchnorm and dropout train/eval
+        input = torch.randn(6, 10)
+        batchnorm = nn.BatchNorm1d(10)
+        dropout = nn.Dropout(p=0.2)
 
-            @torch.jit.script_method
-            def forward(self, input):
-                if self.training:
-                    x = 2 * input
-                else:
-                    x = -input
-                return x + 1
+        m_batchnorm = MyModule(batchnorm)
+        self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
+        batchnorm.eval()
+        m_batchnorm.eval()
+        self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
 
-        m = MyModule()
-        input = torch.rand(3, 4)
-        self.assertEqual(2 * input + 1, m(input))
+        m_dropout = MyModule(dropout)
+        dropout.eval()
+        m_dropout.eval()
+        self.assertEqual(dropout(input) + 1, m_dropout(input))
 
     def test_diff_subgraph_clones_constants(self):
         @torch.jit.script
@@ -5347,8 +5340,6 @@ a")
     def test_script_module_param_buffer_mutation(self):
         # TODO: add param mutation test case after JIT support it
         class ModuleBufferMutate(torch.jit.ScriptModule):
-            __constants__ = ['training']
-
             def __init__(self):
                 super(ModuleBufferMutate, self).__init__(False)
                 self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
@@ -5361,6 +5352,8 @@ a")
 
         m = ModuleBufferMutate()
         self.assertEqual(m(), 1)
+        m.eval()
+        self.assertEqual(m(), 1)
 
     def test_script_module_for(self):
         class M(torch.jit.ScriptModule):
index a7224a3..b3e0c41 100644 (file)
@@ -126,7 +126,7 @@ class RReLU(Module):
     .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
         https://arxiv.org/abs/1505.00853
     """
-    __constants__ = ['lower', 'upper', 'inplace', 'training']
+    __constants__ = ['lower', 'upper', 'inplace']
 
     def __init__(self, lower=1. / 8, upper=1. / 3, inplace=False):
         super(RReLU, self).__init__()
index 9b09483..e6dbc21 100644 (file)
@@ -13,8 +13,8 @@ from ..._jit_internal import weak_module, weak_script_method
 @weak_module
 class _BatchNorm(Module):
     _version = 2
-    __constants__ = ['training', 'track_running_stats', 'momentum', 'eps',
-                     'weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked']
+    __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
+                     'running_mean', 'running_var', 'num_batches_tracked']
 
     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                  track_running_stats=True):
index b34e481..2b114db 100644 (file)
@@ -4,7 +4,7 @@ from ..._jit_internal import weak_module, weak_script_method
 
 
 class _DropoutNd(Module):
-    __constants__ = ['p', 'inplace', 'training']
+    __constants__ = ['p', 'inplace']
 
     def __init__(self, p=0.5, inplace=False):
         super(_DropoutNd, self).__init__()
index d5b8427..3a0c452 100644 (file)
@@ -5,7 +5,7 @@ from ..._jit_internal import weak_module, weak_script_method
 
 class _InstanceNorm(_BatchNorm):
     __constants__ = ['running_mean', 'running_var', 'weight', 'bias',
-                     'training', 'track_running_stats', 'momentum', 'eps']
+                     'track_running_stats', 'momentum', 'eps']
 
     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
                  track_running_stats=False):