From e161872aab00f3ca347ea32b972aab53660fc382 Mon Sep 17 00:00:00 2001 From: David Riazati Date: Thu, 2 Sep 2021 16:58:59 -0700 Subject: [PATCH] Revert D30732630: [quant] Enable jit tracing on quantizable LSTM Test Plan: revert-hammer Differential Revision: D30732630 (https://github.com/pytorch/pytorch/commit/116142143cc2d66c7e582d9f96e00862456fd736) Original commit changeset: 443e351ebb0e fbshipit-source-id: 49001392f01366f3b1ccc31139f824c80b86cd40 --- test/quantization/core/test_quantized_op.py | 7 ---- torch/nn/quantizable/modules/rnn.py | 59 ++++++++++++++++++++++++----- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 6275174..49b7c96 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -2476,13 +2476,6 @@ class TestQuantizedOps(TestCase): msg=(f"Error is too high: SNR(dB): {power}, " f"Signal: {signal}, MSE: {mse}")) - # Trace - jit_qmodule = torch.jit.trace(lstm_quantized, qx) - - # Script - # TODO: Fix the scripting in the torch/nn/quantizable/modules/rnn.py - # jit_qmodule = torch.jit.script(lstm_quantized) - @override_qengines def test_custom_module_multi_head_attention(self): class MultiheadAttentionModel(torch.nn.Module): diff --git a/torch/nn/quantizable/modules/rnn.py b/torch/nn/quantizable/modules/rnn.py index cd0d094..bdfd778 100644 --- a/torch/nn/quantizable/modules/rnn.py +++ b/torch/nn/quantizable/modules/rnn.py @@ -48,7 +48,7 @@ class LSTMCell(torch.nn.Module): self.ogate_cy = torch.nn.quantized.FloatFunctional() def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: - if hidden is None or hidden[0] is None or hidden[1] is None: + if hidden is None or hidden == (None, None): hidden = self.initialize_hidden(x.shape[0], x.is_quantized) hx, cx = hidden @@ -175,13 +175,10 @@ class _LSTMLayer(torch.nn.Module): cx_bw = cx_fw[1] cx_fw = cx_fw[0] hidden_bw = hx_bw, cx_bw - if hx_fw is None and cx_fw is None: - hidden_fw = None - else: - hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw) + hidden_fw = hx_fw, cx_fw result_fw, hidden_fw = self.layer_fw(x, hidden_fw) - if hasattr(self, 'layer_bw') and self.bidirectional: + if self.bidirectional: x_reversed = x.flip(0) result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) result_bw = result_bw.flip(0) @@ -191,7 +188,7 @@ class _LSTMLayer(torch.nn.Module): c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] else: result = result_fw - h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] + h, c = hidden_fw # type: ignore[assignment] if self.batch_first: result.transpose_(0, 1) @@ -230,6 +227,46 @@ class _LSTMLayer(torch.nn.Module): layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) return layer + # Getters for the weights and biases + # Note that jit currently doesn't support the `porperty`, so if you need to + # access the weights/biases you would need to navigate manually to the + # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883 + @property + def weight_ih(self): + return self.layer_fw.cell.igates.weight + + @property + def weight_hh(self): + return self.layer_fw.cell.hgates.weight + + @property + def bias_ih(self): + return self.layer_fw.cell.igates.bias + + @property + def bias_hh(self): + return self.layer_fw.cell.hgates.bias + + @property + def weight_ih_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.igates.weight + + @property + def weight_hh_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.hgates.weight + + @property + def bias_ih_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.igates.bias + + @property + def bias_hh_reverse(self): + assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' + return self.layer_bw.cell.hgates.bias + class LSTM(torch.nn.Module): r"""A quantizable long short-term memory (LSTM). @@ -325,12 +362,14 @@ class LSTM(torch.nn.Module): cx = hidden_non_opt[1].reshape(self.num_layers, num_directions, max_batch_size, self.hidden_size).unbind(0) - hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)] + hxcx = [] + for idx in range(self.num_layers): + hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0))) else: hxcx = hidden_non_opt - for idx, layer in enumerate(self.layers): - x, hxcx[idx] = layer(x, hxcx[idx]) + for idx in range(self.num_layers): + x, hxcx[idx] = self.layers[idx](x, hxcx[idx]) hx_list = [] cx_list = [] -- 2.7.4