From 83878e19ff2af39443016beb31952476c9f049d0 Mon Sep 17 00:00:00 2001 From: Rodrigo Berriel Date: Thu, 16 Sep 2021 06:33:40 -0700 Subject: [PATCH] Improve LSTM documentation for proj_size > 0 (#65102) Summary: Fixes https://github.com/pytorch/pytorch/issues/65053. Although the documentation states that: https://github.com/pytorch/pytorch/blob/fe0f9d1dafb9791cb08635636a01128850d17538/torch/nn/modules/rnn.py#L500-L506 It seems that the definition of `weight_ih_l[k]` could be improved by specifying what happens when `k > 0` and `proj_size > 0`. As `proj_size` is only used in LSTM, no changes are needed for the other RNNs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/65102 Reviewed By: supriyar Differential Revision: D30975781 Pulled By: jbschlosser fbshipit-source-id: 12df06e5e6a8d5de0ad10fb15e33c3e6311c11d3 --- torch/nn/modules/rnn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 0d92428..47dc5eb 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -563,7 +563,9 @@ class LSTM(RNNBase): Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. - Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)` + Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If + ``proj_size > 0`` was specified, the shape will be + `(4*hidden_size, num_directions * proj_size)` for `k > 0` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` was specified, the shape will be `(4*hidden_size, proj_size)`. -- 2.7.4