[quant] Enable jit tracing on quantizable LSTM (resubmission) (#64638)
authorZafar Takhirov <zaf@fb.com>
Wed, 8 Sep 2021 20:32:29 +0000 (13:32 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 20:34:18 +0000 (13:34 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64638

The quantizable LSTM didn't support jit tracing because it had several non taceable paths. We sacrifice some of the user experience to enable the tracing.
The main UX feature removed is a user-friendly message when trying to access the backwards path in a bidirectional LSTM: When the bidirectional flag is False, we used to throw a nice error message when the user tried accessing backwards weights. Now the message is default (removed properties).

Test Plan: `buck test mode/dev //caffe2/test:quantization -- test_custom_module_lstm`

Reviewed By: HDCharles

Differential Revision: D30803753

fbshipit-source-id: a639955a96cee22538d9436f1c952a5d121f50f9

test/quantization/core/test_quantized_op.py
torch/nn/quantizable/modules/rnn.py

index 49b7c96..6275174 100644 (file)
@@ -2476,6 +2476,13 @@ class TestQuantizedOps(TestCase):
                         msg=(f"Error is too high: SNR(dB): {power}, "
                              f"Signal: {signal}, MSE: {mse}"))
 
+                # Trace
+                jit_qmodule = torch.jit.trace(lstm_quantized, qx)
+
+                # Script
+                # TODO: Fix the scripting in the torch/nn/quantizable/modules/rnn.py
+                # jit_qmodule = torch.jit.script(lstm_quantized)
+
     @override_qengines
     def test_custom_module_multi_head_attention(self):
         class MultiheadAttentionModel(torch.nn.Module):
index bdfd778..7692261 100644 (file)
@@ -48,7 +48,7 @@ class LSTMCell(torch.nn.Module):
         self.ogate_cy = torch.nn.quantized.FloatFunctional()
 
     def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
-        if hidden is None or hidden == (None, None):
+        if hidden is None or hidden[0] is None or hidden[1] is None:
             hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
         hx, cx = hidden
 
@@ -175,20 +175,33 @@ class _LSTMLayer(torch.nn.Module):
                 cx_bw = cx_fw[1]
                 cx_fw = cx_fw[0]
             hidden_bw = hx_bw, cx_bw
-        hidden_fw = hx_fw, cx_fw
+        if hx_fw is None and cx_fw is None:
+            hidden_fw = None
+        else:
+            hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
         result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
 
-        if self.bidirectional:
+        if hasattr(self, 'layer_bw') and self.bidirectional:
             x_reversed = x.flip(0)
             result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
             result_bw = result_bw.flip(0)
 
             result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
-            h = torch.stack([hidden_fw[0], hidden_bw[0]], 0)  # type: ignore[list-item]
-            c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)  # type: ignore[list-item]
+            if hidden_fw is None and hidden_bw is None:
+                h = None
+                c = None
+            elif hidden_fw is None:
+                h = hidden_bw[0]
+                c = hidden_bw[1]
+            elif hidden_bw is None:
+                h = hidden_fw[0]
+                c = hidden_fw[1]
+            else:
+                h = torch.stack([hidden_fw[0], hidden_bw[0]], 0)  # type: ignore[list-item]
+                c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)  # type: ignore[list-item]
         else:
             result = result_fw
-            h, c = hidden_fw  # type: ignore[assignment]
+            h, c = torch.jit._unwrap_optional(hidden_fw)  # type: ignore[assignment]
 
         if self.batch_first:
             result.transpose_(0, 1)
@@ -227,46 +240,6 @@ class _LSTMLayer(torch.nn.Module):
             layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
         return layer
 
-    # Getters for the weights and biases
-    # Note that jit currently doesn't support the `porperty`, so if you need to
-    # access the weights/biases you would need to navigate manually to the
-    # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883
-    @property
-    def weight_ih(self):
-        return self.layer_fw.cell.igates.weight
-
-    @property
-    def weight_hh(self):
-        return self.layer_fw.cell.hgates.weight
-
-    @property
-    def bias_ih(self):
-        return self.layer_fw.cell.igates.bias
-
-    @property
-    def bias_hh(self):
-        return self.layer_fw.cell.hgates.bias
-
-    @property
-    def weight_ih_reverse(self):
-        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
-        return self.layer_bw.cell.igates.weight
-
-    @property
-    def weight_hh_reverse(self):
-        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
-        return self.layer_bw.cell.hgates.weight
-
-    @property
-    def bias_ih_reverse(self):
-        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
-        return self.layer_bw.cell.igates.bias
-
-    @property
-    def bias_hh_reverse(self):
-        assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
-        return self.layer_bw.cell.hgates.bias
-
 
 class LSTM(torch.nn.Module):
     r"""A quantizable long short-term memory (LSTM).
@@ -362,14 +335,12 @@ class LSTM(torch.nn.Module):
                 cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
                                                max_batch_size,
                                                self.hidden_size).unbind(0)
-                hxcx = []
-                for idx in range(self.num_layers):
-                    hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0)))
+                hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
             else:
                 hxcx = hidden_non_opt
 
-        for idx in range(self.num_layers):
-            x, hxcx[idx] = self.layers[idx](x, hxcx[idx])
+        for idx, layer in enumerate(self.layers):
+            x, hxcx[idx] = layer(x, hxcx[idx])
 
         hx_list = []
         cx_list = []