Add option to automatically handle unsorted variable-length sequences in RNNs (#15225)
authorRichard Zou <zou3519@gmail.com>
Fri, 21 Dec 2018 01:34:41 +0000 (17:34 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 21 Dec 2018 01:37:18 +0000 (17:37 -0800)
Summary:
Fixes #3584.

Motivation: manually sorting sequences, packing them, and then unsorting them
is something a lot of users have complained about doing, especially when we can
offer library support for them.

Overview: we internally sort sequences before packing them and store a list of
`unsorted_indices` that represent how to unsort the sequences inside
PackedSequence. The packing helper functions return PackedSequence with the
`permutation` field and the unpacking helper functions use it to unsort.

To implement this, the following changes were made:
- PackedSequence now keeps `sorted_indices` and `unsorted_indices`.
  These two can be thought of as permutations and are inverses of each other.
  `sorted_indices` is how the sequences were sorted; `unsorted_indices` is how
  to unsort the sequences.
- Added an `enforce_sorted` argument to pack_sequence and pack_padded_sequence
  that maintains the legacy behavior of error-ing out on unsorted-sequences.
  When `enforce_sorted=True`, these functions maintain their ONNX exportability.
- pack_sequence(sequences, enforce_sorted) takes in unsorted sequences.
- pack_padded_sequence can take in a padded tensor that represents padded,
  unsorted sequences.
- pad_packed_sequence unsorts the PackedSequence such that it is still the
  inverse operation of packed_padded_sequence.
- RNNs apply `sort_indices` to their input hidden state and apply
  `unsort_indices` to their output hidden state. This is to ensure that the
  hidden state batches correspond to the user's ordering of input sequences.

NOT BC-Breaking
- The default for pack_sequence and pack_padded_sequence is
  `enforce_sorted=True` to avoid breaking ONNX export. To use the new
  functionality, pass in `enforce_sorted=False`

Testing Plan
- Modified TestNN.test_pack_sequence, TestNN.test_packed_padded_sequence,
  and TestNN.test_variable_sequence (RNN test) to check the behavior
  of unsorted sequences, sorted sequences, and sorted sequences with
  enforce_sorted=True
- test/test_jit.py has a test to see if RNNs are exportable with
  enforce_sorted=True

cc colesbury
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15225

Reviewed By: soumith

Differential Revision: D13507138

Pulled By: zou3519

fbshipit-source-id: b871dccd6abefffca81bc4e3efef1873faa242ef

test/test_nn.py
torch/nn/modules/rnn.py
torch/nn/utils/rnn.py

index 58a0ca8..bf37b98 100644 (file)
@@ -123,31 +123,36 @@ class PackedSequenceTest(TestCase):
         """Test type casting of `PackedSequence` against type casting of tensor"""
         for _, (input_type, _) in self._type_by_name.items():
             for expected_type_str, (_, cast_str) in self._type_by_name.items():
-                padded, lengths = self._padded_sequence(input_type)
-                packed = rnn_utils.pack_padded_sequence(padded, lengths)
-                # Apply cast to `PackedSequence` instance and unpack
-                masked = getattr(packed, cast_str)()
-                unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
-                self.assertEqual(unpacked.type(), expected_type_str)
+                for enforce_sorted in [True, False]:
+                    padded, lengths = self._padded_sequence(input_type)
+                    packed = rnn_utils.pack_padded_sequence(
+                        padded, lengths, enforce_sorted=enforce_sorted)
+                    # Apply cast to `PackedSequence` instance and unpack
+                    masked = getattr(packed, cast_str)()
+                    unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
+                    self.assertEqual(unpacked.type(), expected_type_str)
 
     @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
     def test_cuda_mask(self):
-        tensor_type = torch.FloatTensor
-        cuda_type_str = 'torch.cuda.FloatTensor'
-        padded, lengths = self._padded_sequence(tensor_type)
-        packed = rnn_utils.pack_padded_sequence(padded, lengths)
-        self.assertFalse(packed.is_cuda)
-        packed = packed.cuda()
-        self.assertTrue(packed.is_cuda)
-        unpacked, _ = rnn_utils.pad_packed_sequence(packed)
-        self.assertEqual(unpacked.type(), cuda_type_str)
+        for enforce_sorted in [True, False]:
+            tensor_type = torch.FloatTensor
+            cuda_type_str = 'torch.cuda.FloatTensor'
+            padded, lengths = self._padded_sequence(tensor_type)
+            packed = rnn_utils.pack_padded_sequence(
+                padded, lengths, enforce_sorted=enforce_sorted)
+            self.assertFalse(packed.is_cuda)
+            packed = packed.cuda()
+            self.assertTrue(packed.is_cuda)
+            unpacked, _ = rnn_utils.pad_packed_sequence(packed)
+            self.assertEqual(unpacked.type(), cuda_type_str)
 
     def test_wrong_order(self):
-        # https://github.com/pytorch/pytorch/issues/13324
         a = torch.ones(25, 300)
         b = torch.ones(22, 300)
         b_a = rnn_utils.pad_sequence([b, a])
-        self.assertRaises(RuntimeError, lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25]))
+        self.assertRaises(
+            RuntimeError,
+            lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True))
 
     def test_total_length(self):
         padded, lengths = self._padded_sequence(torch.FloatTensor)
@@ -183,22 +188,24 @@ class PackedSequenceTest(TestCase):
                 self.assertEqual(unpacked, ref_output)
 
     def test_to(self):
-        padded, lengths = self._padded_sequence(torch.IntTensor)
-        a = rnn_utils.pack_padded_sequence(padded, lengths).cpu()
-
-        self.assertIs(a, a.to('cpu'))
-        self.assertIs(a, a.to('cpu', dtype=torch.int32))
-        self.assertEqual(a.long(), a.to(torch.int64))
-
-        if torch.cuda.is_available():
-            for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
-                b = a.cuda(device=cuda)
-                self.assertIs(b, b.to(cuda))
-                self.assertEqual(a, b.to('cpu'))
-                self.assertEqual(b, a.to(cuda))
-                self.assertEqual(a, b.to('cpu', dtype=torch.int32))
-                self.assertIs(b, b.to(dtype=torch.int32))
-                self.assertEqual(b.long(), b.to(dtype=torch.int64))
+        for enforce_sorted in (True, False):
+            padded, lengths = self._padded_sequence(torch.IntTensor)
+            a = rnn_utils.pack_padded_sequence(
+                padded, lengths, enforce_sorted=enforce_sorted).cpu()
+
+            self.assertIs(a, a.to('cpu'))
+            self.assertIs(a, a.to('cpu', dtype=torch.int32))
+            self.assertEqual(a.long(), a.to(torch.int64))
+
+            if torch.cuda.is_available():
+                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
+                    b = a.cuda(device=cuda)
+                    self.assertIs(b, b.to(cuda))
+                    self.assertEqual(a, b.to('cpu'))
+                    self.assertEqual(b, a.to(cuda))
+                    self.assertEqual(a, b.to('cpu', dtype=torch.int32))
+                    self.assertIs(b, b.to(dtype=torch.int32))
+                    self.assertEqual(b.long(), b.to(dtype=torch.int64))
 
 
 def default_tensor_type(type):
@@ -4203,22 +4210,41 @@ class TestNN(NNTestCase):
             self.assertEqual(padded, expected.transpose(0, 1))
 
     def test_pack_sequence(self):
-        def _compatibility_test(sequences, lengths, batch_first):
+        def _compatibility_test(sequences, lengths, batch_first, enforce_sorted=False):
             padded = rnn_utils.pad_sequence(sequences, batch_first)
-            packed = rnn_utils.pack_sequence(sequences)
+            packed = rnn_utils.pack_sequence(sequences, enforce_sorted)
             unpacked = rnn_utils.pad_packed_sequence(packed, batch_first)
             self.assertEqual(padded, unpacked[0])
-            pack_padded = rnn_utils.pack_padded_sequence(padded, lengths, batch_first)
+            pack_padded = rnn_utils.pack_padded_sequence(
+                padded, lengths, batch_first, enforce_sorted)
             self.assertEqual(packed, pack_padded)
 
         # single dimensional
         a = torch.tensor([1, 2, 3])
         b = torch.tensor([4, 5])
         c = torch.tensor([6])
-        packed = rnn_utils.pack_sequence([a, b, c])
+        packed = rnn_utils.pack_sequence([a, b, c], enforce_sorted=False)
         expected = torch.tensor([1, 4, 6, 2, 5, 3])
         self.assertEqual(packed.batch_sizes, [3, 2, 1])
         self.assertEqual(packed.data.data, expected)
+        self.assertEqual(packed.sorted_indices, [0, 1, 2])
+        self.assertEqual(packed.unsorted_indices, [0, 1, 2])
+
+        packed_unsorted = rnn_utils.pack_sequence([b, c, a], enforce_sorted=False)
+        self.assertEqual(packed_unsorted.batch_sizes, [3, 2, 1])
+        self.assertEqual(packed_unsorted.data.data, expected)
+        self.assertEqual(packed_unsorted.sorted_indices, [2, 0, 1])
+        self.assertEqual(packed_unsorted.unsorted_indices, [1, 2, 0])
+
+        # single dimensional, enforce_sorted = True
+        packed_enforce_sorted = rnn_utils.pack_sequence([a, b, c], enforce_sorted=True)
+        self.assertEqual(packed_enforce_sorted.batch_sizes, [3, 2, 1])
+        self.assertEqual(packed_enforce_sorted.data.data, expected)
+        self.assertTrue(packed_enforce_sorted.sorted_indices is None)
+        self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
+
+        with self.assertRaisesRegex(RuntimeError, 'has to be sorted in decreasing order'):
+            rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
 
         # more dimensions
         maxlen = 9
@@ -4230,34 +4256,69 @@ class TestNN(NNTestCase):
                 seq_len = i * i
                 lengths.append(seq_len)
                 sequences.append(torch.rand(seq_len, 5, *trailing_dims))
+            unsorted_sequences = [s.clone() for s in sequences]
+            random.shuffle(unsorted_sequences)
+            unsorted_sequences_lengths = [t.size(0) for t in unsorted_sequences]
 
             # compatibility with other utilities
             for batch_first in (True, False):
-                _compatibility_test(sequences, lengths, batch_first)
+                for enforce_sorted in (True, False):
+                    _compatibility_test(sequences, lengths, batch_first, enforce_sorted)
+                _compatibility_test(unsorted_sequences, unsorted_sequences_lengths,
+                                    batch_first)
 
     def test_pack_padded_sequence(self):
-        def pad(tensor, length):
-            return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()])
-        lengths = [10, 8, 4, 2, 2, 2, 1]
-        max_length = lengths[0]
-        batch_sizes = [sum(map(bool, filter(lambda x: x >= i, lengths))) for i in range(1, max_length + 1)]
-        offset = 0
-        padded = torch.cat([pad(i * 100 + torch.arange(1., 5 * l + 1).view(l, 1, 5), max_length)
-                            for i, l in enumerate(lengths, 1)], 1).requires_grad_()
-        expected_data = [[torch.arange(1., 6) + (i + 1) * 100 + 5 * n for i in range(batch_size)]
-                         for n, batch_size in enumerate(batch_sizes)]
-        expected_data = list(itertools.chain.from_iterable(expected_data))
-        expected_data = torch.stack(expected_data, dim=0)
+        def generate_test_case(sorted_lengths, should_shuffle):
+            def pad(tensor, length):
+                return torch.cat([tensor, tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_()])
+
+            max_length = sorted_lengths[0]
+            batch_sizes = [sum(map(bool, filter(lambda x: x >= i, sorted_lengths)))
+                           for i in range(1, max_length + 1)]
+            offset = 0
+            padded = torch.cat([pad(i * 100 + torch.arange(1., 5 * l + 1).view(l, 1, 5), max_length)
+                                for i, l in enumerate(sorted_lengths, 1)], 1)
+            expected_data = [[torch.arange(1., 6) + (i + 1) * 100 + 5 * n for i in range(batch_size)]
+                             for n, batch_size in enumerate(batch_sizes)]
+            expected_data = list(itertools.chain.from_iterable(expected_data))
+            expected_data = torch.stack(expected_data, dim=0)
+
+            if should_shuffle:
+                # Shuffle the padded sequence to create an unsorted sequence
+                permutation = list(range(len(sorted_lengths)))
+                random.shuffle(permutation)
+
+                unsorted_indices = torch.tensor(permutation)
+                padded = padded.index_select(1, unsorted_indices)
+                lengths = torch.tensor(sorted_lengths).index_select(0, unsorted_indices)
+            else:
+                unsorted_indices = None
+                lengths = sorted_lengths
+
+            return padded.requires_grad_(), lengths, expected_data, batch_sizes, unsorted_indices
+
+        test_cases = [
+            # sorted_lengths, should_shuffle
+            [[10, 8, 4, 2, 2, 2, 1], False],
+            [[11, 10, 8, 6, 4, 3, 1], False],
+            [[11, 10, 8, 6, 4, 3, 1], True],
+        ]
+
+        for test_case, batch_first in itertools.product(test_cases, (True, False)):
+            sorted_lengths, should_shuffle = test_case
+            padded, lengths, expected_data, batch_sizes, unsorted_indices = generate_test_case(
+                sorted_lengths, should_shuffle)
 
-        for batch_first in (True, False):
             src = padded
             if batch_first:
                 src = src.transpose(0, 1)
 
             # check output
-            packed = rnn_utils.pack_padded_sequence(src, lengths, batch_first=batch_first)
+            packed = rnn_utils.pack_padded_sequence(src, lengths, batch_first=batch_first,
+                                                    enforce_sorted=not should_shuffle)
             self.assertEqual(packed.data.data, expected_data)
             self.assertEqual(packed.batch_sizes, batch_sizes)
+            self.assertEqual(packed.unsorted_indices, unsorted_indices)
 
             # test inverse
             unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed, batch_first=batch_first)
@@ -4282,46 +4343,80 @@ class TestNN(NNTestCase):
                 return var
             return torch.cat([var, var.new_zeros(length - var.size(0), *var.size()[1:])])
 
-        lengths = [10, 10, 6, 2, 2, 1, 1]
-        max_length = lengths[0]
-        x_leaf = torch.randn(max_length, len(lengths), 3, device=device, dtype=dtype, requires_grad=True)
-        lstm = nn.LSTM(3, 4, bidirectional=True, num_layers=2).to(device, dtype)
-        lstm2 = deepcopy(lstm).to(device, dtype)
-        x = x_leaf
-
-        # Compute sequences separately
-        seq_outs = []
-        seq_hiddens = []
-        for i, l in enumerate(lengths):
-            out, hid = lstm2(x[:l, i:i + 1])
-            out_pad = pad(out, max_length)
-            seq_outs.append(out_pad)
-            seq_hiddens.append(hid)
-        seq_out = torch.cat(seq_outs, 1)
-        seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))
-
-        # Use packed format
-        packed = rnn_utils.pack_padded_sequence(x, lengths)
-        packed_out, packed_hidden = lstm(packed)
-        unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)
+        def maybe_index_tuple(maybe_tuple_of_tensors, index):
+            if maybe_tuple_of_tensors is None:
+                return None
+            return tuple(maybe_tuple_of_tensors[j][:, index:index + 1, :].contiguous()
+                         for j in range(2))
+
+        def check_lengths(lengths, enforce_sorted, use_default_hiddens):
+            input_size = 3
+            hidden_size = 4
+            num_layers = 2
+            bidirectional = True
+
+            max_length = max(lengths)
+            x_leaf = torch.randn(max_length, len(lengths), input_size, device=device,
+                                 dtype=dtype, requires_grad=True)
+            num_directions = 2 if bidirectional else 1
+            lstm = nn.LSTM(input_size, hidden_size, bidirectional=bidirectional,
+                           num_layers=num_layers).to(device, dtype)
+            lstm2 = deepcopy(lstm).to(device, dtype)
+            x = x_leaf
+
+            hidden0 = None
+            if not use_default_hiddens:
+                hidden0 = tuple(torch.randn(num_directions * num_layers, len(lengths), hidden_size,
+                                            device=device, dtype=dtype)
+                                for _ in range(2))
+
+            # Compute sequences separately
+            seq_outs = []
+            seq_hiddens = []
+            for i, l in enumerate(lengths):
+                hidden_i = maybe_index_tuple(hidden0, i)
+                out, hid = lstm2(x[:l, i:i + 1], hidden_i)
+                out_pad = pad(out, max_length)
+                seq_outs.append(out_pad)
+                seq_hiddens.append(hid)
+            seq_out = torch.cat(seq_outs, 1)
+            seq_hidden = tuple(torch.cat(hids, 1) for hids in zip(*seq_hiddens))
+
+            # Use packed format
+            packed = rnn_utils.pack_padded_sequence(x, lengths, enforce_sorted=enforce_sorted)
+            packed_out, packed_hidden = lstm(packed, hidden0)
+            unpacked, unpacked_len = rnn_utils.pad_packed_sequence(packed_out)
+
+            # Check forward
+            self.assertEqual(packed_hidden, seq_hidden)
+            self.assertEqual(unpacked, seq_out)
+            self.assertEqual(unpacked_len, lengths)
 
-        # Check forward
-        self.assertEqual(packed_hidden, seq_hidden)
-        self.assertEqual(unpacked, seq_out)
-        self.assertEqual(unpacked_len, lengths)
-
-        # Check backward
-        seq_out.sum().backward()
-        grad_x = x_leaf.grad.data.clone()
-        x_leaf.grad.data.zero_()
-        unpacked.sum().backward()
-
-        self.assertEqual(x_leaf.grad, grad_x, dtype2prec[dtype])
-        for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
-            prec = dtype2prec[dtype]
-            if dtype == torch.float16:
-                prec = 2e-2
-            self.assertEqual(p1.grad, p2.grad, prec)
+            # Check backward
+            seq_out.sum().backward()
+            grad_x = x_leaf.grad.data.clone()
+            x_leaf.grad.data.zero_()
+            unpacked.sum().backward()
+
+            self.assertEqual(x_leaf.grad, grad_x, dtype2prec[dtype])
+            for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
+                prec = dtype2prec[dtype]
+                if dtype == torch.float16:
+                    prec = 2e-2
+                self.assertEqual(p1.grad, p2.grad, prec)
+
+        tests = [
+            # enforce_sorted, lengths
+            [True, [5]],
+            [False, [5]],
+            [True, [10, 10, 6, 2, 2, 1, 1]],
+            [False, [10, 10, 6, 2, 2, 1, 1]],
+            [False, [2, 1, 3, 2, 10, 5, 3]],
+        ]
+
+        for enforce_sorted, seq_lens, in tests:
+            for use_default_hiddens in (True, False):
+                check_lengths(seq_lens, enforce_sorted, use_default_hiddens)
 
     def test_variable_sequence(self):
         self._test_variable_sequence()
index 3db8ebf..fa3f2c3 100644 (file)
@@ -19,6 +19,10 @@ _rnn_impls = {
 }
 
 
+def apply_permutation(tensor, permutation, dim=1):
+    return tensor.index_select(dim, permutation)
+
+
 class RNNBase(Module):
 
     def __init__(self, mode, input_size, hidden_size,
@@ -156,14 +160,24 @@ class RNNBase(Module):
         else:
             check_hidden_size(hidden, expected_hidden_size)
 
+    def permute_hidden(self, hx, permutation):
+        if permutation is None:
+            return hx
+        if self.mode == 'LSTM':
+            return tuple(apply_permutation(state, permutation) for state in hx)
+        else:
+            return apply_permutation(hx, permutation)
+
     def forward(self, input, hx=None):
         is_packed = isinstance(input, PackedSequence)
         if is_packed:
-            input, batch_sizes = input
+            input, batch_sizes, sorted_indices, unsorted_indices = input
             max_batch_size = int(batch_sizes[0])
         else:
             batch_sizes = None
             max_batch_size = input.size(0) if self.batch_first else input.size(1)
+            sorted_indices = None
+            unsorted_indices = None
 
         if hx is None:
             num_directions = 2 if self.bidirectional else 1
@@ -172,6 +186,10 @@ class RNNBase(Module):
                                  requires_grad=False)
             if self.mode == 'LSTM':
                 hx = (hx, hx)
+        else:
+            # Each batch of the hidden state should match the input sequence that
+            # the user believes he/she is passing in.
+            hx = self.permute_hidden(hx, sorted_indices)
 
         self.check_forward_args(input, hx, batch_sizes)
         _impl = _rnn_impls[self.mode]
@@ -185,8 +203,8 @@ class RNNBase(Module):
         hidden = result[1:] if self.mode == 'LSTM' else result[1]
 
         if is_packed:
-            output = PackedSequence(output, batch_sizes)
-        return output, hidden
+            output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
+        return output, self.permute_hidden(hidden, unsorted_indices)
 
     def extra_repr(self):
         s = '{input_size}, {hidden_size}'
index c5ad556..c394667 100644 (file)
@@ -4,7 +4,14 @@ import warnings
 import torch
 
 
-PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes'])
+PackedSequence_ = namedtuple('PackedSequence',
+                             ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])
+
+
+def bind(optional, fn):
+    if optional is None:
+        return None
+    return fn(optional)
 
 
 class PackedSequence(PackedSequence_):
@@ -28,65 +35,86 @@ class PackedSequence(PackedSequence_):
             information about the batch size at each sequence step
 
     """
-    def __new__(cls, data, batch_sizes=None):
+    def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None):
         # PackedSequence used to only have __init__(self, data, batch_sizes)
         # without a __new__ like this. So to preserve BC for calling in keyword
         # arg style (e.g., `PackedSequence(data=..., batch_sizes=...)`), we have
         # to provide two arguments with exact names `data` and `batch_sizes`.
-        #
-        # support being called as `PackedSequence(data, batch_sizes)`
+
+        # NB: if unsorted_indices is provided, it should be the inverse permutation
+        # to sorted_indices. Don't assert it here because the PackedSequence ctor
+        # should only be used internally.
+        if unsorted_indices is None:
+            unsorted_indices = invert_permutation(sorted_indices)
+
+        # support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
         if batch_sizes is not None:
-            return super(PackedSequence, cls).__new__(cls, data, batch_sizes)
-        # support being called as `PackedSequence((data, batch_sizes))`
+            return super(PackedSequence, cls).__new__(
+                cls, data, batch_sizes, sorted_indices, unsorted_indices)
+
+        # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
         else:
             assert isinstance(data, (list, tuple)) and len(data) == 2
-            return super(PackedSequence, cls).__new__(cls, *data)
+            return super(PackedSequence, cls).__new__(
+                cls, data[0], data[1], sorted_indices)
 
     def cuda(self, *args, **kwargs):
         """Returns a GPU copy if `self.data` not already on the GPU"""
         if self.is_cuda:
             return self
         else:
-            return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes)
+            return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes,
+                              bind(self.sorted_indices, lambda t: t.cuda(*args, **kwargs)),
+                              bind(self.unsorted_indices, lambda t: t.cuda(*args, **kwargs)))
 
     def cpu(self):
         """Returns a CPU copy if `self.data` not already on the CPU"""
         if self.is_cuda:
-            return type(self)(self.data.cpu(), self.batch_sizes)
+            return type(self)(self.data.cpu(), self.batch_sizes,
+                              bind(self.sorted_indices, lambda t: t.cpu()),
+                              bind(self.unsorted_indices, lambda t: t.cpu()))
         else:
             return self
 
     def double(self):
         r"""Returns copy with `self.data` cast to double type"""
-        return type(self)(self.data.double(), self.batch_sizes)
+        return type(self)(self.data.double(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def float(self):
         r"""Returns copy with `self.data` cast to float type"""
-        return type(self)(self.data.float(), self.batch_sizes)
+        return type(self)(self.data.float(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def half(self):
         r"""Returns copy with `self.data` cast to half type"""
-        return type(self)(self.data.half(), self.batch_sizes)
+        return type(self)(self.data.half(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def long(self):
         r"""Returns copy with `self.data` cast to long type"""
-        return type(self)(self.data.long(), self.batch_sizes)
+        return type(self)(self.data.long(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def int(self):
         r"""Returns copy with `self.data` cast to int type"""
-        return type(self)(self.data.int(), self.batch_sizes)
+        return type(self)(self.data.int(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def short(self):
         r"""Returns copy with `self.data` cast to short type"""
-        return type(self)(self.data.short(), self.batch_sizes)
+        return type(self)(self.data.short(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def char(self):
         r"""Returns copy with `self.data` cast to char type"""
-        return type(self)(self.data.char(), self.batch_sizes)
+        return type(self)(self.data.char(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def byte(self):
         r"""Returns copy with `self.data` cast to byte type"""
-        return type(self)(self.data.byte(), self.batch_sizes)
+        return type(self)(self.data.byte(), self.batch_sizes,
+                          self.sorted_indices, self.unsorted_indices)
 
     def to(self, *args, **kwargs):
         r"""Performs dtype and/or device conversion on `self.data`.
@@ -100,10 +128,17 @@ class PackedSequence(PackedSequence_):
             Otherwise, returns a copy with the desired configuration.
         """
         data = self.data.to(*args, **kwargs)
+        sorted_indices = self.sorted_indices
+        unsorted_indices = self.unsorted_indices
+        device_kw = 'device'
+        if device_kw in kwargs:
+            sorted_indices = bind(sorted_indices, lambda t: t.to(kwargs[device_kw]))
+            unsorted_indices = bind(unsorted_indices, lambda t: t.to(kwargs[device_kw]))
         if data is self.data:
             return self
         else:
-            return type(self)(data, self.batch_sizes)
+            return type(self)(data, self.batch_sizes,
+                              sorted_indices, unsorted_indices)
 
     @property
     def is_cuda(self):
@@ -111,7 +146,16 @@ class PackedSequence(PackedSequence_):
         return self.data.is_cuda
 
 
-def pack_padded_sequence(input, lengths, batch_first=False):
+def invert_permutation(permutation):
+    if permutation is None:
+        return None
+    output = torch.empty_like(permutation)
+    output.scatter_(0, permutation,
+                    torch.arange(0, permutation.numel(), device=permutation.device))
+    return output
+
+
+def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
     r"""Packs a Tensor containing padded sequences of variable length.
 
     Input can be of size ``T x B x *`` where `T` is the length of the longest sequence
@@ -119,9 +163,10 @@ def pack_padded_sequence(input, lengths, batch_first=False):
     dimensions (including 0). If ``batch_first`` is True ``B x T x *`` inputs are
     expected.
 
-    The sequences should be sorted by length in a decreasing order, i.e.
-    ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the
-    shortest one.
+    For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` is
+    ``True``, the sequences should be sorted by length in a decreasing order, i.e.
+    ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
+    one. `enforce_sorted = True` is only necessary for ONNX export.
 
     Note:
         This function accepts any input that has at least two dimensions. You
@@ -134,6 +179,9 @@ def pack_padded_sequence(input, lengths, batch_first=False):
         lengths (Tensor): list of sequences lengths of each batch element.
         batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
             format.
+        enforce_sorted (bool, optional): if ``True``, the input is expected to
+            contain sequences sorted by length in a decreasing order. If
+            ``False``, this condition is not checked. Default: ``True``.
 
     Returns:
         a :class:`PackedSequence` object
@@ -145,7 +193,17 @@ def pack_padded_sequence(input, lengths, batch_first=False):
                       'the trace incorrect for any other combination of lengths.',
                       category=torch.jit.TracerWarning, stacklevel=2)
     lengths = torch.as_tensor(lengths, dtype=torch.int64)
-    return PackedSequence(torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first))
+    if enforce_sorted:
+        sorted_indices = None
+    else:
+        lengths, sorted_indices = torch.sort(lengths, descending=True)
+        sorted_indices = sorted_indices.to(input.device)
+        batch_dim = 0 if batch_first else 1
+        input = input.index_select(batch_dim, sorted_indices)
+
+    data, batch_sizes = \
+        torch._C._VariableFunctions._pack_padded_sequence(input, lengths, batch_first)
+    return PackedSequence(data, batch_sizes, sorted_indices)
 
 
 def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
@@ -189,8 +247,13 @@ def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_le
                              "total_length={} and max sequence length being {}"
                              .format(total_length, max_seq_length))
         max_seq_length = total_length
-    return torch._C._VariableFunctions._pad_packed_sequence(
+    padded_output, lengths = torch._C._VariableFunctions._pad_packed_sequence(
         sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length)
+    if sequence.unsorted_indices is not None:
+        batch_dim = 0 if batch_first else 1
+        return padded_output.index_select(batch_dim, sequence.unsorted_indices), \
+            lengths[sequence.unsorted_indices]
+    return padded_output, lengths
 
 
 def pad_sequence(sequences, batch_first=False, padding_value=0):
@@ -252,12 +315,17 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
     return out_tensor
 
 
-def pack_sequence(sequences):
+def pack_sequence(sequences, enforce_sorted=True):
     r"""Packs a list of variable length Tensors
 
     ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
     the length of a sequence and `*` is any number of trailing dimensions,
-    including zero. They should be sorted in the order of decreasing length.
+    including zero.
+
+    For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
+    is ``True``, the sequences should be sorted in the order of decreasing length.
+    ``enforce_sorted = True`` is only necessary for ONNX export.
+
 
     Example:
         >>> from torch.nn.utils.rnn import pack_sequence
@@ -270,8 +338,12 @@ def pack_sequence(sequences):
 
     Arguments:
         sequences (list[Tensor]): A list of sequences of decreasing length.
+        enforce_sorted (bool, optional): if ``True``, checks that the input
+            contains sequences sorted by length in a decreasing order. If
+            ``False``, this condition is not checked. Default: ``True``.
 
     Returns:
         a :class:`PackedSequence` object
     """
-    return pack_padded_sequence(pad_sequence(sequences), [v.size(0) for v in sequences])
+    lengths = [v.size(0) for v in sequences]
+    return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)