# 3. Enter dispatcher, wind your way through Autograd
# 4. Hit Python dispatch key, call __torch_dispatch__
+WRAPPER_DEVICE = "meta"
# TODO: TensorBase should work
class LoggingTensor(torch.Tensor):
elem: torch.Tensor
# The wrapping tensor (LoggingTensor) is just a meta tensor, so it
# doesn't hold any memory (meta tensor is generally the preferred type
# of tensor you want to make a subclass from)...
- r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad)
+ r = torch.Tensor._make_subclass(cls, elem.to(WRAPPER_DEVICE), elem.requires_grad)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r
$5 = torch._ops.aten.mul($4, $0)
$6 = torch._ops.aten.add_($1, $5)''')
+ def test_subclass_creation(self):
+ # Make sure these statements runs without error
+ # In particular checking that when internal detach returns
+ # subclasses, these are cleanly overwritten.
+ class Foo(torch.Tensor):
+ pass
+
+ err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
+ with self.assertRaisesRegex(RuntimeError, err_msg):
+ a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
+ with self.assertRaisesRegex(RuntimeError, err_msg):
+ b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
+
+ # And in case where we don't know if the user wants this subclass
+ # overwritten, raise a nice error.
+ # The standard LoggingTensor will fail because it is not on the right device
+ with self.assertRaisesRegex(TypeError, "expected.*device=cpu.*device=meta"):
+ Foo(LoggingTensor(torch.rand(2)))
+
+ # And if we put it on the right device, we still get a nice error
+ try:
+ global WRAPPER_DEVICE
+ prev_device = WRAPPER_DEVICE
+ WRAPPER_DEVICE = "cpu"
+
+ err_msg = "Creating a new Tensor subclass Foo.*python object of type LoggingTensor"
+ with self.assertRaisesRegex(RuntimeError, err_msg):
+ Foo(LoggingTensor(torch.rand(2)))
+
+ finally:
+ WRAPPER_DEVICE = prev_device
+
if __name__ == '__main__':
run_tests()
# declared with requires_grad.
self.assertTrue(t.grad is not None)
+ # Make sure invalid subclasses raise nice errors
+ class BadSubTensor():
+ member_var = object()
+
+ err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
+ with self.assertRaisesRegex(RuntimeError, err_msg):
+ s0 = t0.as_subclass(BadSubTensor)
+
+
def test_type(self):
x = torch.randn(3, 3).double()
self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32)
PyTypeObject* type,
Variable _var,
c10::impl::PyInterpreterStatus status) {
+ // This function overwrite the Tensor's pyobj field without extra checks
+ // Make sure it is not set otherwise we would leak memory
+ auto mb_obj = _var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get());
+ TORCH_CHECK(!mb_obj.has_value() || !mb_obj.value(), "Creating a new Tensor subclass ",
+ type->tp_name, " but the raw Tensor object is already associated to a python object ",
+ "of type ", mb_obj.value()->ob_type->tp_name);
+
+ // Make sure that the reinterpret into a THPVariable* will be valid
+ TORCH_CHECK(PyType_IsSubtype(type, &THPVariableType), "Creating a Tensor subclass from a class ",
+ "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
+
PyObject* obj = type->tp_alloc(type, 0);
if (obj) {
auto v = (THPVariable*) obj;