Improve error message w/ size inference on empty tensors
authorSsnL <tongzhou.wang.1994@gmail.com>
Wed, 20 Feb 2019 16:58:49 +0000 (08:58 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 20 Feb 2019 17:12:26 +0000 (09:12 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17255

Differential Revision: D14143094

Pulled By: soumith

fbshipit-source-id: f96fa7f8eb6eaac72887d3e837546cbfa505f101

aten/src/ATen/InferSize.h
test/test_torch.py

index 907a867..d001bbd 100644 (file)
@@ -28,9 +28,17 @@ static std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
 
   if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
     if (infer_dim) {
-      // we have a degree of freedom here to select the dimension size; follow NumPy semantics
-      // and just bail.
-      AT_CHECK(newsize != 0, "cannot reshape tensor of 0 elements into shape ", shape);
+      // We have a degree of freedom here to select the dimension size; follow
+      // NumPy semantics and just bail.  However, a nice error message is needed
+      // because users often use `view` as a way to flatten & unflatten
+      // dimensions and will otherwise be confused why
+      //   empty_tensor.view( 0, 0)
+      // works yet
+      //   empty_tensor.view(-1, 0)
+      // doesn't.
+      AT_CHECK(newsize != 0, "cannot reshape tensor of 0 elements into shape ",
+               shape, " because the unspecified dimension size -1 can be any "
+               "value and is ambiguous");
       res[*infer_dim] = numel / newsize;
     }
     return res;
index cc76a22..3d95d82 100644 (file)
@@ -7186,7 +7186,7 @@ class _TestTorchMixin(object):
     def _test_view(self, cast):
         tensor = cast(torch.rand(15))
         template = cast(torch.rand(3, 5))
-        empty = cast(torch.Tensor())
+        empty = cast(torch.empty(0))
         target = template.size()
         self.assertEqual(tensor.view_as(template).size(), target)
         self.assertEqual(tensor.view(3, 5).size(), target)
@@ -7197,9 +7197,23 @@ class _TestTorchMixin(object):
         tensor_view.fill_(random.uniform(0, 1))
         self.assertEqual(empty.view_as(empty), empty)
         self.assertEqual(empty.view(0), empty)
+        self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
+        self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
+
+        # test size inference with empty tensors
+        self.assertEqual(empty.view(-1).size(), torch.Size([0]))
+        self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
+
+        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
+            empty.view(-1, 0)
+
+        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
+            empty.view(3, 0, -1, 0)
+
         self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
         self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
         self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
+
         # test view when tensor is not contiguous in every dimension, but only
         # contiguous dimensions are touched.
         tensor = cast(torch.rand(4, 2, 5, 1, 6, 2, 9, 3)).transpose(-1, 2).transpose(-2, 3)
@@ -7236,7 +7250,7 @@ class _TestTorchMixin(object):
         self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5))
 
         # view with stride 0 dims
-        tensor = cast(torch.Tensor(1, 1)).expand(3, 4)  # all dims are contiguous
+        tensor = cast(torch.empty(1, 1)).expand(3, 4)  # all dims are contiguous
         contig_tensor = tensor.clone()
         self.assertEqual(tensor.view(-1), contig_tensor.view(-1))
         self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1))