From 6bb4b5d150ab51ed15d15ed270471848bb84d4e3 Mon Sep 17 00:00:00 2001 From: Matti Picus Date: Tue, 31 Aug 2021 18:54:44 -0700 Subject: [PATCH] disallow empty named dims list to flatten(names, name) (#61953) Summary: Fixes https://github.com/pytorch/pytorch/issues/61137 by raising an error if an empty tuple is passed in for the names: ``` >>> torch.empty((2, 3), names=['a', 'b']).flatten((), 'abc') RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty ``` or from the original issue: ``` >>> torch.empty((2, 3)).flatten((), 'abc') RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/61953 Reviewed By: iramazanli Differential Revision: D30574571 Pulled By: malfet fbshipit-source-id: e606e84458a8dd66e5da6d0eb1a260f37b4ce91b --- aten/src/ATen/native/TensorShape.cpp | 2 ++ test/test_namedtensor.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 1dc2a27..edbfa23 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -2042,6 +2042,8 @@ Tensor flatten(const Tensor& self, Dimname start_dim, Dimname end_dim, Dimname o Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) { auto positions = dimnames_to_positions(self, dims); + TORCH_CHECK(positions.size() > 0, + "flatten(tensor, dims, out_dim): dims cannot be empty"); for (const auto i : c10::irange(positions.size() - 1)) { if (positions[i] + 1 == positions[i + 1]) continue; TORCH_CHECK(positions[i] + 1 == positions[i + 1], diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index b5e7aac..2c6d2d8 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1072,6 +1072,11 @@ class TestNamedTensor(TestCase): with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): tensor.flatten(['H', 'D', 'W'], 'features') + def test_flatten_nodims(self): + tensor = torch.empty((2, 3)) + with self.assertRaisesRegex(RuntimeError, "cannot be empty"): + tensor.flatten((), 'abcd') + def test_unflatten(self): # test args: tensor, int, namedshape self.assertTrue(torch.equal( -- 2.7.4