Fix pickling torch.float32 (#18045)
authorAiling Zhang <ailzhang@fb.com>
Thu, 18 Apr 2019 19:07:17 +0000 (12:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 19:28:10 +0000 (12:28 -0700)
Summary:
Attempt fix for #14057 . This PR fixes the example script in the issue.
The old behavior is a bit confusing here. What happened to pickling is python2 failed to recognize `torch.float32` is in module `torch`, thus it's looking for `torch.float32` in module `__main__`. Python3 is smart enough to handle it.
According to the doc [here](https://docs.python.org/2/library/pickle.html#object.__reduce__), it seems `__reduce__` should return `float32` instead of the old name `torch.float32`. In this way python2 is able to find `float32` in `torch` module.
> If a string is returned, it names a global variable whose contents are pickled as normal. The string returned by __reduce__() should be the object’s local name relative to its module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18045

Differential Revision: D14990638

Pulled By: ailzhang

fbshipit-source-id: 816b97d63a934a5dda1a910312ad69f120b0b4de

test/test_torch.py
torch/csrc/Dtype.cpp
torch/csrc/utils/tensor_dtypes.cpp

index 39abb94..aaba09d 100644 (file)
@@ -8859,6 +8859,13 @@ class _TestTorchMixin(object):
         self.assertEqual(a.requires_grad, b.requires_grad)
         self.assertEqual(a, b)
 
+    def test_pickle_dtype(self):
+        t = torch.float32
+        serialized = pickle.dumps(t)
+        b = pickle.loads(serialized)
+        self.assertTrue(isinstance(b, torch.dtype))
+        self.assertEqual(id(b), id(t))
+
     def test_norm_fastpaths(self):
         x = torch.randn(3, 5)
 
index 8dd93e2..fca83f9 100644 (file)
@@ -52,7 +52,8 @@ static PyMethodDef THPDtype_methods[] = {
 
 PyObject *THPDtype_repr(THPDtype *self)
 {
-  return THPUtils_packString(self->name);
+  std::string name = self->name;
+  return THPUtils_packString("torch." + name);
 }
 
 PyTypeObject THPDtypeType = {
index 51c8dc5..9b39677 100644 (file)
@@ -60,9 +60,7 @@ void initializeDtypes() {
   for (at::ScalarType scalarType : all_scalar_types) {
     std::string primary_name, legacy_name;
     std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
-    std::string name =
-        std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
-    PyObject* dtype = THPDtype_New(scalarType, name);
+    PyObject *dtype = THPDtype_New(scalarType, primary_name);
     torch::registerDtypeObject((THPDtype*)dtype, scalarType);
     Py_INCREF(dtype);
     if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=