From d8ae3cc31889a151ae1c1b61bf179cc25915eadc Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Fri, 10 Sep 2021 13:07:37 -0700 Subject: [PATCH] Add more error checking in subclass creation (#64746) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64746 This extracts the error checking that used to be in the PR above. We are not going to land the proposed fix there, but I think we want this error checking in right now as these would lead to respectively a memory leak and arbitrary memory read/write. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D30867569 Pulled By: albanD fbshipit-source-id: bf468033fb8b49fcb26eed423f5fad82b4a46c56 --- test/test_python_dispatch.py | 35 ++++++++++++++++++++++++++++++++- test/test_torch.py | 9 +++++++++ torch/csrc/autograd/python_variable.cpp | 11 +++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 0f5b6b9..cd09488 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -23,6 +23,7 @@ def no_dispatch() -> Iterator[None]: # 3. Enter dispatcher, wind your way through Autograd # 4. Hit Python dispatch key, call __torch_dispatch__ +WRAPPER_DEVICE = "meta" # TODO: TensorBase should work class LoggingTensor(torch.Tensor): elem: torch.Tensor @@ -34,7 +35,7 @@ class LoggingTensor(torch.Tensor): # The wrapping tensor (LoggingTensor) is just a meta tensor, so it # doesn't hold any memory (meta tensor is generally the preferred type # of tensor you want to make a subclass from)... - r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad) + r = torch.Tensor._make_subclass(cls, elem.to(WRAPPER_DEVICE), elem.requires_grad) # ...the real tensor is held as an element on the tensor. r.elem = elem return r @@ -335,6 +336,38 @@ $4 = torch._ops.aten.mul($3, tensor(2)) $5 = torch._ops.aten.mul($4, $0) $6 = torch._ops.aten.add_($1, $5)''') + def test_subclass_creation(self): + # Make sure these statements runs without error + # In particular checking that when internal detach returns + # subclasses, these are cleanly overwritten. + class Foo(torch.Tensor): + pass + + err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor" + with self.assertRaisesRegex(RuntimeError, err_msg): + a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2))) + with self.assertRaisesRegex(RuntimeError, err_msg): + b = LoggingTensor(torch.rand(2)).as_subclass(Foo) + + # And in case where we don't know if the user wants this subclass + # overwritten, raise a nice error. + # The standard LoggingTensor will fail because it is not on the right device + with self.assertRaisesRegex(TypeError, "expected.*device=cpu.*device=meta"): + Foo(LoggingTensor(torch.rand(2))) + + # And if we put it on the right device, we still get a nice error + try: + global WRAPPER_DEVICE + prev_device = WRAPPER_DEVICE + WRAPPER_DEVICE = "cpu" + + err_msg = "Creating a new Tensor subclass Foo.*python object of type LoggingTensor" + with self.assertRaisesRegex(RuntimeError, err_msg): + Foo(LoggingTensor(torch.rand(2))) + + finally: + WRAPPER_DEVICE = prev_device + if __name__ == '__main__': run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index 6de409b..ef76fc4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -519,6 +519,15 @@ class AbstractTestCases: # declared with requires_grad. self.assertTrue(t.grad is not None) + # Make sure invalid subclasses raise nice errors + class BadSubTensor(): + member_var = object() + + err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor" + with self.assertRaisesRegex(RuntimeError, err_msg): + s0 = t0.as_subclass(BadSubTensor) + + def test_type(self): x = torch.randn(3, 3).double() self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32) diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 769f9e0..ce352ad 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1252,6 +1252,17 @@ static PyObject* THPVariable_NewWithVar( PyTypeObject* type, Variable _var, c10::impl::PyInterpreterStatus status) { + // This function overwrite the Tensor's pyobj field without extra checks + // Make sure it is not set otherwise we would leak memory + auto mb_obj = _var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()); + TORCH_CHECK(!mb_obj.has_value() || !mb_obj.value(), "Creating a new Tensor subclass ", + type->tp_name, " but the raw Tensor object is already associated to a python object ", + "of type ", mb_obj.value()->ob_type->tp_name); + + // Make sure that the reinterpret into a THPVariable* will be valid + TORCH_CHECK(PyType_IsSubtype(type, &THPVariableType), "Creating a Tensor subclass from a class ", + "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor."); + PyObject* obj = type->tp_alloc(type, 0); if (obj) { auto v = (THPVariable*) obj; -- 2.7.4