Disable .numpy() and .tolist() for tensor subclasses subclasses and fix .tolist(...
authoranjali411 <chourdiaanjali123@gmail.com>
Thu, 14 Oct 2021 20:16:03 +0000 (16:16 -0400)
committerGitHub <noreply@github.com>
Thu, 14 Oct 2021 20:16:03 +0000 (13:16 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66082

Fixes https://github.com/pytorch/pytorch/issues/66024 #65779

cc ezyang anjali411 dylanbespalko mruberry Lezcano nikitaved albanD

Test Plan: Imported from OSS

Reviewed By: Gamrix, albanD

Differential Revision: D31615588

Pulled By: anjali411

fbshipit-source-id: c3e65ef0fe301630eb76732ccd7819683c09aa19

test/test_python_dispatch.py
test/test_torch.py
torch/csrc/utils/tensor_list.cpp
torch/csrc/utils/tensor_numpy.cpp
torch/testing/_internal/common_utils.py

index 38af5bb..c211a9b 100644 (file)
@@ -448,6 +448,15 @@ $6 = torch._ops.aten.add_($1, $5)''')
             with enable_python_mode(LoggingTensor):
                 with enable_python_mode(LoggingTensor):
                     pass
-
+    
+    def test_tolist_numpy_with_python_mode(self) -> None:
+        x = LoggingTensor(torch.tensor([2.0, 3.0]))
+        with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
+            x.tolist()
+        with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
+            x.numpy()
+        with self.assertRaises(AssertionError):
+            self.assertEqual(x, None)
+            
 if __name__ == '__main__':
     run_tests()
index 3c47fd8..a43f0ac 100644 (file)
@@ -8399,6 +8399,13 @@ class TestTorch(AbstractTestCases._TestTorchMixin):
         finally:
             torch.set_num_threads(num_threads)
 
+    def test_conj_neg_tolist(self):
+        x = torch.randn(2, dtype=torch.cfloat)
+        y1 = x.conj()
+        y1_expect = x.conj_physical()
+        y2 = y1.imag
+        self.assertEqual(y1, y1_expect.tolist())
+        self.assertEqual(y2, y1_expect.imag.tolist())
 
 # TODO: these empy classes are temporarily instantiated for XLA compatibility
 #   once XLA updates their test suite it should be removed
index 7948734..1cde3a1 100644 (file)
@@ -30,7 +30,8 @@ static PyObject* recursive_to_list(
 }
 
 PyObject* tensor_to_list(const Tensor& tensor) {
-  Tensor data = tensor;
+  TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".tolist() is not supported for tensor subclasses.");
+  Tensor data = tensor.resolve_conj().resolve_neg();
   if (!data.device().is_cpu()) {
     pybind11::gil_scoped_release no_gil;
     data = data.toBackend(Backend::CPU);
index 433d1e2..e507ffb 100644 (file)
@@ -130,6 +130,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
       "Can't call numpy() on Tensor that has negative bit set. "
       "Use tensor.resolve_neg().numpy() instead.");
 
+  TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses.");
+
   auto dtype = aten_to_numpy_dtype(tensor.scalar_type());
   auto sizes = to_numpy_shape(tensor.sizes());
   auto strides = to_numpy_shape(tensor.strides());
index b32c3a1..b4ef796 100644 (file)
@@ -1781,8 +1781,10 @@ class TestCase(expecttest.TestCase):
         assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too"
         debug_msg: Optional[str] = None
 
+        if x is None or y is None:
+            self.assertTrue(x is None and y is None)
         # Tensor x Number and Number x Tensor comparisons
-        if isinstance(x, torch.Tensor) and isinstance(y, Number):
+        elif isinstance(x, torch.Tensor) and isinstance(y, Number):
             self.assertEqual(x.item(), y, atol=atol, rtol=rtol, msg=msg,
                              exact_dtype=exact_dtype, exact_device=exact_device)
         elif isinstance(y, torch.Tensor) and isinstance(x, Number):