From 2d56df7892ff6588da3fd079ee95d679644d6730 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 4 Dec 2018 13:58:31 -0800 Subject: [PATCH] Use .to to convert new tensors in new_tensor (#14097) Summary: This would solve the tracing problems of #13969. Fixes: #14732 I would appreciate if this got good scrutiny before applied. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14097 Differential Revision: D13323181 Pulled By: ezyang fbshipit-source-id: dcd104b497c0bfddb751923c6166a3824b7a3702 --- .../TestJit.test_export_tensoroption_to.expect | 9 +-- ...estScript.test_index_put_trace_with_view.expect | 14 +++-- ...Script.test_index_put_trace_without_view.expect | 14 +++-- .../autograd/templates/python_variable_methods.cpp | 22 ++++++-- torch/CMakeLists.txt | 1 - torch/csrc/autograd/python_variable_indexing.cpp | 9 +-- torch/csrc/utils/tensor_conversion_dispatch.cpp | 64 ---------------------- torch/csrc/utils/tensor_conversion_dispatch.h | 31 ----------- torch/csrc/utils/tensor_new.cpp | 56 +++++++++---------- 9 files changed, 65 insertions(+), 155 deletions(-) delete mode 100644 torch/csrc/utils/tensor_conversion_dispatch.cpp delete mode 100644 torch/csrc/utils/tensor_conversion_dispatch.h diff --git a/test/expect/TestJit.test_export_tensoroption_to.expect b/test/expect/TestJit.test_export_tensoroption_to.expect index d3b9aea..f0f2841 100644 --- a/test/expect/TestJit.test_export_tensoroption_to.expect +++ b/test/expect/TestJit.test_export_tensoroption_to.expect @@ -6,16 +6,11 @@ ModelProto { GraphProto { name: "torch-jit-export" inputs: [{name: "x", type:Tensor dims: 2}] - outputs: [{name: "7", type:Tensor dims: 2}] + outputs: [{name: "2", type:Tensor dims: 2}] initializers: [] nodes: [ Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, - Node {type: "Gather", inputs: [x,1], outputs: [2], attributes: [{ name: 'axis', type: int, value: 0}]}, - Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, - Node {type: "Shape", inputs: [3], outputs: [4], attributes: []}, - Node {type: "Expand", inputs: [2,4], outputs: [5], attributes: []}, - Node {type: "Cast", inputs: [5], outputs: [6], attributes: [{ name: 'to', type: int, value: 11}]}, - Node {type: "Add", inputs: [6,x], outputs: [7], attributes: []} + Node {type: "Add", inputs: [1,x], outputs: [2], attributes: []} ] } opset_import: [OperatorSetIdProto { domain: }], diff --git a/test/expect/TestScript.test_index_put_trace_with_view.expect b/test/expect/TestScript.test_index_put_trace_with_view.expect index 236ad85..39c92a4 100644 --- a/test/expect/TestScript.test_index_put_trace_with_view.expect +++ b/test/expect/TestScript.test_index_put_trace_with_view.expect @@ -4,10 +4,14 @@ graph(%target : Double(100) %3 : int = prim::Constant[value=4]() %4 : int[] = prim::ListConstruct(%3) %5 : Double(4) = aten::view(%rhs, %4) - %6 : bool = prim::Constant[value=0]() - %indices : Long(4) = aten::_cast_Long(%indices.1, %6) - %8 : Tensor[] = prim::ListConstruct(%indices) + %6 : int = prim::Constant[value=4]() + %7 : int = prim::Constant[value=0]() + %8 : Device = prim::Constant[value="cpu"]() %9 : bool = prim::Constant[value=0]() - %10 : Double(100) = aten::index_put_(%target, %8, %5, %9) - return (%10); + %10 : bool = prim::Constant[value=0]() + %indices : Long(4) = aten::to(%indices.1, %6, %7, %8, %9, %10) + %12 : Tensor[] = prim::ListConstruct(%indices) + %13 : bool = prim::Constant[value=0]() + %14 : Double(100) = aten::index_put_(%target, %12, %5, %13) + return (%14); } diff --git a/test/expect/TestScript.test_index_put_trace_without_view.expect b/test/expect/TestScript.test_index_put_trace_without_view.expect index cb67d2d..dfe403f 100644 --- a/test/expect/TestScript.test_index_put_trace_without_view.expect +++ b/test/expect/TestScript.test_index_put_trace_without_view.expect @@ -1,10 +1,14 @@ graph(%target : Double(100) %indices.1 : Long(4) %rhs : Double(4)) { - %3 : bool = prim::Constant[value=0]() - %indices : Long(4) = aten::_cast_Long(%indices.1, %3) - %5 : Tensor[] = prim::ListConstruct(%indices) + %3 : int = prim::Constant[value=4]() + %4 : int = prim::Constant[value=0]() + %5 : Device = prim::Constant[value="cpu"]() %6 : bool = prim::Constant[value=0]() - %7 : Double(100) = aten::index_put_(%target, %5, %rhs, %6) - return (%7); + %7 : bool = prim::Constant[value=0]() + %indices : Long(4) = aten::to(%indices.1, %3, %4, %5, %6, %7) + %9 : Tensor[] = prim::ListConstruct(%indices) + %10 : bool = prim::Constant[value=0]() + %11 : Double(100) = aten::index_put_(%target, %9, %rhs, %10) + return (%11); } diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 5104a54..08ff373 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -20,7 +20,6 @@ #include "torch/csrc/utils/python_strings.h" #include "torch/csrc/utils/python_tuples.h" #include "torch/csrc/utils/tensor_apply.h" -#include "torch/csrc/utils/tensor_conversion_dispatch.h" #include "torch/csrc/utils/tensor_list.h" #include "torch/csrc/utils/tensor_new.h" #include "torch/csrc/utils/tensor_numpy.h" @@ -605,11 +604,22 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa } else { throw TypeError("dtype must be a type, str, or dtype object"); } - auto self_device_type = torch::getDeviceType(self_.type()); - auto& type = is_dtype ? torch::getVariableType(r.scalartype(0), *torch::getLayout(self_.type().backend()), self_device_type) : - torch::utils::type_from_string(type_name); - return THPVariable_Wrap(torch::utils::dispatch_type_conversion( - self_, type, c10::nullopt, r.toBool(1))); + ScalarType scalar_type; + Device device = self_.device(); + if (is_dtype) { + scalar_type = r.scalartype(0); + } else { + auto& type = torch::utils::type_from_string(type_name); + scalar_type = type.scalarType(); + auto device_type = backendToDeviceType(type.backend()); + if (device_type != device.type()) { + device = at::Device(device_type); + } + } + if (device.is_cuda()) { + torch::utils::cuda_lazy_init(); + } + return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false)); END_HANDLE_TH_ERRORS } diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 90cc3d4..9c4a018 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -552,7 +552,6 @@ if (BUILD_PYTHON) ${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp ${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp - ${TORCH_SRC_DIR}/csrc/utils/tensor_conversion_dispatch.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index fb9e58f..c747a96 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -10,7 +10,6 @@ #include "torch/csrc/utils/python_compat.h" #include "torch/csrc/utils/python_numbers.h" #include "torch/csrc/utils/tensor_new.h" -#include "torch/csrc/utils/tensor_conversion_dispatch.h" #include "torch/csrc/jit/tracer.h" #include @@ -194,12 +193,10 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis static std::vector typeConvertIndices(const Variable& self, const variable_list& indices) { std::vector converted_inds(indices.size()); - c10::Device device = self.device(); for (size_t i = 0; i < indices.size(); ++i) { const auto &ind = indices[i]; if (ind.defined()) { - auto& new_type = ind.type().toBackend(self.type().backend()); - converted_inds[i] = torch::utils::dispatch_type_conversion(ind, new_type, device, false); + converted_inds[i] = ind.to(ind.options().device(self.device())); } else { converted_inds[i] = indices[i]; } @@ -208,15 +205,15 @@ static std::vector typeConvertIndices(const Variable& self, const variab } static Variable dispatch_index(const Variable& self, const variable_list& indices) { - std::vector converted_indices = typeConvertIndices(self, indices); AutoNoGIL no_gil; + std::vector converted_indices = typeConvertIndices(self, indices); OptionalDeviceGuard device_guard(device_of(self)); return self.index(converted_indices); } static Variable dispatch_index_put_(Variable& self, const variable_list& indices, const Variable& value) { - std::vector converted_indices = typeConvertIndices(self, indices); AutoNoGIL no_gil; + std::vector converted_indices = typeConvertIndices(self, indices); OptionalDeviceGuard device_guard(device_of(self)); return self.index_put_(converted_indices, value); } diff --git a/torch/csrc/utils/tensor_conversion_dispatch.cpp b/torch/csrc/utils/tensor_conversion_dispatch.cpp deleted file mode 100644 index 615782a..0000000 --- a/torch/csrc/utils/tensor_conversion_dispatch.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include - -#include "tensor_conversion_dispatch.h" - -#include "torch/csrc/utils/auto_gil.h" -#include "torch/csrc/utils/cuda_lazy_init.h" - -#include - -#include - -namespace torch { namespace utils { - -at::Tensor dispatch_type_conversion( - const at::Tensor& self, - const at::Type& type, - c10::optional device_opt, - bool non_blocking) { - - if (type.is_cuda()) { - torch::utils::cuda_lazy_init(); - } - AutoNoGIL no_gil; - - // TODO: Make this less CUDA specific - at::Device device = device_opt.value_or(self.device()); - at::DeviceGuard device_guard(device); - - if (self.device().type() == type.device_type()) { - switch (self.device().type()) { - case at::DeviceType::CPU: - // Do nothing, there is only one CPU "device" - // TODO: Maybe this wouldn't be true with NUMA - break; - default: - if (self.device() != device_guard.current_device()) { - // copy if the devices are different even if the types are the same - return type.copy(self, non_blocking); - } - break; - } - } - - // Don't specialize cross-backend copies - if (self.type().backend() != type.backend()) { - return self.toType(type, non_blocking); - } - - // Dispatch to specialized, traceable cast operators for the JIT. These - // specialized ops are ATen native and thus have the tracing mechanisms auto- - // generated, whereas the default case is not traceable since it requires a - // Type as a parameter/attribute. TODO: support Types in the JIT and remove - // this once we have that - switch (type.scalarType()) { -#define DEFINE_CAST_DISPATCH(_1, n, _2) \ - case at::ScalarType::n: { \ - return at::_cast_##n(self, non_blocking); \ - } break; - AT_FORALL_SCALAR_TYPES(DEFINE_CAST_DISPATCH) -#undef DEFINE_CAST_DISPATCH - default: { return self.toType(type, non_blocking); } break; - } -} -}} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_conversion_dispatch.h b/torch/csrc/utils/tensor_conversion_dispatch.h deleted file mode 100644 index 3c710dc..0000000 --- a/torch/csrc/utils/tensor_conversion_dispatch.h +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -// "Convert" tensor a different type and / or device - -#include - -#include - -namespace torch { namespace utils { - -// Returns a tensor with the same data as `self` and the specified type and -// device. Returns `self` unmodified if neither the type nor device change; -// otherwise a copy is made. -// -// The `device` argument is only relevant if `type` is a CUDA type. There are -// a few special cases for device: -// -// - if device is -1 then the returned tensor will be on the current device -// - if device is nullopt then the returned tensor will be on the same device -// as `self` if possible; otherwise it will be on the current device. -// -// If `non_blocking` is true, then the copy may be performed asynchronously -// w.r.t the host if `self` is a CPU tensor in pinned memory and `type` is a -// CUDA type. Note that copies between CUDA devices are always asynchronous -// w.r.t the host. -at::Tensor dispatch_type_conversion( - const at::Tensor& self, - const at::Type& type, - c10::optional device = c10::nullopt, - bool non_blocking = false); -}} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 4452b08..d6dea49 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -12,7 +12,6 @@ #include "torch/csrc/utils/python_numbers.h" #include "torch/csrc/utils/python_scalars.h" #include "torch/csrc/utils/python_strings.h" -#include "torch/csrc/utils/tensor_conversion_dispatch.h" #include "torch/csrc/utils/tensor_numpy.h" #include "torch/csrc/autograd/generated/variable_factories.h" @@ -48,6 +47,12 @@ void maybe_initialize_cuda(const Type &type) { } } +void maybe_initialize_cuda(const Device device) { + if (device.is_cuda()) { + torch::utils::cuda_lazy_init(); + } +} + Tensor dispatch_zeros(const Type& type, optional device, IntList sizes) { maybe_initialize_cuda(type); AutoNoGIL no_gil; @@ -85,19 +90,6 @@ Tensor new_with_tensor(const Type& type, Tensor other) { return other.slice(); } -Tensor new_with_type_conversion(const Type& type, Tensor other, optional device) { - return dispatch_type_conversion(other, type, device, false); -} - -Tensor new_with_tensor_copy(const Type& type, Tensor other, optional device) { - AT_ASSERT(device.has_value() ? device.value().type() == type.device_type() - : other.device().type() == type.device_type()); - maybe_initialize_cuda(type); - AutoNoGIL no_gil; - at::OptionalDeviceGuard device_guard(device); - return type.copy(other); -} - std::vector compute_sizes(PyObject* seq) { std::vector sizes; THPObjectPtr handle; @@ -208,35 +200,39 @@ Tensor internal_new_from_data( if (THPVariable_Check(data)) { auto var = reinterpret_cast(data)->cdata; - auto type_inference_device_type = device_opt.has_value() ? device_opt->type() - : torch::getDeviceType(var.type()); + if (copy_variables) { + var = var.detach(); + } // infer the scalar type and device type; it's not expected to infer the layout since these constructors // are defined per-layout-type (e.g. tensor vs sparse_coo_tensor). - const auto& type_inference_type = torch::getVariableType(var.type().scalarType(), - *torch::getLayout(type.backend()), - type_inference_device_type); - const auto& type_to_use = type_inference ? type_inference_type : type; - return copy_variables ? new_with_tensor_copy(type_to_use, var, device_opt) - : new_with_type_conversion(type_to_use, var, device_opt); + const auto& scalar_type = type_inference ? var.type().scalarType() : type.scalarType(); + auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(torch::getDeviceType(type))); + AutoNoGIL no_gil; + maybe_initialize_cuda(device); + return var.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_variables); } #ifdef USE_NUMPY if (PyArray_Check(data)) { auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false); - const auto& type_to_use = type_inference ? type.toScalarType(tensor.type().scalarType()) : type; - return copy_numpy ? new_with_tensor_copy(type_to_use, tensor, device_opt) : - new_with_type_conversion(type_to_use, tensor, device_opt); + const auto& scalar_type = type_inference ? tensor.type().scalarType() : type.scalarType(); + auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type()); + AutoNoGIL no_gil; + maybe_initialize_cuda(device); + return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_numpy); } #endif auto sizes = compute_sizes(data); - ScalarType scalarType = type_inference ? infer_scalar_type(data) : type.scalarType(); - auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalarType)), /*requires_grad=*/false); + ScalarType scalar_type = type_inference ? infer_scalar_type(data) : type.scalarType(); + auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalar_type)), /*requires_grad=*/false); recursive_store( (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0, - scalarType, tensor.type().elementSizeInBytes(), data); - const auto& type_to_use = type_inference ? type.toScalarType(scalarType) : type; - return new_with_type_conversion(type_to_use, tensor, device_opt); + scalar_type, tensor.type().elementSizeInBytes(), data); + auto device = device_opt.has_value() ? *device_opt : at::Device(torch::getDeviceType(type)); + AutoNoGIL no_gil; + maybe_initialize_cuda(device); + return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/false); } Tensor new_from_data_copy( -- 2.7.4