return -input
class MyModule(torch.jit.ScriptModule):
- def __init__(self):
+ def __init__(self, module):
super(MyModule, self).__init__()
- self.sub = Sub()
+ self.module = module
@torch.jit.script_method
def forward(self, input):
- return self.sub(input) + 1
+ return self.module(input) + 1
- m = MyModule()
+ m = MyModule(Sub())
input = torch.rand(3, 4)
self.assertEqual(input + 1, m(input))
m.eval()
self.assertEqual(-input + 1, m(input))
- def test_train_eval_const(self):
- class MyModule(torch.jit.ScriptModule):
- __constants__ = ['training']
-
- def __init__(self):
- super(MyModule, self).__init__()
- # TODO: it is illegal to try to call
- # eval/train because training has already
- # been set. Consider allowing
- # constants to be mutable until the end of __init__
+ # test batchnorm and dropout train/eval
+ input = torch.randn(6, 10)
+ batchnorm = nn.BatchNorm1d(10)
+ dropout = nn.Dropout(p=0.2)
- @torch.jit.script_method
- def forward(self, input):
- if self.training:
- x = 2 * input
- else:
- x = -input
- return x + 1
+ m_batchnorm = MyModule(batchnorm)
+ self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
+ batchnorm.eval()
+ m_batchnorm.eval()
+ self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
- m = MyModule()
- input = torch.rand(3, 4)
- self.assertEqual(2 * input + 1, m(input))
+ m_dropout = MyModule(dropout)
+ dropout.eval()
+ m_dropout.eval()
+ self.assertEqual(dropout(input) + 1, m_dropout(input))
def test_diff_subgraph_clones_constants(self):
@torch.jit.script
def test_script_module_param_buffer_mutation(self):
# TODO: add param mutation test case after JIT support it
class ModuleBufferMutate(torch.jit.ScriptModule):
- __constants__ = ['training']
-
def __init__(self):
super(ModuleBufferMutate, self).__init__(False)
self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
m = ModuleBufferMutate()
self.assertEqual(m(), 1)
+ m.eval()
+ self.assertEqual(m(), 1)
def test_script_module_for(self):
class M(torch.jit.ScriptModule):
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
https://arxiv.org/abs/1505.00853
"""
- __constants__ = ['lower', 'upper', 'inplace', 'training']
+ __constants__ = ['lower', 'upper', 'inplace']
def __init__(self, lower=1. / 8, upper=1. / 3, inplace=False):
super(RReLU, self).__init__()
@weak_module
class _BatchNorm(Module):
_version = 2
- __constants__ = ['training', 'track_running_stats', 'momentum', 'eps',
- 'weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked']
+ __constants__ = ['track_running_stats', 'momentum', 'eps', 'weight', 'bias',
+ 'running_mean', 'running_var', 'num_batches_tracked']
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
class _InstanceNorm(_BatchNorm):
__constants__ = ['running_mean', 'running_var', 'weight', 'bias',
- 'training', 'track_running_stats', 'momentum', 'eps']
+ 'track_running_stats', 'momentum', 'eps']
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
track_running_stats=False):