Revert D30732630: [quant] Enable jit tracing on quantizable LSTM
authorDavid Riazati <davidriazati@fb.com>
Thu, 2 Sep 2021 23:58:59 +0000 (16:58 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 3 Sep 2021 00:08:26 +0000 (17:08 -0700)
Test Plan: revert-hammer

Differential Revision:
D30732630 (https://github.com/pytorch/pytorch/commit/116142143cc2d66c7e582d9f96e00862456fd736)

Original commit changeset: 443e351ebb0e

fbshipit-source-id: 49001392f01366f3b1ccc31139f824c80b86cd40

test/quantization/core/test_quantized_op.py
torch/nn/quantizable/modules/rnn.py

index 6275174..49b7c96 100644 (file)
@@ -2476,13 +2476,6 @@ class TestQuantizedOps(TestCase):
                         msg=(f"Error is too high: SNR(dB): {power}, "
                              f"Signal: {signal}, MSE: {mse}"))
 
-                # Trace
-                jit_qmodule = torch.jit.trace(lstm_quantized, qx)
-
-                # Script
-                # TODO: Fix the scripting in the torch/nn/quantizable/modules/rnn.py
-                # jit_qmodule = torch.jit.script(lstm_quantized)
-
     @override_qengines
     def test_custom_module_multi_head_attention(self):
         class MultiheadAttentionModel(torch.nn.Module):
index cd0d094..bdfd778 100644 (file)
@@ -48,7 +48,7 @@ class LSTMCell(torch.nn.Module):
         self.ogate_cy = torch.nn.quantized.FloatFunctional()
 
     def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
-        if hidden is None or hidden[0] is None or hidden[1] is None:
+        if hidden is None or hidden == (None, None):
             hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
         hx, cx = hidden
 
@@ -175,13 +175,10 @@ class _LSTMLayer(torch.nn.Module):
                 cx_bw = cx_fw[1]
                 cx_fw = cx_fw[0]
             hidden_bw = hx_bw, cx_bw
-        if hx_fw is None and cx_fw is None:
-            hidden_fw = None
-        else:
-            hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
+        hidden_fw = hx_fw, cx_fw
         result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
 
-        if hasattr(self, 'layer_bw') and self.bidirectional:
+        if self.bidirectional:
             x_reversed = x.flip(0)
             result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
             result_bw = result_bw.flip(0)
@@ -191,7 +188,7 @@ class _LSTMLayer(torch.nn.Module):
             c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)  # type: ignore[list-item]
         else:
             result = result_fw
-            h, c = torch.jit._unwrap_optional(hidden_fw)  # type: ignore[assignment]
+            h, c = hidden_fw  # type: ignore[assignment]
 
         if self.batch_first:
             result.transpose_(0, 1)
@@ -230,6 +227,46 @@ class _LSTMLayer(torch.nn.Module):
             layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
         return layer
 
+    # Getters for the weights and biases
+    # Note that jit currently doesn't support the `porperty`, so if you need to
+    # access the weights/biases you would need to navigate manually to the
+    # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883
+    @property
+    def weight_ih(self):
+        return self.layer_fw.cell.igates.weight
+
+    @property
+    def weight_hh(self):
+        return self.layer_fw.cell.hgates.weight
+
+    @property
+    def bias_ih(self):
+        return self.layer_fw.cell.igates.bias
+
+    @property
+    def bias_hh(self):
+        return self.layer_fw.cell.hgates.bias
+
+    @property
+    def weight_ih_reverse(self):
+        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+        return self.layer_bw.cell.igates.weight
+
+    @property
+    def weight_hh_reverse(self):
+        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+        return self.layer_bw.cell.hgates.weight
+
+    @property
+    def bias_ih_reverse(self):
+        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+        return self.layer_bw.cell.igates.bias
+
+    @property
+    def bias_hh_reverse(self):
+        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+        return self.layer_bw.cell.hgates.bias
+
 
 class LSTM(torch.nn.Module):
     r"""A quantizable long short-term memory (LSTM).
@@ -325,12 +362,14 @@ class LSTM(torch.nn.Module):
                 cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
                                                max_batch_size,
                                                self.hidden_size).unbind(0)
-                hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
+                hxcx = []
+                for idx in range(self.num_layers):
+                    hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0)))
             else:
                 hxcx = hidden_non_opt
 
-        for idx, layer in enumerate(self.layers):
-            x, hxcx[idx] = layer(x, hxcx[idx])
+        for idx in range(self.num_layers):
+            x, hxcx[idx] = self.layers[idx](x, hxcx[idx])
 
         hx_list = []
         cx_list = []