disallow empty named dims list to flatten(names, name) (#61953)
authorMatti Picus <matti.picus@gmail.com>
Wed, 1 Sep 2021 01:54:44 +0000 (18:54 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 02:32:30 +0000 (19:32 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61137 by raising an error if an empty tuple is passed in for the names:
```
>>> torch.empty((2, 3), names=['a', 'b']).flatten((), 'abc')
RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty
```

or from the original issue:
```
>>> torch.empty((2, 3)).flatten((), 'abc')
RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61953

Reviewed By: iramazanli

Differential Revision: D30574571

Pulled By: malfet

fbshipit-source-id: e606e84458a8dd66e5da6d0eb1a260f37b4ce91b

aten/src/ATen/native/TensorShape.cpp
test/test_namedtensor.py

index 1dc2a27..edbfa23 100644 (file)
@@ -2042,6 +2042,8 @@ Tensor flatten(const Tensor& self, Dimname start_dim, Dimname end_dim, Dimname o
 
 Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) {
   auto positions = dimnames_to_positions(self, dims);
+  TORCH_CHECK(positions.size() > 0,
+      "flatten(tensor, dims, out_dim): dims cannot be empty");
   for (const auto i : c10::irange(positions.size() - 1)) {
     if (positions[i] + 1 == positions[i + 1]) continue;
     TORCH_CHECK(positions[i] + 1 == positions[i + 1],
index b5e7aac..2c6d2d8 100644 (file)
@@ -1072,6 +1072,11 @@ class TestNamedTensor(TestCase):
         with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
             tensor.flatten(['H', 'D', 'W'], 'features')
 
+    def test_flatten_nodims(self):
+        tensor = torch.empty((2, 3))
+        with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
+            tensor.flatten((), 'abcd')
+
     def test_unflatten(self):
         # test args: tensor, int, namedshape
         self.assertTrue(torch.equal(