self.ogate_cy = torch.nn.quantized.FloatFunctional()
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
- if hidden is None or hidden == (None, None):
+ if hidden is None or hidden[0] is None or hidden[1] is None:
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
hidden_bw = hx_bw, cx_bw
- hidden_fw = hx_fw, cx_fw
+ if hx_fw is None and cx_fw is None:
+ hidden_fw = None
+ else:
+ hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
- if self.bidirectional:
+ if hasattr(self, 'layer_bw') and self.bidirectional:
x_reversed = x.flip(0)
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
result_bw = result_bw.flip(0)
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
- h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
- c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
+ if hidden_fw is None and hidden_bw is None:
+ h = None
+ c = None
+ elif hidden_fw is None:
+ h = hidden_bw[0]
+ c = hidden_bw[1]
+ elif hidden_bw is None:
+ h = hidden_fw[0]
+ c = hidden_fw[1]
+ else:
+ h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
+ c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
else:
result = result_fw
- h, c = hidden_fw # type: ignore[assignment]
+ h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
if self.batch_first:
result.transpose_(0, 1)
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer
- # Getters for the weights and biases
- # Note that jit currently doesn't support the `porperty`, so if you need to
- # access the weights/biases you would need to navigate manually to the
- # `layer_fw.cell.igates.*`: https://github.com/pytorch/pytorch/issues/37883
- @property
- def weight_ih(self):
- return self.layer_fw.cell.igates.weight
-
- @property
- def weight_hh(self):
- return self.layer_fw.cell.hgates.weight
-
- @property
- def bias_ih(self):
- return self.layer_fw.cell.igates.bias
-
- @property
- def bias_hh(self):
- return self.layer_fw.cell.hgates.bias
-
- @property
- def weight_ih_reverse(self):
- assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
- return self.layer_bw.cell.igates.weight
-
- @property
- def weight_hh_reverse(self):
- assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
- return self.layer_bw.cell.hgates.weight
-
- @property
- def bias_ih_reverse(self):
- assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
- return self.layer_bw.cell.igates.bias
-
- @property
- def bias_hh_reverse(self):
- assert self.bidirectional, 'There is no reverse path in the non-bidirectional layer'
- return self.layer_bw.cell.hgates.bias
-
class LSTM(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM).
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
- hxcx = []
- for idx in range(self.num_layers):
- hxcx.append((hx[idx].squeeze_(0), cx[idx].squeeze_(0)))
+ hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
else:
hxcx = hidden_non_opt
- for idx in range(self.num_layers):
- x, hxcx[idx] = self.layers[idx](x, hxcx[idx])
+ for idx, layer in enumerate(self.layers):
+ x, hxcx[idx] = layer(x, hxcx[idx])
hx_list = []
cx_list = []