from ..parameter import Parameter
from ..utils.rnn import PackedSequence
from .. import init
+from .. import _VF
+from ..._jit_internal import weak_module, weak_script_method
-_VF = torch._C._VariableFunctions
_rnn_impls = {
'LSTM': _VF.lstm,
'GRU': _VF.gru,
class RNNCellBase(Module):
+ __constants__ = ['input_size', 'hidden_size', 'bias']
def __init__(self, input_size, hidden_size, bias, num_chunks):
super(RNNCellBase, self).__init__()
s += ', nonlinearity={nonlinearity}'
return s.format(**self.__dict__)
+ @weak_script_method
def check_forward_input(self, input):
if input.size(1) != self.input_size:
raise RuntimeError(
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))
+ @weak_script_method
def check_forward_hidden(self, input, hx, hidden_label=''):
+ # type: (Tensor, Tensor, str)
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
init.uniform_(weight, -stdv, stdv)
+@weak_module
class RNNCell(RNNCellBase):
r"""An Elman RNN cell with tanh or ReLU non-linearity.
hx = rnn(input[i], hx)
output.append(hx)
"""
+ __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
self.nonlinearity = nonlinearity
+ @weak_script_method
def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)
if hx is None:
- hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
- self.check_forward_hidden(input, hx)
+ _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx, '')
if self.nonlinearity == "tanh":
- func = _VF.rnn_tanh_cell
+ ret = _VF.rnn_tanh_cell(
+ input, _hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
elif self.nonlinearity == "relu":
- func = _VF.rnn_relu_cell
+ ret = _VF.rnn_relu_cell(
+ input, _hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
else:
+ ret = input # TODO: remove when jit supports exception flow
raise RuntimeError(
"Unknown nonlinearity: {}".format(self.nonlinearity))
-
- return func(
- input, hx,
- self.weight_ih, self.weight_hh,
- self.bias_ih, self.bias_hh,
- )
+ return ret
+@weak_module
class LSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell.
def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
+ @weak_script_method
def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
self.check_forward_input(input)
if hx is None:
- hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
- hx = (hx, hx)
- self.check_forward_hidden(input, hx[0], '[0]')
- self.check_forward_hidden(input, hx[1], '[1]')
+ zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ _hx = (zeros, zeros)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx[0], '[0]')
+ self.check_forward_hidden(input, _hx[1], '[1]')
return _VF.lstm_cell(
- input, hx,
+ input, _hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)
+@weak_module
class GRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell
def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
+ @weak_script_method
def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)
if hx is None:
- hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
- self.check_forward_hidden(input, hx)
+ _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx, '')
return _VF.gru_cell(
- input, hx,
+ input, _hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)