From: Wanchao Liang Date: Wed, 5 Dec 2018 02:15:14 +0000 (-0800) Subject: Add tests for dropout/batchnorm train/eval, remove training constants (#14780) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2473 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d872af92826a71034da9125165b9e0433ffba0d8;p=platform%2Fupstream%2Fpytorch.git Add tests for dropout/batchnorm train/eval, remove training constants (#14780) Summary: This PR: 1. add tests for batchnorm/dropout for train/eval parameter mutatino 2. remove training constants from all our standard library Pull Request resolved: https://github.com/pytorch/pytorch/pull/14780 Differential Revision: D13331578 Pulled By: wanchaol fbshipit-source-id: d92ca3ce38cc2888688d50fe015e3e22539a20a5 --- diff --git a/test/test_jit.py b/test/test_jit.py index a71e31f..b3a249b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -641,42 +641,35 @@ class TestJit(JitTestCase): return -input class MyModule(torch.jit.ScriptModule): - def __init__(self): + def __init__(self, module): super(MyModule, self).__init__() - self.sub = Sub() + self.module = module @torch.jit.script_method def forward(self, input): - return self.sub(input) + 1 + return self.module(input) + 1 - m = MyModule() + m = MyModule(Sub()) input = torch.rand(3, 4) self.assertEqual(input + 1, m(input)) m.eval() self.assertEqual(-input + 1, m(input)) - def test_train_eval_const(self): - class MyModule(torch.jit.ScriptModule): - __constants__ = ['training'] - - def __init__(self): - super(MyModule, self).__init__() - # TODO: it is illegal to try to call - # eval/train because training has already - # been set. Consider allowing - # constants to be mutable until the end of __init__ + # test batchnorm and dropout train/eval + input = torch.randn(6, 10) + batchnorm = nn.BatchNorm1d(10) + dropout = nn.Dropout(p=0.2) - @torch.jit.script_method - def forward(self, input): - if self.training: - x = 2 * input - else: - x = -input - return x + 1 + m_batchnorm = MyModule(batchnorm) + self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) + batchnorm.eval() + m_batchnorm.eval() + self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) - m = MyModule() - input = torch.rand(3, 4) - self.assertEqual(2 * input + 1, m(input)) + m_dropout = MyModule(dropout) + dropout.eval() + m_dropout.eval() + self.assertEqual(dropout(input) + 1, m_dropout(input)) def test_diff_subgraph_clones_constants(self): @torch.jit.script @@ -5347,8 +5340,6 @@ a") def test_script_module_param_buffer_mutation(self): # TODO: add param mutation test case after JIT support it class ModuleBufferMutate(torch.jit.ScriptModule): - __constants__ = ['training'] - def __init__(self): super(ModuleBufferMutate, self).__init__(False) self.register_buffer('running_var', torch.tensor(0, dtype=torch.long)) @@ -5361,6 +5352,8 @@ a") m = ModuleBufferMutate() self.assertEqual(m(), 1) + m.eval() + self.assertEqual(m(), 1) def test_script_module_for(self): class M(torch.jit.ScriptModule): diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index a7224a3..b3e0c41 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -126,7 +126,7 @@ class RReLU(Module): .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: https://arxiv.org/abs/1505.00853 """ - __constants__ = ['lower', 'upper', 'inplace', 'training'] + __constants__ = ['lower', 'upper', 'inplace'] def __init__(self, lower=1. / 8, upper=1. / 3, inplace=False): super(RReLU, self).__init__() diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 9b09483..e6dbc21 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -13,8 +13,8 @@ from ..._jit_internal import weak_module, weak_script_method @weak_module class _BatchNorm(Module): _version = 2 - __constants__ = ['training', 'track_running_stats', 'momentum', 'eps', - 'weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked'] + __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias', + 'running_mean', 'running_var', 'num_batches_tracked'] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py index b34e481..2b114db 100644 --- a/torch/nn/modules/dropout.py +++ b/torch/nn/modules/dropout.py @@ -4,7 +4,7 @@ from ..._jit_internal import weak_module, weak_script_method class _DropoutNd(Module): - __constants__ = ['p', 'inplace', 'training'] + __constants__ = ['p', 'inplace'] def __init__(self, p=0.5, inplace=False): super(_DropoutNd, self).__init__() diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index d5b8427..3a0c452 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -5,7 +5,7 @@ from ..._jit_internal import weak_module, weak_script_method class _InstanceNorm(_BatchNorm): __constants__ = ['running_mean', 'running_var', 'weight', 'bias', - 'training', 'track_running_stats', 'momentum', 'eps'] + 'track_running_stats', 'momentum', 'eps'] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False):