From 65b00aa5972e23b2a70aa60dec5125671a3d7153 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Fri, 8 Mar 2019 16:39:04 -0800 Subject: [PATCH] Remove some simple use cases of Type::ScalarType() Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17529 Reviewed By: ezyang Differential Revision: D14237932 fbshipit-source-id: be633a1fc19215d53cfe083fdd7196acf2b7dd2f --- aten/src/ATen/DLConvertor.cpp | 8 ++--- aten/src/ATen/cudnn/Descriptors.cpp | 8 ++--- aten/src/ATen/miopen/Descriptors.cpp | 8 ++--- aten/src/ATen/native/Indexing.cpp | 3 +- aten/src/ATen/native/PackedSequence.cpp | 3 +- torch/csrc/cuda/nccl.cpp | 10 +++---- torch/csrc/cuda/nccl.h | 2 +- torch/csrc/cuda/python_nccl.cpp | 6 ++-- torch/csrc/nn/type_checks.h | 4 +-- torch/csrc/utils/tensor_list.cpp | 3 +- torch/csrc/utils/tensor_numpy.cpp | 53 +++++++++++++++++---------------- 11 files changed, 49 insertions(+), 59 deletions(-) diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 8c1bab7..428ae9e 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -8,11 +8,11 @@ using namespace std; namespace at { -static DLDataType getDLDataType(const Type& type) { +static DLDataType getDLDataType(const Tensor& t) { DLDataType dtype; dtype.lanes = 1; - dtype.bits = type.elementSizeInBytes() * 8; - switch (type.scalarType()) { + dtype.bits = t.dtype().itemsize() * 8; + switch (t.scalar_type()) { case ScalarType::Byte: dtype.code = DLDataTypeCode::kDLUInt; break; @@ -160,7 +160,7 @@ DLManagedTensor* toDLPack(const Tensor& src) { } atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id); atDLMTensor->tensor.dl_tensor.ndim = src.dim(); - atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src.type()); + atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); atDLMTensor->tensor.dl_tensor.shape = const_cast(src.sizes().data()); atDLMTensor->tensor.dl_tensor.strides = const_cast(src.strides().data()); atDLMTensor->tensor.dl_tensor.byte_offset = 0; diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index 1b669a5..eef5e52 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -10,8 +10,8 @@ namespace at { namespace native { namespace { -inline cudnnDataType_t getDataType(const at::Type& t) { - auto scalar_type = t.scalarType(); +inline cudnnDataType_t getDataType(const at::Tensor& t) { + auto scalar_type = t.scalar_type(); if (scalar_type == at::kFloat) { return CUDNN_DATA_FLOAT; } else if (scalar_type == at::kHalf) { @@ -22,10 +22,6 @@ inline cudnnDataType_t getDataType(const at::Type& t) { throw std::runtime_error("TensorDescriptor only supports double, float and half tensors"); } -inline cudnnDataType_t getDataType(const at::Tensor& t) { - return getDataType(t.type()); -} - } // anonymous namespace diff --git a/aten/src/ATen/miopen/Descriptors.cpp b/aten/src/ATen/miopen/Descriptors.cpp index 37173da..ecbd03d 100644 --- a/aten/src/ATen/miopen/Descriptors.cpp +++ b/aten/src/ATen/miopen/Descriptors.cpp @@ -5,8 +5,8 @@ namespace at { namespace native { namespace { -inline miopenDataType_t getDataType(const at::Type& t) { - auto scalar_type = t.scalarType(); +inline miopenDataType_t getDataType(const at::Tensor& t) { + auto scalar_type = t.scalar_type(); if (scalar_type == at::kFloat) { return miopenFloat; } else if (scalar_type == at::kHalf) { @@ -15,10 +15,6 @@ inline miopenDataType_t getDataType(const at::Type& t) { throw std::runtime_error("TensorDescriptor only supports float and half tensors"); } -inline miopenDataType_t getDataType(const at::Tensor& t) { - return getDataType(t.type()); -} - } // anonymous namespace diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index b62e0ab..6ea80ba 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -78,8 +78,7 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, static void checkIndexTensorTypes(TensorList indices) { for (auto& tensor : indices) { if (tensor.defined()) { - auto& type = tensor.type(); - auto scalarType = type.scalarType(); + auto scalarType = tensor.scalar_type(); if (scalarType != kLong && scalarType != kByte) { AT_INDEX_ERROR("tensors used as indices must be long or byte tensors"); } diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp index c256f04..265d514 100644 --- a/aten/src/ATen/native/PackedSequence.cpp +++ b/aten/src/ATen/native/PackedSequence.cpp @@ -4,8 +4,7 @@ namespace at { namespace native { void checkLongTensor(const Tensor& tensor) { - auto & t = tensor.type(); - AT_CHECK(tensor.dim() == 1 && t.device_type() == at::kCPU && t.scalarType() == at::kLong, + AT_CHECK(tensor.dim() == 1 && tensor.type().device_type() == at::kCPU && tensor.scalar_type() == at::kLong, "'lengths' argument should be a 1D CPU int64 tensor"); } diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index a083d0d..4d51a12 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -81,11 +81,11 @@ ArrayRef _get_communicators(TensorList inputs) { return it->second.ref(); } -ncclDataType_t _get_data_type(const Type& type) { - if (type.backend() != Backend::CUDA) { +ncclDataType_t _get_data_type(const Tensor& t) { + if (t.type().backend() != Backend::CUDA) { throw std::runtime_error("Unconvertible NCCL type"); } - switch (type.scalarType()) { + switch (t.scalar_type()) { case at::kFloat: return ncclFloat; case at::kHalf: @@ -233,7 +233,7 @@ void broadcast( #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; _check_inputs(tensors, tensors, 1, 1); - ncclDataType_t data_type = _get_data_type(tensors[0].type()); + ncclDataType_t data_type = _get_data_type(tensors[0]); int64_t numel = tensors[0].numel(); std::lock_guard free_mutex( @@ -281,7 +281,7 @@ void reduce( _check_inputs(inputs, outputs, 1, 1); const auto len = inputs.size(); - ncclDataType_t data_type = _get_data_type(inputs[0].type()); + ncclDataType_t data_type = _get_data_type(inputs[0]); const auto count = inputs[0].numel(); std::lock_guard lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex())); diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 2120274..6655aa5 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -45,7 +45,7 @@ void _check_inputs( at::TensorList outputs, int input_multiplier, int output_multiplier); -ncclDataType_t _get_data_type(const at::Type& type); +ncclDataType_t _get_data_type(const at::Tensor& t); } // namespace detail diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 9491b85..55e796d 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -188,7 +188,7 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) { _check_inputs(inputs, outputs, 1, 1); size_t len = inputs.size(); - ncclDataType_t data_type = _get_data_type(inputs[0].type()); + ncclDataType_t data_type = _get_data_type(inputs[0]); int64_t count = inputs[0].numel(); std::lock_guard lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex())); @@ -268,7 +268,7 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) { size_t len = inputs.size(); _check_inputs(inputs, outputs, len, 1); - ncclDataType_t data_type = _get_data_type(inputs[0].type()); + ncclDataType_t data_type = _get_data_type(inputs[0]); int64_t count = inputs[0].numel(); std::lock_guard lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex())); @@ -331,7 +331,7 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) { size_t len = inputs.size(); _check_inputs(inputs, outputs, 1, len); - ncclDataType_t data_type = _get_data_type(inputs[0].type()); + ncclDataType_t data_type = _get_data_type(inputs[0]); int64_t count = inputs[0].numel() / len; std::lock_guard lock(*(c10::cuda::CUDACachingAllocator::getFreeMutex())); diff --git a/torch/csrc/nn/type_checks.h b/torch/csrc/nn/type_checks.h index 966e32f..2a00caf 100644 --- a/torch/csrc/nn/type_checks.h +++ b/torch/csrc/nn/type_checks.h @@ -11,8 +11,8 @@ namespace torch { namespace nn { inline bool check_type(PyObject* obj, at::TypeID typeID) { if (THPVariable_Check(obj)) { - auto& data_type = ((THPVariable*)obj)->cdata.type(); - return at::globalContext().getNonVariableType(data_type.backend(), data_type.scalarType()).ID() == typeID; + auto& tensor = ((THPVariable*)obj)->cdata; + return at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type()).ID() == typeID; } return false; } diff --git a/torch/csrc/utils/tensor_list.cpp b/torch/csrc/utils/tensor_list.cpp index 46cfd3b..1e3e48b 100644 --- a/torch/csrc/utils/tensor_list.cpp +++ b/torch/csrc/utils/tensor_list.cpp @@ -35,10 +35,9 @@ PyObject* tensor_to_list(const Tensor& tensor) { data = data.toBackend(Backend::CPU); }); } - auto& type = data.type(); return recursive_to_list( (char*)data.data_ptr(), data.sizes(), data.strides(), 0, - type.scalarType(), type.elementSizeInBytes()); + data.scalar_type(), data.dtype().itemsize()); } }} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 4550b46..4549971 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -49,10 +49,23 @@ static std::vector to_aten_shape(int ndim, npy_intp* values) { return result; } -static int aten_to_dtype(const at::Type& type); +static int aten_to_dtype(const ScalarType scalar_type); PyObject* tensor_to_numpy(const at::Tensor& tensor) { - auto dtype = aten_to_dtype(tensor.type()); + if (tensor.is_cuda()) { + throw TypeError( + "can't convert CUDA tensor to numpy. Use Tensor.cpu() to " + "copy the tensor to host memory first."); + } + if (tensor.is_sparse()) { + throw TypeError( + "can't convert sparse tensor to numpy. Use Tensor.to_dense() to " + "convert to a dense tensor first."); + } + if (tensor.type().backend() != Backend::CPU) { + throw TypeError("NumPy conversion for %s is not supported", tensor.type().toString()); + } + auto dtype = aten_to_dtype(tensor.scalar_type()); auto sizes = to_numpy_shape(tensor.sizes()); auto strides = to_numpy_shape(tensor.strides()); // NumPy strides use bytes. Torch strides use element counts. @@ -133,31 +146,19 @@ at::Tensor tensor_from_numpy(PyObject* obj) { }); } -static int aten_to_dtype(const at::Type& type) { - if (type.is_cuda()) { - throw TypeError( - "can't convert CUDA tensor to numpy. Use Tensor.cpu() to " - "copy the tensor to host memory first."); - } - if (type.is_sparse()) { - throw TypeError( - "can't convert sparse tensor to numpy. Use Tensor.to_dense() to " - "convert to a dense tensor first."); - } - if (type.backend() == Backend::CPU) { - switch (type.scalarType()) { - case kDouble: return NPY_DOUBLE; - case kFloat: return NPY_FLOAT; - case kHalf: return NPY_HALF; - case kLong: return NPY_INT64; - case kInt: return NPY_INT32; - case kShort: return NPY_INT16; - case kChar: return NPY_INT8; - case kByte: return NPY_UINT8; - default: break; - } +static int aten_to_dtype(const ScalarType scalar_type) { + switch (scalar_type) { + case kDouble: return NPY_DOUBLE; + case kFloat: return NPY_FLOAT; + case kHalf: return NPY_HALF; + case kLong: return NPY_INT64; + case kInt: return NPY_INT32; + case kShort: return NPY_INT16; + case kChar: return NPY_INT8; + case kByte: return NPY_UINT8; + default: + throw ValueError("Got unsupported ScalarType ", toString(scalar_type)); } - throw TypeError("NumPy conversion for %s is not supported", type.toString()); } ScalarType numpy_dtype_to_aten(int dtype) { -- 2.7.4