Use .to to convert new tensors in new_tensor (#14097)
authorThomas Viehmann <tv.code@beamnet.de>
Tue, 4 Dec 2018 21:58:31 +0000 (13:58 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 22:03:56 +0000 (14:03 -0800)
Summary:
This would solve the tracing problems of #13969.
Fixes: #14732

I would appreciate if this got good scrutiny before applied.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14097

Differential Revision: D13323181

Pulled By: ezyang

fbshipit-source-id: dcd104b497c0bfddb751923c6166a3824b7a3702

test/expect/TestJit.test_export_tensoroption_to.expect
test/expect/TestScript.test_index_put_trace_with_view.expect
test/expect/TestScript.test_index_put_trace_without_view.expect
tools/autograd/templates/python_variable_methods.cpp
torch/CMakeLists.txt
torch/csrc/autograd/python_variable_indexing.cpp
torch/csrc/utils/tensor_conversion_dispatch.cpp [deleted file]
torch/csrc/utils/tensor_conversion_dispatch.h [deleted file]
torch/csrc/utils/tensor_new.cpp

index d3b9aea..f0f2841 100644 (file)
@@ -6,16 +6,11 @@ ModelProto {
     GraphProto {
       name: "torch-jit-export"
       inputs: [{name: "x", type:Tensor dims: 2}]
-      outputs: [{name: "7", type:Tensor dims: 2}]
+      outputs: [{name: "2", type:Tensor dims: 2}]
       initializers: []
       nodes: [
         Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
-        Node {type: "Gather", inputs: [x,1], outputs: [2], attributes: [{ name: 'axis', type: int, value: 0}]},
-        Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
-        Node {type: "Shape", inputs: [3], outputs: [4], attributes: []},
-        Node {type: "Expand", inputs: [2,4], outputs: [5], attributes: []},
-        Node {type: "Cast", inputs: [5], outputs: [6], attributes: [{ name: 'to', type: int, value: 11}]},
-        Node {type: "Add", inputs: [6,x], outputs: [7], attributes: []}
+        Node {type: "Add", inputs: [1,x], outputs: [2], attributes: []}
       ]
     }
   opset_import: [OperatorSetIdProto { domain: }],
index 236ad85..39c92a4 100644 (file)
@@ -4,10 +4,14 @@ graph(%target : Double(100)
   %3 : int = prim::Constant[value=4]()
   %4 : int[] = prim::ListConstruct(%3)
   %5 : Double(4) = aten::view(%rhs, %4)
-  %6 : bool = prim::Constant[value=0]()
-  %indices : Long(4) = aten::_cast_Long(%indices.1, %6)
-  %8 : Tensor[] = prim::ListConstruct(%indices)
+  %6 : int = prim::Constant[value=4]()
+  %7 : int = prim::Constant[value=0]()
+  %8 : Device = prim::Constant[value="cpu"]()
   %9 : bool = prim::Constant[value=0]()
-  %10 : Double(100) = aten::index_put_(%target, %8, %5, %9)
-  return (%10);
+  %10 : bool = prim::Constant[value=0]()
+  %indices : Long(4) = aten::to(%indices.1, %6, %7, %8, %9, %10)
+  %12 : Tensor[] = prim::ListConstruct(%indices)
+  %13 : bool = prim::Constant[value=0]()
+  %14 : Double(100) = aten::index_put_(%target, %12, %5, %13)
+  return (%14);
 }
index cb67d2d..dfe403f 100644 (file)
@@ -1,10 +1,14 @@
 graph(%target : Double(100)
       %indices.1 : Long(4)
       %rhs : Double(4)) {
-  %3 : bool = prim::Constant[value=0]()
-  %indices : Long(4) = aten::_cast_Long(%indices.1, %3)
-  %5 : Tensor[] = prim::ListConstruct(%indices)
+  %3 : int = prim::Constant[value=4]()
+  %4 : int = prim::Constant[value=0]()
+  %5 : Device = prim::Constant[value="cpu"]()
   %6 : bool = prim::Constant[value=0]()
-  %7 : Double(100) = aten::index_put_(%target, %5, %rhs, %6)
-  return (%7);
+  %7 : bool = prim::Constant[value=0]()
+  %indices : Long(4) = aten::to(%indices.1, %3, %4, %5, %6, %7)
+  %9 : Tensor[] = prim::ListConstruct(%indices)
+  %10 : bool = prim::Constant[value=0]()
+  %11 : Double(100) = aten::index_put_(%target, %9, %rhs, %10)
+  return (%11);
 }
index 5104a54..08ff373 100644 (file)
@@ -20,7 +20,6 @@
 #include "torch/csrc/utils/python_strings.h"
 #include "torch/csrc/utils/python_tuples.h"
 #include "torch/csrc/utils/tensor_apply.h"
-#include "torch/csrc/utils/tensor_conversion_dispatch.h"
 #include "torch/csrc/utils/tensor_list.h"
 #include "torch/csrc/utils/tensor_new.h"
 #include "torch/csrc/utils/tensor_numpy.h"
@@ -605,11 +604,22 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
   } else {
     throw TypeError("dtype must be a type, str, or dtype object");
   }
-  auto self_device_type = torch::getDeviceType(self_.type());
-  auto& type = is_dtype ? torch::getVariableType(r.scalartype(0), *torch::getLayout(self_.type().backend()), self_device_type) :
-                          torch::utils::type_from_string(type_name);
-  return THPVariable_Wrap(torch::utils::dispatch_type_conversion(
-      self_, type, c10::nullopt, r.toBool(1)));
+  ScalarType scalar_type;
+  Device device = self_.device();
+  if (is_dtype) {
+    scalar_type = r.scalartype(0);
+  } else {
+    auto& type = torch::utils::type_from_string(type_name);
+    scalar_type = type.scalarType();
+    auto device_type = backendToDeviceType(type.backend());
+    if (device_type != device.type()) {
+      device = at::Device(device_type);
+    }
+  }
+  if (device.is_cuda()) {
+    torch::utils::cuda_lazy_init();
+  }
+  return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false));
   END_HANDLE_TH_ERRORS
 }
 
index 90cc3d4..9c4a018 100644 (file)
@@ -552,7 +552,6 @@ if (BUILD_PYTHON)
     ${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp
     ${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
-    ${TORCH_SRC_DIR}/csrc/utils/tensor_conversion_dispatch.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_layouts.cpp
index fb9e58f..c747a96 100644 (file)
@@ -10,7 +10,6 @@
 #include "torch/csrc/utils/python_compat.h"
 #include "torch/csrc/utils/python_numbers.h"
 #include "torch/csrc/utils/tensor_new.h"
-#include "torch/csrc/utils/tensor_conversion_dispatch.h"
 #include "torch/csrc/jit/tracer.h"
 
 #include <ATen/DeviceGuard.h>
@@ -194,12 +193,10 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis
 
 static std::vector<Tensor> typeConvertIndices(const Variable& self, const variable_list& indices) {
   std::vector<Tensor> converted_inds(indices.size());
-  c10::Device device = self.device();
   for (size_t i = 0; i < indices.size(); ++i) {
     const auto &ind = indices[i];
     if (ind.defined()) {
-      auto& new_type = ind.type().toBackend(self.type().backend());
-      converted_inds[i] = torch::utils::dispatch_type_conversion(ind, new_type, device, false);
+      converted_inds[i] = ind.to(ind.options().device(self.device()));
     } else {
       converted_inds[i] = indices[i];
     }
@@ -208,15 +205,15 @@ static std::vector<Tensor> typeConvertIndices(const Variable& self, const variab
 }
 
 static Variable dispatch_index(const Variable& self, const variable_list& indices) {
-  std::vector<Tensor> converted_indices = typeConvertIndices(self, indices);
   AutoNoGIL no_gil;
+  std::vector<Tensor> converted_indices = typeConvertIndices(self, indices);
   OptionalDeviceGuard device_guard(device_of(self));
   return self.index(converted_indices);
 }
 
 static Variable dispatch_index_put_(Variable& self, const variable_list& indices, const Variable& value) {
-  std::vector<Tensor> converted_indices = typeConvertIndices(self, indices);
   AutoNoGIL no_gil;
+  std::vector<Tensor> converted_indices = typeConvertIndices(self, indices);
   OptionalDeviceGuard device_guard(device_of(self));
   return self.index_put_(converted_indices, value);
 }
diff --git a/torch/csrc/utils/tensor_conversion_dispatch.cpp b/torch/csrc/utils/tensor_conversion_dispatch.cpp
deleted file mode 100644 (file)
index 615782a..0000000
+++ /dev/null
@@ -1,64 +0,0 @@
-#include <Python.h>
-
-#include "tensor_conversion_dispatch.h"
-
-#include "torch/csrc/utils/auto_gil.h"
-#include "torch/csrc/utils/cuda_lazy_init.h"
-
-#include <ATen/DeviceGuard.h>
-
-#include <cstddef>
-
-namespace torch { namespace utils {
-
-at::Tensor dispatch_type_conversion(
-    const at::Tensor& self,
-    const at::Type& type,
-    c10::optional<c10::Device> device_opt,
-    bool non_blocking) {
-
-  if (type.is_cuda()) {
-    torch::utils::cuda_lazy_init();
-  }
-  AutoNoGIL no_gil;
-
-  // TODO: Make this less CUDA specific
-  at::Device device = device_opt.value_or(self.device());
-  at::DeviceGuard device_guard(device);
-
-  if (self.device().type() == type.device_type()) {
-    switch (self.device().type()) {
-      case at::DeviceType::CPU:
-        // Do nothing, there is only one CPU "device"
-        // TODO: Maybe this wouldn't be true with NUMA
-        break;
-      default:
-        if (self.device() != device_guard.current_device()) {
-          // copy if the devices are different even if the types are the same
-          return type.copy(self, non_blocking);
-        }
-        break;
-    }
-  }
-
-  // Don't specialize cross-backend copies
-  if (self.type().backend() != type.backend()) {
-    return self.toType(type, non_blocking);
-  }
-
-  // Dispatch to specialized, traceable cast operators for the JIT. These
-  // specialized ops are ATen native and thus have the tracing mechanisms auto-
-  // generated, whereas the default case is not traceable since it requires a
-  // Type as a parameter/attribute. TODO: support Types in the JIT and remove
-  // this once we have that
-  switch (type.scalarType()) {
-#define DEFINE_CAST_DISPATCH(_1, n, _2)   \
-  case at::ScalarType::n: {               \
-    return at::_cast_##n(self, non_blocking); \
-  } break;
-    AT_FORALL_SCALAR_TYPES(DEFINE_CAST_DISPATCH)
-#undef DEFINE_CAST_DISPATCH
-    default: { return self.toType(type, non_blocking); } break;
-  }
-}
-}} // namespace torch::utils
diff --git a/torch/csrc/utils/tensor_conversion_dispatch.h b/torch/csrc/utils/tensor_conversion_dispatch.h
deleted file mode 100644 (file)
index 3c710dc..0000000
+++ /dev/null
@@ -1,31 +0,0 @@
-#pragma once
-
-// "Convert" tensor a different type and / or device
-
-#include <ATen/ATen.h>
-
-#include <cstddef>
-
-namespace torch { namespace utils {
-
-// Returns a tensor with the same data as `self` and the specified type and
-// device. Returns `self` unmodified if neither the type nor device change;
-// otherwise a copy is made.
-//
-// The `device` argument is only relevant if `type` is a CUDA type. There are
-// a few special cases for device:
-//
-//  - if device is -1 then the returned tensor will be on the current device
-//  - if device is nullopt then the returned tensor will be on the same device
-//    as `self` if possible; otherwise it will be on the current device.
-//
-// If `non_blocking` is true, then the copy may be performed asynchronously
-// w.r.t the host if `self` is a CPU tensor in pinned memory and `type` is a
-// CUDA type. Note that copies between CUDA devices are always asynchronous
-// w.r.t the host.
-at::Tensor dispatch_type_conversion(
-    const at::Tensor& self,
-    const at::Type& type,
-    c10::optional<c10::Device> device = c10::nullopt,
-    bool non_blocking = false);
-}} // namespace torch::utils
index 4452b08..d6dea49 100644 (file)
@@ -12,7 +12,6 @@
 #include "torch/csrc/utils/python_numbers.h"
 #include "torch/csrc/utils/python_scalars.h"
 #include "torch/csrc/utils/python_strings.h"
-#include "torch/csrc/utils/tensor_conversion_dispatch.h"
 #include "torch/csrc/utils/tensor_numpy.h"
 #include "torch/csrc/autograd/generated/variable_factories.h"
 
@@ -48,6 +47,12 @@ void maybe_initialize_cuda(const Type &type) {
   }
 }
 
+void maybe_initialize_cuda(const Device device) {
+  if (device.is_cuda()) {
+    torch::utils::cuda_lazy_init();
+  }
+}
+
 Tensor dispatch_zeros(const Type& type, optional<Device> device, IntList sizes) {
   maybe_initialize_cuda(type);
   AutoNoGIL no_gil;
@@ -85,19 +90,6 @@ Tensor new_with_tensor(const Type& type, Tensor other) {
   return other.slice();
 }
 
-Tensor new_with_type_conversion(const Type& type, Tensor other, optional<Device> device) {
-  return dispatch_type_conversion(other, type, device, false);
-}
-
-Tensor new_with_tensor_copy(const Type& type, Tensor other, optional<Device> device) {
-  AT_ASSERT(device.has_value() ? device.value().type() == type.device_type()
-                               : other.device().type() == type.device_type());
-  maybe_initialize_cuda(type);
-  AutoNoGIL no_gil;
-  at::OptionalDeviceGuard device_guard(device);
-  return type.copy(other);
-}
-
 std::vector<int64_t> compute_sizes(PyObject* seq) {
   std::vector<int64_t> sizes;
   THPObjectPtr handle;
@@ -208,35 +200,39 @@ Tensor internal_new_from_data(
 
   if (THPVariable_Check(data)) {
     auto var = reinterpret_cast<THPVariable*>(data)->cdata;
-    auto type_inference_device_type = device_opt.has_value() ? device_opt->type()
-                                                             : torch::getDeviceType(var.type());
+    if (copy_variables) {
+      var = var.detach();
+    }
     // infer the scalar type and device type; it's not expected to infer the layout since these constructors
     // are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
-    const auto& type_inference_type = torch::getVariableType(var.type().scalarType(),
-                                                     *torch::getLayout(type.backend()),
-                                                     type_inference_device_type);
-    const auto& type_to_use = type_inference ? type_inference_type : type;
-    return copy_variables ? new_with_tensor_copy(type_to_use, var, device_opt)
-                          : new_with_type_conversion(type_to_use, var, device_opt);
+    const auto& scalar_type = type_inference ? var.type().scalarType() : type.scalarType();
+    auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(torch::getDeviceType(type)));
+    AutoNoGIL no_gil;
+    maybe_initialize_cuda(device);
+    return var.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_variables);
   }
 
 #ifdef USE_NUMPY
   if (PyArray_Check(data)) {
     auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
-    const auto& type_to_use = type_inference ? type.toScalarType(tensor.type().scalarType()) : type;
-    return copy_numpy ? new_with_tensor_copy(type_to_use, tensor, device_opt) :
-                        new_with_type_conversion(type_to_use, tensor, device_opt);
+    const auto& scalar_type = type_inference ? tensor.type().scalarType() : type.scalarType();
+    auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type());
+    AutoNoGIL no_gil;
+    maybe_initialize_cuda(device);
+    return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/copy_numpy);
   }
 #endif
 
   auto sizes = compute_sizes(data);
-  ScalarType scalarType = type_inference ? infer_scalar_type(data) : type.scalarType();
-  auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalarType)), /*requires_grad=*/false);
+  ScalarType scalar_type = type_inference ? infer_scalar_type(data) : type.scalarType();
+  auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalar_type)), /*requires_grad=*/false);
   recursive_store(
       (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
-      scalarType, tensor.type().elementSizeInBytes(), data);
-  const auto& type_to_use = type_inference ? type.toScalarType(scalarType) : type;
-  return new_with_type_conversion(type_to_use, tensor, device_opt);
+      scalar_type, tensor.type().elementSizeInBytes(), data);
+  auto device = device_opt.has_value() ? *device_opt : at::Device(torch::getDeviceType(type));
+  AutoNoGIL no_gil;
+  maybe_initialize_cuda(device);
+  return tensor.to(device, scalar_type, /*blocking=*/false, /*copy=*/false);
 }
 
 Tensor new_from_data_copy(