self.ogate_cy = torch.nn.quantized.FloatFunctional()
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
- if hidden is None or hidden[0] is None or hidden[1] is None:
+ if hidden is None or hidden == (None, None):
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
hidden_bw = hx_bw, cx_bw
- if hx_fw is None and cx_fw is None:
- hidden_fw = None
- else:
- hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
+ hidden_fw = hx_fw, cx_fw
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
- if hasattr(self, 'layer_bw') and self.bidirectional:
+ if self.bidirectional:
x_reversed = x.flip(0)
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
result_bw = result_bw.flip(0)
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
else:
result = result_fw
- h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
+ h, c = hidden_fw # type: ignore[assignment]
if self.batch_first:
result.transpose_(0, 1)
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer
+ # Getters for the weights and biases
+ # Note that jit currently doesn't support the `porperty`, so if you need to
+ # access the weights/biases you would need to navigate manually to the
+ # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883
+ @property
+ def weight_ih(self):
+ return self.layer_fw.cell.igates.weight
+
+ @property
+ def weight_hh(self):
+ return self.layer_fw.cell.hgates.weight
+
+ @property
+ def bias_ih(self):
+ return self.layer_fw.cell.igates.bias
+
+ @property
+ def bias_hh(self):
+ return self.layer_fw.cell.hgates.bias
+
+ @property
+ def weight_ih_reverse(self):
+ assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+ return self.layer_bw.cell.igates.weight
+
+ @property
+ def weight_hh_reverse(self):
+ assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+ return self.layer_bw.cell.hgates.weight
+
+ @property
+ def bias_ih_reverse(self):
+ assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+ return self.layer_bw.cell.igates.bias
+
+ @property
+ def bias_hh_reverse(self):
+ assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
+ return self.layer_bw.cell.hgates.bias
+
class LSTM(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM).
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
- hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
+ hxcx = []
+ for idx in range(self.num_layers):
+ hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0)))
else:
hxcx = hidden_non_opt
- for idx, layer in enumerate(self.layers):
- x, hxcx[idx] = layer(x, hxcx[idx])
+ for idx in range(self.num_layers):
+ x, hxcx[idx] = self.layers[idx](x, hxcx[idx])
hx_list = []
cx_list = []