From f3bff2d50050bc4300c6c842163eabfc4a01f571 Mon Sep 17 00:00:00 2001 From: David Riazati Date: Tue, 18 Dec 2018 17:25:51 -0800 Subject: [PATCH] Add RNNCell modules to Script standard library (#14695) Summary: Adds RNNCell modules to script standard lib cc apaszke for argument_spec changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/14695 Differential Revision: D13467680 Pulled By: driazati fbshipit-source-id: 13a14da87714325cc4c3d49e5fde8a850d5d757b --- test/test_jit.py | 15 ++++++++++ torch/csrc/jit/argument_spec.h | 5 ++-- torch/csrc/jit/graph_executor.cpp | 3 ++ torch/nn/modules/rnn.py | 63 +++++++++++++++++++++++++++------------ 4 files changed, 64 insertions(+), 22 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index f8d946a..760f902 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10787,6 +10787,21 @@ additional_module_tests = [ input_size=(S, S), extra_args=((S, S),) ), + dict( + module_name='RNNCell', + constructor_args=(S, S), + input_size=(S, S), + ), + dict( + module_name='LSTMCell', + constructor_args=(S, S), + input_size=(S, S), + ), + dict( + module_name='GRUCell', + constructor_args=(S, S), + input_size=(S, S), + ), ] diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index ec5b337..ec3d988 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -73,13 +73,13 @@ struct ArgumentSpec { } void addInput(const IValue& input, size_t& offset, bool with_grad) { - auto & arg = args[offset]; + auto & arg = args.at(offset); // Initialize all fields to 0. This is convenient, because e.g. // requires_grad() can be checked even on tensors AND will make // padding bits all 0s. std::memset(&arg, 0, sizeof(ArgumentInfo)); + if (input.isTensor()) { - JIT_ASSERT(offset < args.size()); at::Tensor t = input.toTensor(); if ((arg.defined_ = t.defined())) { arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad(); @@ -96,7 +96,6 @@ struct ArgumentSpec { addInput(elem, offset, with_grad); } } else { - JIT_ASSERT(offset < args.size()); // NB: no need to set is_tensor to false, because we memset the struct to 0 above combineHash(arg); offset++; diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index cbdf893..9d40546 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -298,6 +298,9 @@ struct GraphExecutorImpl { } static size_t countFlatInputs(const TypePtr& ptr) { + if (auto optional_type = ptr->cast()) { + return countFlatInputs(optional_type->getElementType()); + } if (auto tuple_type = ptr->cast()) { size_t total = 0; for (auto & elem : tuple_type->elements()) { diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 4d5f2a1..89d9b94 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -8,8 +8,9 @@ from .module import Module from ..parameter import Parameter from ..utils.rnn import PackedSequence from .. import init +from .. import _VF +from ..._jit_internal import weak_module, weak_script_method -_VF = torch._C._VariableFunctions _rnn_impls = { 'LSTM': _VF.lstm, 'GRU': _VF.gru, @@ -535,6 +536,7 @@ class GRU(RNNBase): class RNNCellBase(Module): + __constants__ = ['input_size', 'hidden_size', 'bias'] def __init__(self, input_size, hidden_size, bias, num_chunks): super(RNNCellBase, self).__init__() @@ -559,13 +561,16 @@ class RNNCellBase(Module): s += ', nonlinearity={nonlinearity}' return s.format(**self.__dict__) + @weak_script_method def check_forward_input(self, input): if input.size(1) != self.input_size: raise RuntimeError( "input has inconsistent input_size: got {}, expected {}".format( input.size(1), self.input_size)) + @weak_script_method def check_forward_hidden(self, input, hx, hidden_label=''): + # type: (Tensor, Tensor, str) if input.size(0) != hx.size(0): raise RuntimeError( "Input batch size {} doesn't match hidden{} batch size {}".format( @@ -582,6 +587,7 @@ class RNNCellBase(Module): init.uniform_(weight, -stdv, stdv) +@weak_module class RNNCell(RNNCellBase): r"""An Elman RNN cell with tanh or ReLU non-linearity. @@ -630,31 +636,41 @@ class RNNCell(RNNCellBase): hx = rnn(input[i], hx) output.append(hx) """ + __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity'] def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1) self.nonlinearity = nonlinearity + @weak_script_method def forward(self, input, hx=None): + # type: (Tensor, Optional[Tensor]) -> Tensor self.check_forward_input(input) if hx is None: - hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) - self.check_forward_hidden(input, hx) + _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + _hx = torch.jit._unwrap_optional(hx) + self.check_forward_hidden(input, _hx, '') if self.nonlinearity == "tanh": - func = _VF.rnn_tanh_cell + ret = _VF.rnn_tanh_cell( + input, _hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) elif self.nonlinearity == "relu": - func = _VF.rnn_relu_cell + ret = _VF.rnn_relu_cell( + input, _hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) else: + ret = input # TODO: remove when jit supports exception flow raise RuntimeError( "Unknown nonlinearity: {}".format(self.nonlinearity)) - - return func( - input, hx, - self.weight_ih, self.weight_hh, - self.bias_ih, self.bias_hh, - ) + return ret +@weak_module class LSTMCell(RNNCellBase): r"""A long short-term memory (LSTM) cell. @@ -719,20 +735,25 @@ class LSTMCell(RNNCellBase): def __init__(self, input_size, hidden_size, bias=True): super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4) + @weak_script_method def forward(self, input, hx=None): + # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] self.check_forward_input(input) if hx is None: - hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) - hx = (hx, hx) - self.check_forward_hidden(input, hx[0], '[0]') - self.check_forward_hidden(input, hx[1], '[1]') + zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + _hx = (zeros, zeros) + else: + _hx = torch.jit._unwrap_optional(hx) + self.check_forward_hidden(input, _hx[0], '[0]') + self.check_forward_hidden(input, _hx[1], '[1]') return _VF.lstm_cell( - input, hx, + input, _hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) +@weak_module class GRUCell(RNNCellBase): r"""A gated recurrent unit (GRU) cell @@ -789,13 +810,17 @@ class GRUCell(RNNCellBase): def __init__(self, input_size, hidden_size, bias=True): super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3) + @weak_script_method def forward(self, input, hx=None): + # type: (Tensor, Optional[Tensor]) -> Tensor self.check_forward_input(input) if hx is None: - hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) - self.check_forward_hidden(input, hx) + _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + _hx = torch.jit._unwrap_optional(hx) + self.check_forward_hidden(input, _hx, '') return _VF.gru_cell( - input, hx, + input, _hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) -- 2.7.4