Fix variable checking in THCPModule_setRNGState (#17474)
authorWill Feng <willfeng@fb.com>
Mon, 25 Feb 2019 18:56:03 +0000 (10:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Feb 2019 19:05:51 +0000 (11:05 -0800)
Summary:
See https://github.com/pytorch/pytorch/pull/16325/files#r259576901
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17474

Differential Revision: D14209549

Pulled By: yf225

fbshipit-source-id: 2ae091955ae17f5d1540f7d465739c4809c327f8

torch/csrc/cuda/Module.cpp

index 7571f11..e344813 100644 (file)
@@ -149,8 +149,8 @@ PyObject * THCPModule_getRNGState(PyObject *_unused)
 PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *obj)
 {
   HANDLE_TH_ERRORS
-  auto& data_type = THPVariable_Unpack(obj).type();
-  if (!THPVariable_Check(obj) || at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() != at::TypeID::CPUByte) {
+  if (!THPVariable_Check(obj) ||
+      at::globalContext().getNonVariableType(THPVariable_Unpack(obj).type().backend(), THPVariable_Unpack(obj).type().scalarType()).ID() != at::TypeID::CPUByte) {
     throw TypeError("set_rng_state expects a torch.ByteTensor, but got %s",
         Py_TYPE(obj)->tp_name);
   }