Summary:
Attempt fix for #14057 . This PR fixes the example script in the issue.
The old behavior is a bit confusing here. What happened to pickling is python2 failed to recognize `torch.float32` is in module `torch`, thus it's looking for `torch.float32` in module `__main__`. Python3 is smart enough to handle it.
According to the doc [here](https://docs.python.org/2/library/pickle.html#object.__reduce__), it seems `__reduce__` should return `float32` instead of the old name `torch.float32`. In this way python2 is able to find `float32` in `torch` module.
> If a string is returned, it names a global variable whose contents are pickled as normal. The string returned by __reduce__() should be the object’s local name relative to its module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18045
Differential Revision:
D14990638
Pulled By: ailzhang
fbshipit-source-id:
816b97d63a934a5dda1a910312ad69f120b0b4de
self.assertEqual(a.requires_grad, b.requires_grad)
self.assertEqual(a, b)
+ def test_pickle_dtype(self):
+ t = torch.float32
+ serialized = pickle.dumps(t)
+ b = pickle.loads(serialized)
+ self.assertTrue(isinstance(b, torch.dtype))
+ self.assertEqual(id(b), id(t))
+
def test_norm_fastpaths(self):
x = torch.randn(3, 5)
PyObject *THPDtype_repr(THPDtype *self)
{
- return THPUtils_packString(self->name);
+ std::string name = self->name;
+ return THPUtils_packString("torch." + name);
}
PyTypeObject THPDtypeType = {
for (at::ScalarType scalarType : all_scalar_types) {
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
- std::string name =
- std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
- PyObject* dtype = THPDtype_New(scalarType, name);
+ PyObject *dtype = THPDtype_New(scalarType, primary_name);
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
Py_INCREF(dtype);
if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=