Add LSTM to standard library (#15744)
authorDavid Riazati <davidriazati@fb.com>
Fri, 22 Feb 2019 00:11:37 +0000 (16:11 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 00:24:19 +0000 (16:24 -0800)
commit2370c989d810635e307fbbc1e7a8ed5b7376a3b9
treefd047f91fd9d3ed8c63705d16dcc84278c63f856
parentac00a0cd479173f7fa800283a4ca899aa6328778
Add LSTM to standard library (#15744)

Summary:
**WIP**

Attempt 2 at #14831

This adds `nn.LSTM` to the jit standard library. Necessary changes to the module itself are detailed in comments. The main limitation is the lack of a true `PackedSequence`, instead this PR uses an ordinary `tuple` to stand in for `PackedSequence`.

Most of the new code in `rnn.py` is copied to `nn.LSTM` from `nn.RNNBase` to specialize it for LSTM since `hx` is a `Tuple[Tensor, Tensor]` (rather than just a `Tensor` as in the other RNN modules) for LSTM.

As a hack it adds an internal annotation `@_parameter_list` to mark that a function returns all the parameters of a module. The weights for `RNN` modules are passed to the corresponding op as a `List[Tensor]`. In Python this has to be gathered dynamically since Parameters could be moved from CPU to GPU or be deleted and replaced (i.e. if someone calls `weight_norm` on their module, #15766), but in the JIT parameter lists are immutable, hence a builtin to handle this differently in Python/JIT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15744

Differential Revision: D14173198

Pulled By: driazati

fbshipit-source-id: 4ee8113159b3a8f29a9f56fe661cfbb6b30dffcd
test/test_jit.py
torch/_jit_internal.py
torch/csrc/jit/register_special_ops.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/init.cpp
torch/jit/__init__.py
torch/nn/modules/rnn.py
torch/nn/utils/rnn.py