"in 'lengths' that is <= 0");
for(auto i = 0; i < batch_size - 1; i++) {
if (lengths[batch_size - 1 - i] > lengths[batch_size - 2 - i]) {
- AT_ERROR("'lengths' array has to be sorted in decreasing order");
+ // NB: enforce_sorted is implemented at a Python level, but the sortedness
+ // check lives here. If enforce_sorted=False then this error should never
+ // get called.
+ AT_ERROR("`lengths` array must be sorted in decreasing order when "
+ "`enforce_sorted` is True. You can pass `enforce_sorted=False` "
+ "to pack_padded_sequence and/or pack_sequence to sidestep this "
+ "requirement if you do not need ONNX exportability.");
}
}
self.assertTrue(packed_enforce_sorted.sorted_indices is None)
self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
- with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'):
+ with self.assertRaisesRegex(RuntimeError, 'must be sorted in decreasing order'):
+ rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
+
+ with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
# more dimensions
if l < 10:
self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
+ # test error message
+ with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
+ packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
+
def _test_variable_sequence(self, device="cpu", dtype=torch.float):
def pad(var, length):
if var.size(0) == length: