From 202eaa4ef40824d08ae88b3062b9c82d78794432 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sun, 10 Feb 2019 09:38:50 -0800 Subject: [PATCH] Use non-Variable type for callsites that check type equality (#16325) Summary: When Variable and Tensor are merged, the dynamic type of the tensors passed to certain functions will become variables, and expecting `type()` on those variables to still return non-Variable types will cause type mismatch error. One way to fix this problem is to use the thread-local guard `at::AutoNonVariableTypeMode` to force `type()` to return non-Variable type, but ideally we want to limit the use of `at::AutoNonVariableTypeMode` to be only in VariableType.cpp. Another way to fix the problem is to use `at::globalContext().getNonVariableType()` instead to get the non-Variable type of the tensor, which is what this PR is trying to achieve. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16325 Differential Revision: D14012022 Pulled By: yf225 fbshipit-source-id: 77ef1d2a02f78bff0063bdd72596e34046f1e00d --- aten/src/ATen/native/TensorIterator.cpp | 18 ++++++++++-------- torch/csrc/Generator.cpp | 5 +++-- torch/csrc/autograd/python_variable.cpp | 3 ++- torch/csrc/cuda/Module.cpp | 3 ++- torch/csrc/nn/type_checks.h | 3 ++- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index 29257aa..f89f047 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -95,20 +95,21 @@ void TensorIterator::compute_types() { if (missing_dtypes || compute_common_dtype_) { auto& type = compute_common_type(); for (auto& op : operands_) { + auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.type().scalarType()); if (!op.type) { op.type = &type; } else if (compute_common_dtype_ && op.type != &type) { if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 && - type.device_type() == kCUDA && op.tensor.type().device_type() == kCPU) { + type.device_type() == kCUDA && op_tensor_type.device_type() == kCPU) { // don't cast CPU scalars in CUDA ops that directly support them - op.type = &op.tensor.type(); + op.type = &op_tensor_type; } else if (promote_gpu_output_dtypes_ && op.tensor.defined() && - !op.is_output && op.tensor.type().scalarType() == kHalf && + !op.is_output && op_tensor_type.scalarType() == kHalf && type.scalarType() == kFloat && type.device_type() == kCUDA && - op.tensor.type().device_type() == kCUDA) { + op_tensor_type.device_type() == kCUDA) { // allow input tensor type upcasting for fp16 to fp32 in fused kernel // on GPU - op.type = &op.tensor.type(); + op.type = &op_tensor_type; } else { op.type = &type; } @@ -117,15 +118,16 @@ void TensorIterator::compute_types() { } for (auto& op : operands_) { - if (op.tensor.defined() && op.tensor.type() != *op.type) { + auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.type().scalarType()); + if (op.tensor.defined() && op_tensor_type != *op.type) { if (op.is_output) { - AT_ERROR("output with type ", op.tensor.type().toString(), + AT_ERROR("output with type ", op_tensor_type.toString(), " doesn't match the desired type ", op.type->toString()); } else if (op.tensor.dim() == 0) { op.tensor = op.tensor.to(*op.type); } else { AT_ERROR("expected type ", op.type->toString(), " but got ", - op.tensor.type().toString()); + op_tensor_type.toString()); } } } diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 1b0be10..d13032d 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -82,8 +82,9 @@ static PyObject * THPGenerator_setState(THPGenerator *self, PyObject *_new_state throw TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name); } auto& tensor = ((THPVariable*)_new_state)->cdata.data(); - if (tensor.type() != CPU(kByte)) { - auto type_name = torch::utils::type_to_string(tensor.type()); + auto& tensor_type = at::globalContext().getNonVariableType(tensor.type().backend(), tensor.type().scalarType()); + if (tensor_type != CPU(kByte)) { + auto type_name = torch::utils::type_to_string(tensor_type); throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str()); } THGenerator *generator = THPGenerator_TH_CData(self); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index f747c6f..ab2037a 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -251,7 +251,8 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad) auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.type().scalarType()); if (typeOpt) { auto& sparseType = at::globalContext().getNonVariableType(backend, var.type().scalarType()); - gradIsSparse = grad.type() == sparseType; + auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.type().scalarType()); + gradIsSparse = gradType == sparseType; } THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse, diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index e5577c0..8d1571a 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -149,7 +149,8 @@ PyObject * THCPModule_getRNGState(PyObject *_unused) PyObject * THCPModule_setRNGState(PyObject *_unused, PyObject *obj) { HANDLE_TH_ERRORS - if (!THPVariable_Check(obj) || THPVariable_UnpackData(obj).type().ID() != at::TypeID::CPUByte) { + auto& data_type = THPVariable_Unpack(obj).type(); + if (!THPVariable_Check(obj) || at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() != at::TypeID::CPUByte) { throw TypeError("set_rng_state expects a torch.ByteTensor, but got %s", Py_TYPE(obj)->tp_name); } diff --git a/torch/csrc/nn/type_checks.h b/torch/csrc/nn/type_checks.h index 1f3140c..966e32f 100644 --- a/torch/csrc/nn/type_checks.h +++ b/torch/csrc/nn/type_checks.h @@ -11,7 +11,8 @@ namespace torch { namespace nn { inline bool check_type(PyObject* obj, at::TypeID typeID) { if (THPVariable_Check(obj)) { - return ((THPVariable*)obj)->cdata.data().type().ID() == typeID; + auto& data_type = ((THPVariable*)obj)->cdata.type(); + return at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() == typeID; } return false; } -- 2.7.4