Add RNNCell modules to Script standard library (#14695)
authorDavid Riazati <davidriazati@fb.com>
Wed, 19 Dec 2018 01:25:51 +0000 (17:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 01:28:28 +0000 (17:28 -0800)
Summary:
Adds RNNCell modules to script standard lib

cc apaszke for argument_spec changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14695

Differential Revision: D13467680

Pulled By: driazati

fbshipit-source-id: 13a14da87714325cc4c3d49e5fde8a850d5d757b

test/test_jit.py
torch/csrc/jit/argument_spec.h
torch/csrc/jit/graph_executor.cpp
torch/nn/modules/rnn.py

index f8d946a..760f902 100644 (file)
@@ -10787,6 +10787,21 @@ additional_module_tests = [
         input_size=(S, S),
         extra_args=((S, S),)
     ),
+    dict(
+        module_name='RNNCell',
+        constructor_args=(S, S),
+        input_size=(S, S),
+    ),
+    dict(
+        module_name='LSTMCell',
+        constructor_args=(S, S),
+        input_size=(S, S),
+    ),
+    dict(
+        module_name='GRUCell',
+        constructor_args=(S, S),
+        input_size=(S, S),
+    ),
 ]
 
 
index ec5b337..ec3d988 100644 (file)
@@ -73,13 +73,13 @@ struct ArgumentSpec {
   }
 
   void addInput(const IValue& input, size_t& offset, bool with_grad) {
-    auto & arg = args[offset];
+    auto & arg = args.at(offset);
     // Initialize all fields to 0. This is convenient, because e.g.
     // requires_grad() can be checked even on tensors AND will make
     // padding bits all 0s.
     std::memset(&arg, 0, sizeof(ArgumentInfo));
+
     if (input.isTensor()) {
-      JIT_ASSERT(offset < args.size());
       at::Tensor t = input.toTensor();
       if ((arg.defined_ = t.defined())) {
         arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad();
@@ -96,7 +96,6 @@ struct ArgumentSpec {
         addInput(elem, offset, with_grad);
       }
     } else {
-      JIT_ASSERT(offset < args.size());
       // NB: no need to set is_tensor to false, because we memset the struct to 0 above
       combineHash(arg);
       offset++;
index cbdf893..9d40546 100644 (file)
@@ -298,6 +298,9 @@ struct GraphExecutorImpl {
   }
 
   static size_t countFlatInputs(const TypePtr& ptr) {
+    if (auto optional_type = ptr->cast<OptionalType>()) {
+      return countFlatInputs(optional_type->getElementType());
+    }
     if (auto tuple_type = ptr->cast<TupleType>()) {
       size_t total = 0;
       for (auto & elem : tuple_type->elements()) {
index 4d5f2a1..89d9b94 100644 (file)
@@ -8,8 +8,9 @@ from .module import Module
 from ..parameter import Parameter
 from ..utils.rnn import PackedSequence
 from .. import init
+from .. import _VF
+from ..._jit_internal import weak_module, weak_script_method
 
-_VF = torch._C._VariableFunctions
 _rnn_impls = {
     'LSTM': _VF.lstm,
     'GRU': _VF.gru,
@@ -535,6 +536,7 @@ class GRU(RNNBase):
 
 
 class RNNCellBase(Module):
+    __constants__ = ['input_size', 'hidden_size', 'bias']
 
     def __init__(self, input_size, hidden_size, bias, num_chunks):
         super(RNNCellBase, self).__init__()
@@ -559,13 +561,16 @@ class RNNCellBase(Module):
             s += ', nonlinearity={nonlinearity}'
         return s.format(**self.__dict__)
 
+    @weak_script_method
     def check_forward_input(self, input):
         if input.size(1) != self.input_size:
             raise RuntimeError(
                 "input has inconsistent input_size: got {}, expected {}".format(
                     input.size(1), self.input_size))
 
+    @weak_script_method
     def check_forward_hidden(self, input, hx, hidden_label=''):
+        # type: (Tensor, Tensor, str)
         if input.size(0) != hx.size(0):
             raise RuntimeError(
                 "Input batch size {} doesn't match hidden{} batch size {}".format(
@@ -582,6 +587,7 @@ class RNNCellBase(Module):
             init.uniform_(weight, -stdv, stdv)
 
 
+@weak_module
 class RNNCell(RNNCellBase):
     r"""An Elman RNN cell with tanh or ReLU non-linearity.
 
@@ -630,31 +636,41 @@ class RNNCell(RNNCellBase):
                 hx = rnn(input[i], hx)
                 output.append(hx)
     """
+    __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
 
     def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
         super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
         self.nonlinearity = nonlinearity
 
+    @weak_script_method
     def forward(self, input, hx=None):
+        # type: (Tensor, Optional[Tensor]) -> Tensor
         self.check_forward_input(input)
         if hx is None:
-            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
-        self.check_forward_hidden(input, hx)
+            _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        else:
+            _hx = torch.jit._unwrap_optional(hx)
+        self.check_forward_hidden(input, _hx, '')
         if self.nonlinearity == "tanh":
-            func = _VF.rnn_tanh_cell
+            ret = _VF.rnn_tanh_cell(
+                input, _hx,
+                self.weight_ih, self.weight_hh,
+                self.bias_ih, self.bias_hh,
+            )
         elif self.nonlinearity == "relu":
-            func = _VF.rnn_relu_cell
+            ret = _VF.rnn_relu_cell(
+                input, _hx,
+                self.weight_ih, self.weight_hh,
+                self.bias_ih, self.bias_hh,
+            )
         else:
+            ret = input  # TODO: remove when jit supports exception flow
             raise RuntimeError(
                 "Unknown nonlinearity: {}".format(self.nonlinearity))
-
-        return func(
-            input, hx,
-            self.weight_ih, self.weight_hh,
-            self.bias_ih, self.bias_hh,
-        )
+        return ret
 
 
+@weak_module
 class LSTMCell(RNNCellBase):
     r"""A long short-term memory (LSTM) cell.
 
@@ -719,20 +735,25 @@ class LSTMCell(RNNCellBase):
     def __init__(self, input_size, hidden_size, bias=True):
         super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
 
+    @weak_script_method
     def forward(self, input, hx=None):
+        # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
         self.check_forward_input(input)
         if hx is None:
-            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
-            hx = (hx, hx)
-        self.check_forward_hidden(input, hx[0], '[0]')
-        self.check_forward_hidden(input, hx[1], '[1]')
+            zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+            _hx = (zeros, zeros)
+        else:
+            _hx = torch.jit._unwrap_optional(hx)
+        self.check_forward_hidden(input, _hx[0], '[0]')
+        self.check_forward_hidden(input, _hx[1], '[1]')
         return _VF.lstm_cell(
-            input, hx,
+            input, _hx,
             self.weight_ih, self.weight_hh,
             self.bias_ih, self.bias_hh,
         )
 
 
+@weak_module
 class GRUCell(RNNCellBase):
     r"""A gated recurrent unit (GRU) cell
 
@@ -789,13 +810,17 @@ class GRUCell(RNNCellBase):
     def __init__(self, input_size, hidden_size, bias=True):
         super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
 
+    @weak_script_method
     def forward(self, input, hx=None):
+        # type: (Tensor, Optional[Tensor]) -> Tensor
         self.check_forward_input(input)
         if hx is None:
-            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
-        self.check_forward_hidden(input, hx)
+            _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+        else:
+            _hx = torch.jit._unwrap_optional(hx)
+        self.check_forward_hidden(input, _hx, '')
         return _VF.gru_cell(
-            input, hx,
+            input, _hx,
             self.weight_ih, self.weight_hh,
             self.bias_ih, self.bias_hh,
         )