From: Ailing Zhang Date: Thu, 18 Apr 2019 19:07:17 +0000 (-0700) Subject: Fix pickling torch.float32 (#18045) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~158 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=88f70a16708369ad5d179fbe515f43739c0f2591;p=platform%2Fupstream%2Fpytorch.git Fix pickling torch.float32 (#18045) Summary: Attempt fix for #14057 . This PR fixes the example script in the issue. The old behavior is a bit confusing here. What happened to pickling is python2 failed to recognize `torch.float32` is in module `torch`, thus it's looking for `torch.float32` in module `__main__`. Python3 is smart enough to handle it. According to the doc [here](https://docs.python.org/2/library/pickle.html#object.__reduce__), it seems `__reduce__` should return `float32` instead of the old name `torch.float32`. In this way python2 is able to find `float32` in `torch` module. > If a string is returned, it names a global variable whose contents are pickled as normal. The string returned by __reduce__() should be the object’s local name relative to its module Pull Request resolved: https://github.com/pytorch/pytorch/pull/18045 Differential Revision: D14990638 Pulled By: ailzhang fbshipit-source-id: 816b97d63a934a5dda1a910312ad69f120b0b4de --- diff --git a/test/test_torch.py b/test/test_torch.py index 39abb94..aaba09d 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8859,6 +8859,13 @@ class _TestTorchMixin(object): self.assertEqual(a.requires_grad, b.requires_grad) self.assertEqual(a, b) + def test_pickle_dtype(self): + t = torch.float32 + serialized = pickle.dumps(t) + b = pickle.loads(serialized) + self.assertTrue(isinstance(b, torch.dtype)) + self.assertEqual(id(b), id(t)) + def test_norm_fastpaths(self): x = torch.randn(3, 5) diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp index 8dd93e2..fca83f9 100644 --- a/torch/csrc/Dtype.cpp +++ b/torch/csrc/Dtype.cpp @@ -52,7 +52,8 @@ static PyMethodDef THPDtype_methods[] = { PyObject *THPDtype_repr(THPDtype *self) { - return THPUtils_packString(self->name); + std::string name = self->name; + return THPUtils_packString("torch." + name); } PyTypeObject THPDtypeType = { diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 51c8dc5..9b39677 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -60,9 +60,7 @@ void initializeDtypes() { for (at::ScalarType scalarType : all_scalar_types) { std::string primary_name, legacy_name; std::tie(primary_name, legacy_name) = getDtypeNames(scalarType); - std::string name = - std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name; - PyObject* dtype = THPDtype_New(scalarType, name); + PyObject *dtype = THPDtype_New(scalarType, primary_name); torch::registerDtypeObject((THPDtype*)dtype, scalarType); Py_INCREF(dtype); if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=