Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66082
Fixes https://github.com/pytorch/pytorch/issues/66024 #65779
cc ezyang anjali411 dylanbespalko mruberry Lezcano nikitaved albanD
Test Plan: Imported from OSS
Reviewed By: Gamrix, albanD
Differential Revision:
D31615588
Pulled By: anjali411
fbshipit-source-id:
c3e65ef0fe301630eb76732ccd7819683c09aa19
with enable_python_mode(LoggingTensor):
with enable_python_mode(LoggingTensor):
pass
-
+
+ def test_tolist_numpy_with_python_mode(self) -> None:
+ x = LoggingTensor(torch.tensor([2.0, 3.0]))
+ with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
+ x.tolist()
+ with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
+ x.numpy()
+ with self.assertRaises(AssertionError):
+ self.assertEqual(x, None)
+
if __name__ == '__main__':
run_tests()
finally:
torch.set_num_threads(num_threads)
+ def test_conj_neg_tolist(self):
+ x = torch.randn(2, dtype=torch.cfloat)
+ y1 = x.conj()
+ y1_expect = x.conj_physical()
+ y2 = y1.imag
+ self.assertEqual(y1, y1_expect.tolist())
+ self.assertEqual(y2, y1_expect.imag.tolist())
# TODO: these empy classes are temporarily instantiated for XLA compatibility
# once XLA updates their test suite it should be removed
}
PyObject* tensor_to_list(const Tensor& tensor) {
- Tensor data = tensor;
+ TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".tolist() is not supported for tensor subclasses.");
+ Tensor data = tensor.resolve_conj().resolve_neg();
if (!data.device().is_cpu()) {
pybind11::gil_scoped_release no_gil;
data = data.toBackend(Backend::CPU);
"Can't call numpy() on Tensor that has negative bit set. "
"Use tensor.resolve_neg().numpy() instead.");
+ TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses.");
+
auto dtype = aten_to_numpy_dtype(tensor.scalar_type());
auto sizes = to_numpy_shape(tensor.sizes());
auto strides = to_numpy_shape(tensor.strides());
assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too"
debug_msg: Optional[str] = None
+ if x is None or y is None:
+ self.assertTrue(x is None and y is None)
# Tensor x Number and Number x Tensor comparisons
- if isinstance(x, torch.Tensor) and isinstance(y, Number):
+ elif isinstance(x, torch.Tensor) and isinstance(y, Number):
self.assertEqual(x.item(), y, atol=atol, rtol=rtol, msg=msg,
exact_dtype=exact_dtype, exact_device=exact_device)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):