Add more error checking in subclass creation (#64746)
authorAlban Desmaison <albandes@fb.com>
Fri, 10 Sep 2021 20:07:37 +0000 (13:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 23:49:10 +0000 (16:49 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64746

This extracts the error checking that used to be in the PR above.
We are not going to land the proposed fix there, but I think we want this error checking in right now as these would lead to respectively a memory leak and arbitrary memory read/write.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30867569

Pulled By: albanD

fbshipit-source-id: bf468033fb8b49fcb26eed423f5fad82b4a46c56

test/test_python_dispatch.py
test/test_torch.py
torch/csrc/autograd/python_variable.cpp

index 0f5b6b9..cd09488 100644 (file)
@@ -23,6 +23,7 @@ def no_dispatch() -> Iterator[None]:
 # 3. Enter dispatcher, wind your way through Autograd
 # 4. Hit Python dispatch key, call __torch_dispatch__
 
+WRAPPER_DEVICE = "meta"
 # TODO: TensorBase should work
 class LoggingTensor(torch.Tensor):
     elem: torch.Tensor
@@ -34,7 +35,7 @@ class LoggingTensor(torch.Tensor):
         # The wrapping tensor (LoggingTensor) is just a meta tensor, so it
         # doesn't hold any memory (meta tensor is generally the preferred type
         # of tensor you want to make a subclass from)...
-        r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad)
+        r = torch.Tensor._make_subclass(cls, elem.to(WRAPPER_DEVICE), elem.requires_grad)
         # ...the real tensor is held as an element on the tensor.
         r.elem = elem
         return r
@@ -335,6 +336,38 @@ $4 = torch._ops.aten.mul($3, tensor(2))
 $5 = torch._ops.aten.mul($4, $0)
 $6 = torch._ops.aten.add_($1, $5)''')
 
+    def test_subclass_creation(self):
+        # Make sure these statements runs without error
+        # In particular checking that when internal detach returns
+        # subclasses, these are cleanly overwritten.
+        class Foo(torch.Tensor):
+            pass
+
+        err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
+        with self.assertRaisesRegex(RuntimeError, err_msg):
+            a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
+        with self.assertRaisesRegex(RuntimeError, err_msg):
+            b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
+
+        # And in case where we don't know if the user wants this subclass
+        # overwritten, raise a nice error.
+        # The standard LoggingTensor will fail because it is not on the right device
+        with self.assertRaisesRegex(TypeError, "expected.*device=cpu.*device=meta"):
+            Foo(LoggingTensor(torch.rand(2)))
+
+        # And if we put it on the right device, we still get a nice error
+        try:
+            global WRAPPER_DEVICE
+            prev_device = WRAPPER_DEVICE
+            WRAPPER_DEVICE = "cpu"
+
+            err_msg = "Creating a new Tensor subclass Foo.*python object of type LoggingTensor"
+            with self.assertRaisesRegex(RuntimeError, err_msg):
+                Foo(LoggingTensor(torch.rand(2)))
+
+        finally:
+            WRAPPER_DEVICE = prev_device
+
 
 if __name__ == '__main__':
     run_tests()
index 6de409b..ef76fc4 100644 (file)
@@ -519,6 +519,15 @@ class AbstractTestCases:
             # declared with requires_grad.
             self.assertTrue(t.grad is not None)
 
+            # Make sure invalid subclasses raise nice errors
+            class BadSubTensor():
+                member_var = object()
+
+            err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
+            with self.assertRaisesRegex(RuntimeError, err_msg):
+                s0 = t0.as_subclass(BadSubTensor)
+
+
         def test_type(self):
             x = torch.randn(3, 3).double()
             self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32)
index 769f9e0..ce352ad 100644 (file)
@@ -1252,6 +1252,17 @@ static PyObject* THPVariable_NewWithVar(
     PyTypeObject* type,
     Variable _var,
     c10::impl::PyInterpreterStatus status) {
+  // This function overwrite the Tensor's pyobj field without extra checks
+  // Make sure it is not set otherwise we would leak memory
+  auto mb_obj = _var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get());
+  TORCH_CHECK(!mb_obj.has_value() || !mb_obj.value(), "Creating a new Tensor subclass ",
+    type->tp_name, " but the raw Tensor object is already associated to a python object ",
+    "of type ", mb_obj.value()->ob_type->tp_name);
+
+  // Make sure that the reinterpret into a THPVariable* will be valid
+  TORCH_CHECK(PyType_IsSubtype(type, &THPVariableType), "Creating a Tensor subclass from a class ",
+    "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
+
   PyObject* obj = type->tp_alloc(type, 0);
   if (obj) {
     auto v = (THPVariable*) obj;