Improve pack_sequence and pack_padded_sequence error message (#16084)
authorRichard Zou <zou3519@gmail.com>
Fri, 18 Jan 2019 15:56:17 +0000 (07:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 18 Jan 2019 15:58:54 +0000 (07:58 -0800)
Summary:
Mention that if enforce_sorted=True, the user can set
enforce_sorted=False. This is a new flag that is probably hard to
discover unless one throughly reads the docs.

Fixes #15567
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16084

Differential Revision: D13701118

Pulled By: zou3519

fbshipit-source-id: c9aeb47ae9769d28b0051bcedb8f2f51a5a5c260

aten/src/ATen/native/PackedSequence.cpp
test/test_nn.py

index a414663..f4ac1da 100644 (file)
@@ -24,7 +24,13 @@ std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Ten
            "in 'lengths' that is <= 0");
   for(auto i = 0; i < batch_size - 1; i++) {
     if (lengths[batch_size - 1 - i] > lengths[batch_size - 2 - i]) {
-      AT_ERROR("'lengths' array has to be sorted in decreasing order");
+      // NB: enforce_sorted is implemented at a Python level, but the sortedness
+      // check lives here. If enforce_sorted=False then this error should never
+      // get called.
+      AT_ERROR("`lengths` array must be sorted in decreasing order when "
+               "`enforce_sorted` is True. You can pass `enforce_sorted=False` "
+               "to pack_padded_sequence and/or pack_sequence to sidestep this "
+               "requirement if you do not need ONNX exportability.");
     }
   }
 
index cc966d0..82e740b 100644 (file)
@@ -4362,7 +4362,10 @@ class TestNN(NNTestCase):
         self.assertTrue(packed_enforce_sorted.sorted_indices is None)
         self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
 
-        with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'):
+        with self.assertRaisesRegex(RuntimeError, 'must be sorted in decreasing order'):
+            rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
+
+        with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
             rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
 
         # more dimensions
@@ -4456,6 +4459,10 @@ class TestNN(NNTestCase):
                 if l < 10:
                     self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
 
+        # test error message
+        with self.assertRaisesRegex(RuntimeError, 'You can pass `enforce_sorted=False`'):
+            packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
+
     def _test_variable_sequence(self, device="cpu", dtype=torch.float):
         def pad(var, length):
             if var.size(0) == length: