From 6fe1867c2362e05ee324350cb72d1a29265f62b7 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 29 Nov 2018 16:01:44 -0800 Subject: [PATCH] Expunge direct device index handling from tensor_conversion_dispatch (#14421) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14421 Last time I looked this, I bailed because it seemed like there were a lot of sites to fix. Well, I need this to work properly for out-of-place HIPify, so I took another whack at it. Changes should be pretty self-explanatory. Reviewed By: gchanan Differential Revision: D13221302 fbshipit-source-id: ed21e2668a1a629898a47358baf368fe680263a0 --- aten/src/ATen/core/Type.h | 14 ++++- aten/src/ATen/templates/Type.h | 14 ++++- torch/csrc/autograd/python_variable_indexing.cpp | 2 +- torch/csrc/utils/tensor_conversion_dispatch.cpp | 19 ++----- torch/csrc/utils/tensor_conversion_dispatch.h | 2 +- torch/csrc/utils/tensor_new.cpp | 71 +++++++++++------------- 6 files changed, 65 insertions(+), 57 deletions(-) diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index dfa1b63..742b1d0 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -163,11 +163,23 @@ struct CAFFE2_API Type { /// Constructs the `TensorOptions` from a type and a `device_index`. TensorOptions options(int16_t device_index = -1) const { return TensorOptions().dtype(typeMeta()) - .device(backendToDeviceType(backend()), device_index) + .device(device_type(), device_index) .layout(layout()) .is_variable(is_variable()); } + /// Constructs the `TensorOptions` from a type and a Device. Asserts that + /// the device type matches the device type of the type. + TensorOptions options(optional device_opt) const { + if (!device_opt.has_value()) { + return options(-1); + } else { + Device device = device_opt.value(); + AT_ASSERT(device.type() == device_type()); + return options(device.index()); + } + } + operator TensorOptions() const { return options(); } diff --git a/aten/src/ATen/templates/Type.h b/aten/src/ATen/templates/Type.h index b269a47..e731040 100644 --- a/aten/src/ATen/templates/Type.h +++ b/aten/src/ATen/templates/Type.h @@ -134,11 +134,23 @@ struct CAFFE2_API Type { /// Constructs the `TensorOptions` from a type and a `device_index`. TensorOptions options(int16_t device_index = -1) const { return TensorOptions().dtype(typeMeta()) - .device(backendToDeviceType(backend()), device_index) + .device(device_type(), device_index) .layout(layout()) .is_variable(is_variable()); } + /// Constructs the `TensorOptions` from a type and a Device. Asserts that + /// the device type matches the device type of the type. + TensorOptions options(optional device_opt) const { + if (!device_opt.has_value()) { + return options(-1); + } else { + Device device = device_opt.value(); + AT_ASSERT(device.type() == device_type()); + return options(device.index()); + } + } + operator TensorOptions() const { return options(); } diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index e61f99d..fb9e58f 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -194,7 +194,7 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis static std::vector typeConvertIndices(const Variable& self, const variable_list& indices) { std::vector converted_inds(indices.size()); - int32_t device = self.is_cuda() ? self.get_device() : -1; + c10::Device device = self.device(); for (size_t i = 0; i < indices.size(); ++i) { const auto &ind = indices[i]; if (ind.defined()) { diff --git a/torch/csrc/utils/tensor_conversion_dispatch.cpp b/torch/csrc/utils/tensor_conversion_dispatch.cpp index e55118b..615782a 100644 --- a/torch/csrc/utils/tensor_conversion_dispatch.cpp +++ b/torch/csrc/utils/tensor_conversion_dispatch.cpp @@ -14,18 +14,16 @@ namespace torch { namespace utils { at::Tensor dispatch_type_conversion( const at::Tensor& self, const at::Type& type, - c10::optional device_index, + c10::optional device_opt, bool non_blocking) { + if (type.is_cuda()) { torch::utils::cuda_lazy_init(); } AutoNoGIL no_gil; // TODO: Make this less CUDA specific - at::Device device = self.device(); - if (device_index && *device_index != -1) { - device = at::Device(at::kCUDA, *device_index); - } + at::Device device = device_opt.value_or(self.device()); at::DeviceGuard device_guard(device); if (self.device().type() == type.device_type()) { @@ -34,19 +32,12 @@ at::Tensor dispatch_type_conversion( // Do nothing, there is only one CPU "device" // TODO: Maybe this wouldn't be true with NUMA break; - case at::DeviceType::CUDA: - if (self.device().index() != at::current_device()) { + default: + if (self.device() != device_guard.current_device()) { // copy if the devices are different even if the types are the same return type.copy(self, non_blocking); } break; - default: - // This assert failed because you tried to use copy() on a non-CUDA - // device. We couldn't figure out if this would have resulted in - // a cross-device copy, because at::current_device() only knows about - // the current CUDA device. Fix current_device to take a DeviceType - // and provide information for things that are not CUDA too! - AT_ASSERT(0); } } diff --git a/torch/csrc/utils/tensor_conversion_dispatch.h b/torch/csrc/utils/tensor_conversion_dispatch.h index 6dad49f..3c710dc 100644 --- a/torch/csrc/utils/tensor_conversion_dispatch.h +++ b/torch/csrc/utils/tensor_conversion_dispatch.h @@ -26,6 +26,6 @@ namespace torch { namespace utils { at::Tensor dispatch_type_conversion( const at::Tensor& self, const at::Type& type, - c10::optional device_index = c10::nullopt, + c10::optional device = c10::nullopt, bool non_blocking = false); }} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index f9d6ffc..4452b08 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -48,28 +48,28 @@ void maybe_initialize_cuda(const Type &type) { } } -Tensor dispatch_zeros(const Type& type, int32_t device_index, IntList sizes) { +Tensor dispatch_zeros(const Type& type, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::zeros(sizes, type.options(device_index)); + return torch::zeros(sizes, type.options(device)); } -Tensor dispatch_ones(const Type& type, int32_t device_index, IntList sizes) { +Tensor dispatch_ones(const Type& type, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::ones(sizes, type.options(device_index)); + return torch::ones(sizes, type.options(device)); } -Tensor dispatch_full(const Type& type, Scalar fill_value, int32_t device_index, IntList sizes) { +Tensor dispatch_full(const Type& type, Scalar fill_value, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::full(sizes, fill_value, type.options(device_index)); + return torch::full(sizes, fill_value, type.options(device)); } -Tensor new_with_sizes(const Type& type, int32_t device_index, IntList sizes) { +Tensor new_with_sizes(const Type& type, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; - return torch::empty(sizes, type.options(device_index)); + return torch::empty(sizes, type.options(device)); } Tensor new_with_storage(const Type& type, Storage storage) { @@ -85,20 +85,16 @@ Tensor new_with_tensor(const Type& type, Tensor other) { return other.slice(); } -Tensor new_with_type_conversion(const Type& type, Tensor other, int32_t device_index) { - return dispatch_type_conversion(other, type, device_index, false); +Tensor new_with_type_conversion(const Type& type, Tensor other, optional device) { + return dispatch_type_conversion(other, type, device, false); } -Tensor new_with_tensor_copy(const Type& type, Tensor other, int32_t device_index) { +Tensor new_with_tensor_copy(const Type& type, Tensor other, optional device) { + AT_ASSERT(device.has_value() ? device.value().type() == type.device_type() + : other.device().type() == type.device_type()); maybe_initialize_cuda(type); AutoNoGIL no_gil; - // TODO: It would be better if new_with_tensor_copy took an at::Device - // to begin with, but then we need to fix the situation with - // dispatch_type_conversion bleggg - at::OptionalDeviceGuard device_guard; - if (type.is_cuda()) { - device_guard.reset_device(at::Device(at::kCUDA, device_index)); - } + at::OptionalDeviceGuard device_guard(device); return type.copy(other); } @@ -206,10 +202,6 @@ Tensor internal_new_from_data( bool copy_variables, bool copy_numpy, bool type_inference) { - int32_t device_index = -1; - if (device_opt.has_value()) { - device_index = device_opt->index(); - } if (THPUtils_checkString(data)) { throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name); } @@ -224,16 +216,16 @@ Tensor internal_new_from_data( *torch::getLayout(type.backend()), type_inference_device_type); const auto& type_to_use = type_inference ? type_inference_type : type; - return copy_variables ? new_with_tensor_copy(type_to_use, var, device_index) - : new_with_type_conversion(type_to_use, var, device_index); + return copy_variables ? new_with_tensor_copy(type_to_use, var, device_opt) + : new_with_type_conversion(type_to_use, var, device_opt); } #ifdef USE_NUMPY if (PyArray_Check(data)) { auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false); const auto& type_to_use = type_inference ? type.toScalarType(tensor.type().scalarType()) : type; - return copy_numpy ? new_with_tensor_copy(type_to_use, tensor, device_index) : - new_with_type_conversion(type_to_use, tensor, device_index); + return copy_numpy ? new_with_tensor_copy(type_to_use, tensor, device_opt) : + new_with_type_conversion(type_to_use, tensor, device_opt); } #endif @@ -244,7 +236,7 @@ Tensor internal_new_from_data( (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0, scalarType, tensor.type().elementSizeInBytes(), data); const auto& type_to_use = type_inference ? type.toScalarType(scalarType) : type; - return new_with_type_conversion(type_to_use, tensor, device_index); + return new_with_type_conversion(type_to_use, tensor, device_opt); } Tensor new_from_data_copy( @@ -286,7 +278,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwa if (r.idx == 0) { auto deviceOptional = r.deviceOptional(0); check_legacy_ctor_device(type, deviceOptional); - return at::empty({0}, type.options(r.device(0).index())); + return at::empty({0}, type.options(r.deviceOptional(0))); } else if (r.idx == 1) { auto cdata = reinterpret_cast(r.toInt64(0)); return type.unsafeTensorFromTH(cdata, true); @@ -309,7 +301,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwa // unless the sequences is a torch.Size return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.device(1).index(), r.intlist(0)); + return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); } throw std::runtime_error("new(): invalid arguments"); } @@ -355,11 +347,12 @@ Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObject* kwar // unless the sequences is a torch.Size return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.device(1).index(), r.intlist(0)); + return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); } throw std::runtime_error("new(): invalid arguments"); } +// NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t device_idx, const Type& type) { const auto scalartype = r.scalartypeWithDefault(dtype_idx, type.scalarType()); const Device types_device_type(type.device_type()); @@ -405,7 +398,7 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) { // unless the sequences is a torch.Size return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.device(1).index(), r.intlist(0)); + return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); } else if (r.idx == 5) { auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(type, deviceOptional); @@ -451,7 +444,7 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) { // unless the sequences is a torch.Size return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type, r.device(1).index(), r.intlist(0)); + return new_with_sizes(type, r.deviceOptional(1), r.intlist(0)); } else if (r.idx == 5) { auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(type, deviceOptional); @@ -480,7 +473,7 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject bool type_inference = r.isNone(2); const auto& type = typeWithDefault(r, 2, 3, default_type); const auto& values_type = type.toDense(); - at::DeviceGuard device_guard(r.device(3)); + at::OptionalDeviceGuard device_guard(r.deviceOptional(3)); // if no dtype provided, infer type based on value type. Tensor values = internal_new_from_data(values_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference); const auto& indices_type = values.type().toScalarType(kLong); @@ -490,14 +483,14 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject bool type_inference = r.isNone(3); const auto& type = typeWithDefault(r, 3, 4, default_type); const auto& values_type = type.toDense(); - at::DeviceGuard device_guard(r.device(4)); + at::OptionalDeviceGuard device_guard(r.deviceOptional(4)); Tensor values = internal_new_from_data(values_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference); const auto& indices_type = values.type().toScalarType(kLong); Tensor indices = internal_new_from_data(indices_type, r.deviceOptional(4), r.pyobject(0), false, true, false); return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5)); } else if (r.idx == 2) { const auto& type = typeWithDefault(r, 1, 2, default_type); - at::DeviceGuard device_guard(r.device(2)); + at::OptionalDeviceGuard device_guard(r.deviceOptional(2)); return at::sparse_coo_tensor(r.intlist(0), type.options().layout(at::kSparse)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("sparse_coo_tensor(): invalid arguments"); @@ -586,7 +579,7 @@ Tensor new_empty(const Type& type, PyObject* args, PyObject* kwargs) { auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 1, 2, type); - return new_with_sizes(actual_type, r.device(2).index(), r.intlist(0)).set_requires_grad(r.toBool(3)); + return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("new_empty(): invalid arguments"); } @@ -600,7 +593,7 @@ Tensor new_full(const Type& type, PyObject* args, PyObject* kwargs) { auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 2, 3, type); - return dispatch_full(actual_type, r.scalar(1), r.device(3).index(), r.intlist(0)).set_requires_grad(r.toBool(4)); + return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4)); } throw std::runtime_error("new_full(): invalid arguments"); } @@ -614,7 +607,7 @@ Tensor new_ones(const Type& type, PyObject* args, PyObject* kwargs) { auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 1, 2, type); - return dispatch_ones(actual_type, r.device(2).index(), r.intlist(0)).set_requires_grad(r.toBool(3)); + return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("new_ones(): invalid arguments"); } @@ -628,7 +621,7 @@ Tensor new_zeros(const Type& type, PyObject* args, PyObject* kwargs) { auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { const auto& actual_type = typeWithDefault(r, 1, 2, type); - return dispatch_zeros(actual_type, r.device(2).index(), r.intlist(0)).set_requires_grad(r.toBool(3)); + return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("new_zeros(): invalid arguments"); } -- 2.7.4