From: Zafar Takhirov Date: Wed, 8 Sep 2021 20:32:29 +0000 (-0700) Subject: [quant] Enable jit tracing on quantizable LSTM (resubmission) (#64638) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~368 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=24e1315d4b0b4d064472868afe8ae492e25887b8;p=platform%2Fupstream%2Fpytorch.git [quant] Enable jit tracing on quantizable LSTM (resubmission) (#64638) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64638 The quantizable LSTM didn't support jit tracing because it had several non taceable paths. We sacrifice some of the user experience to enable the tracing. The main UX feature removed is a user-friendly message when trying to access the backwards path in a bidirectional LSTM: When the bidirectional flag is False, we used to throw a nice error message when the user tried accessing backwards weights. Now the message is default (removed properties). Test Plan: `buck test mode/dev //caffe2/test:quantization -- test_custom_module_lstm` Reviewed By: HDCharles Differential Revision: D30803753 fbshipit-source-id: a639955a96cee22538d9436f1c952a5d121f50f9 --- diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 49b7c96..6275174 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -2476,6 +2476,13 @@ class TestQuantizedOps(TestCase): msg=(f"Error is too high: SNR(dB): {power}, " f"Signal: {signal}, MSE: {mse}")) + # Trace + jit_qmodule = torch.jit.trace(lstm_quantized, qx) + + # Script + # TODO: Fix the scripting in the torch/nn/quantizable/modules/rnn.py + # jit_qmodule = torch.jit.script(lstm_quantized) + @override_qengines def test_custom_module_multi_head_attention(self): class MultiheadAttentionModel(torch.nn.Module): diff --git a/torch/nn/quantizable/modules/rnn.py b/torch/nn/quantizable/modules/rnn.py index bdfd778..7692261 100644 --- a/torch/nn/quantizable/modules/rnn.py +++ b/torch/nn/quantizable/modules/rnn.py @@ -48,7 +48,7 @@ class LSTMCell(torch.nn.Module): self.ogate_cy = torch.nn.quantized.FloatFunctional() def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: - if hidden is None or hidden == (None, None): + if hidden is None or hidden[0] is None or hidden[1] is None: hidden = self.initialize_hidden(x.shape[0], x.is_quantized) hx, cx = hidden @@ -175,20 +175,33 @@ class _LSTMLayer(torch.nn.Module): cx_bw = cx_fw[1] cx_fw = cx_fw[0] hidden_bw = hx_bw, cx_bw - hidden_fw = hx_fw, cx_fw + if hx_fw is None and cx_fw is None: + hidden_fw = None + else: + hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw) result_fw, hidden_fw = self.layer_fw(x, hidden_fw) - if self.bidirectional: + if hasattr(self, 'layer_bw') and self.bidirectional: x_reversed = x.flip(0) result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw) result_bw = result_bw.flip(0) result = torch.cat([result_fw, result_bw], result_fw.dim() - 1) - h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] - c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] + if hidden_fw is None and hidden_bw is None: + h = None + c = None + elif hidden_fw is None: + h = hidden_bw[0] + c = hidden_bw[1] + elif hidden_bw is None: + h = hidden_fw[0] + c = hidden_fw[1] + else: + h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item] + c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item] else: result = result_fw - h, c = hidden_fw # type: ignore[assignment] + h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment] if self.batch_first: result.transpose_(0, 1) @@ -227,46 +240,6 @@ class _LSTMLayer(torch.nn.Module): layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh) return layer - # Getters for the weights and biases - # Note that jit currently doesn't support the `porperty`, so if you need to - # access the weights/biases you would need to navigate manually to the - # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883 - @property - def weight_ih(self): - return self.layer_fw.cell.igates.weight - - @property - def weight_hh(self): - return self.layer_fw.cell.hgates.weight - - @property - def bias_ih(self): - return self.layer_fw.cell.igates.bias - - @property - def bias_hh(self): - return self.layer_fw.cell.hgates.bias - - @property - def weight_ih_reverse(self): - assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' - return self.layer_bw.cell.igates.weight - - @property - def weight_hh_reverse(self): - assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' - return self.layer_bw.cell.hgates.weight - - @property - def bias_ih_reverse(self): - assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' - return self.layer_bw.cell.igates.bias - - @property - def bias_hh_reverse(self): - assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer' - return self.layer_bw.cell.hgates.bias - class LSTM(torch.nn.Module): r"""A quantizable long short-term memory (LSTM). @@ -362,14 +335,12 @@ class LSTM(torch.nn.Module): cx = hidden_non_opt[1].reshape(self.num_layers, num_directions, max_batch_size, self.hidden_size).unbind(0) - hxcx = [] - for idx in range(self.num_layers): - hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0))) + hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)] else: hxcx = hidden_non_opt - for idx in range(self.num_layers): - x, hxcx[idx] = self.layers[idx](x, hxcx[idx]) + for idx, layer in enumerate(self.layers): + x, hxcx[idx] = layer(x, hxcx[idx]) hx_list = [] cx_list = []