Remove copy and copy_ special case on Type (#18972)
authorRoy Li <royboy@fb.com>
Thu, 18 Apr 2019 07:18:35 +0000 (00:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 07:21:43 +0000 (00:21 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18972
ghimport-source-id: b5d3012b00530145fa24ab0cab693a7e80cb5989

Differential Revision: D14816530

Pulled By: li-roy

fbshipit-source-id: 9c7a166abb22d2cd1f81f352e44d9df1541b1774

21 files changed:
aten/src/ATen/SparseTensorUtils.h
aten/src/ATen/core/DeprecatedTypeProperties.cpp
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/Copy.cpp
aten/src/ATen/native/TensorConversions.cpp
aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/templates/Tensor.h
aten/src/ATen/templates/TensorMethods.h
aten/src/ATen/templates/Type.h
aten/src/ATen/templates/TypeDefault.cpp
aten/src/ATen/templates/TypeDefault.h
aten/src/ATen/test/scalar_test.cpp
tools/autograd/gen_variable_type.py
tools/autograd/templates/python_variable_methods.cpp
tools/pyi/gen_pyi.py
torch/csrc/autograd/VariableTypeManual.cpp
torch/csrc/autograd/functions/tensor.cpp
torch/csrc/cuda/comm.cpp

index f602f09..1a0669a 100644 (file)
@@ -33,8 +33,8 @@ inline void alias_into_sparse(const SparseTensor& self, const LongTensor& indice
 inline void copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) {
   alias_into_sparse(
       self,
-      self._indices().dispatch_type().copy(indices, non_blocking),
-      self._values().dispatch_type().copy(values, non_blocking));
+      indices.to(self._indices().options(), non_blocking, /*copy*/true),
+      values.to(self._values().options(), non_blocking, /*copy*/true));
 }
 
 // TODO: put this into the public API
index b634a50..a762bb0 100644 (file)
@@ -10,7 +10,10 @@ Tensor DeprecatedTypeProperties::unsafeTensorFromTH(void * th_pointer, bool reta
 }
 
 Tensor DeprecatedTypeProperties::copy(const Tensor & src, bool non_blocking, c10::optional<Device> to_device) const {
-  return getDispatchType().copy(src, non_blocking, to_device);
+  if (to_device) {
+    return src.to(src.options().dtype(scalarType()).device(to_device), non_blocking, /*copy*/true);
+  }
+  return src.to(src.options().dtype(scalarType()), non_blocking, /*copy*/true);
 }
 
 std::unique_ptr<Generator> DeprecatedTypeProperties::generator() const {
index 68977a8..e6552f2 100644 (file)
@@ -209,7 +209,6 @@ class CAFFE2_API Tensor {
   bool is_alias_of(const at::Tensor& other) const{
     return impl_->storage().is_alias_of(other.storage());
   }
-  Tensor & copy_(const Tensor & src, bool non_blocking=false);
   Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const;
   Tensor toType(ScalarType t) const;
   Tensor toBackend(Backend b) const;
@@ -368,6 +367,7 @@ class CAFFE2_API Tensor {
   Tensor clamp_min(Scalar min) const;
   Tensor & clamp_min_(Scalar min);
   Tensor contiguous() const;
+  Tensor & copy_(const Tensor & src, bool non_blocking=false);
   Tensor cos() const;
   Tensor & cos_();
   Tensor cosh() const;
index a7ff82c..c5eda66 100644 (file)
@@ -31,10 +31,6 @@ inline Tensor Tensor::hip() const {
   return toType(type().hip());
 }
 
-inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
-  return dispatch_type().copy_(*this, src, non_blocking);
-}
-
 inline Tensor Tensor::toType(ScalarType t) const {
   return toType(type().toScalarType(t));
 }
@@ -185,6 +181,9 @@ inline Tensor & Tensor::clamp_min_(Scalar min) {
 inline Tensor Tensor::contiguous() const {
     return dispatch_type().contiguous(*this);
 }
+inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
+    return dispatch_type().copy_(*this, src, non_blocking);
+}
 inline Tensor Tensor::cos() const {
     return dispatch_type().cos(*this);
 }
index 85b5387..d86877c 100644 (file)
@@ -164,12 +164,6 @@ struct CAFFE2_API Type {
     return backendToDeviceType(backend());
   }
 
-  virtual Tensor copy(
-      const Tensor& src,
-      bool non_blocking = false,
-      c10::optional<Device> to_device = {}) const = 0;
-  virtual Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const = 0;
-
   virtual void backward(
       Tensor& self,
       c10::optional<Tensor> gradient,
@@ -251,6 +245,7 @@ struct CAFFE2_API Type {
   virtual Tensor clamp_min(const Tensor & self, Scalar min) const = 0;
   virtual Tensor & clamp_min_(Tensor & self, Scalar min) const = 0;
   virtual Tensor contiguous(const Tensor & self) const = 0;
+  virtual Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
   virtual Tensor cos(const Tensor & self) const = 0;
   virtual Tensor & cos_(Tensor & self) const = 0;
   virtual Tensor cosh(const Tensor & self) const = 0;
index 848108a..9f6e303 100644 (file)
@@ -3,6 +3,7 @@
 #include <ATen/ATen.h>
 #include <ATen/CPUApplyUtils.h>
 #include <ATen/Dispatch.h>
+#include <ATen/ExpandUtils.h>
 #include <ATen/NativeFunctions.h>
 #include <ATen/native/cpu/CopyKernel.h>
 
@@ -37,6 +38,19 @@ bool copy_transpose_valid(const at::Tensor& self, const at::Tensor& src) {
 namespace at {
 namespace native {
 
+Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
+  Tensor b_src;
+  if (self.is_sparse() && src.is_sparse()) {
+    return at::copy_sparse_to_sparse_(self, src, non_blocking);
+  }
+  if (!self.is_sparse() && !src.is_sparse()) {
+    std::tie(b_src) = expand_inplace(self, src, "copy");
+    return s_copy_(self, b_src, non_blocking);
+  }
+  AT_ERROR("copy_() between dense and sparse Tensors is not implemented! Found self type = ",
+           self.type(), " and src type = ", src.type());
+}
+
 Tensor& _s_copy__cpu(Tensor& self, const Tensor& src, bool non_blocking) {
   if (src.type_id() != CPUTensorId()) {
     _s_copy_from(src, self, non_blocking);
index ec4c6a1..a4006f9 100644 (file)
@@ -20,8 +20,9 @@ static inline Device ensure_has_index(Device device) {
 }
 
 static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, bool non_blocking) {
-  return self.dispatch_type().toBackend(options.backend()).toScalarType(typeMetaToScalarType(options.dtype()))
-                             .copy(self, non_blocking, options.device());
+  auto r = at::empty(self.sizes(), options);
+  r.copy_(self, non_blocking);
+  return r;
 }
 
 Tensor to(const Tensor& self, const TensorOptions& options, bool non_blocking, bool copy) {
index d08676d..c7891ed 100644 (file)
@@ -140,12 +140,11 @@ Tensor& empty_out(Tensor& result, IntArrayRef size) {
 // specialized operators for each datatype.
 // TODO: remove when we have Type support in the IR
 
-#define DEFINE_CAST_OP(_1, n, _2)                                         \
-  Tensor _cast_##n(const Tensor& self, bool non_blocking) {               \
-    auto& target_type = self.dispatch_type().toScalarType(ScalarType::n); \
-    if (self.dispatch_type() == target_type)                              \
-      return self;                                                        \
-    return target_type.copy(self, non_blocking);                          \
+#define DEFINE_CAST_OP(_1, n, _2)                                \
+  Tensor _cast_##n(const Tensor& self, bool non_blocking) {      \
+    if (self.scalar_type() == ScalarType::n)                     \
+      return self;                                               \
+    return self.to(ScalarType::n, non_blocking);                 \
   }
 
 AT_FORALL_SCALAR_TYPES_AND_BOOL_EXCEPT_QINT(DEFINE_CAST_OP)
index 0c5b24e..7c1ec4b 100644 (file)
 
 - func: conv_transpose3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int groups=1, int[3] dilation=1) -> Tensor
 
+- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+  matches_jit_signature: True
+  variants: function, method
+  device_guard: False
+
 - func: s_copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
   cpu_half: True
   cpu_bool: True
index 8824c26..3944676 100644 (file)
@@ -209,7 +209,6 @@ class CAFFE2_API Tensor {
   bool is_alias_of(const at::Tensor& other) const{
     return impl_->storage().is_alias_of(other.storage());
   }
-  Tensor & copy_(const Tensor & src, bool non_blocking=false);
   Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const;
   Tensor toType(ScalarType t) const;
   Tensor toBackend(Backend b) const;
index 365b80c..218e2f3 100644 (file)
@@ -31,10 +31,6 @@ inline Tensor Tensor::hip() const {
   return toType(type().hip());
 }
 
-inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
-  return dispatch_type().copy_(*this, src, non_blocking);
-}
-
 inline Tensor Tensor::toType(ScalarType t) const {
   return toType(type().toScalarType(t));
 }
index 65c3cb1..14900ac 100644 (file)
@@ -105,12 +105,6 @@ struct CAFFE2_API Type {
     return backendToDeviceType(backend());
   }
 
-  virtual Tensor copy(
-      const Tensor& src,
-      bool non_blocking = false,
-      c10::optional<Device> to_device = {}) const = 0;
-  virtual Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const = 0;
-
   virtual void backward(
       Tensor& self,
       c10::optional<Tensor> gradient,
index 94b4553..f580967 100644 (file)
 
 namespace at {
 
-Tensor & TypeDefault::copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
-  Tensor b_src;
-  if (is_sparse()) {
-    b_src = src;
-  } else {
-    std::tie(b_src) = expand_inplace(self, src, "copy");
-  }
-  return s_copy_(self, b_src, non_blocking);
-}
-
-Tensor TypeDefault::copy(const Tensor & src, bool non_blocking, optional<Device> to_device) const {
-  AT_CHECK(src.defined(), "attempt to copy an undefined tensor");
-  Tensor r;
-  if (is_sparse()) {
-    r = at::empty({0}, this->options(to_device));
-  } else {
-    r = at::empty(src.sizes(), this->options(to_device));
-  }
-  r.copy_(src, non_blocking);
-  return r;
-}
-
 void TypeDefault::backward(
     Tensor& self,
     c10::optional<Tensor> gradient,
index 5235190..e5109bf 100644 (file)
@@ -31,9 +31,6 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface {
   Type & toBackend(Backend b) const override;
   Type & toScalarType(ScalarType s) const override;
 
-  Tensor copy(const Tensor & src, bool non_blocking=false, optional<Device> to_device={}) const override;
-  Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const override;
-
   void backward(
       Tensor& self,
       c10::optional<Tensor> gradient,
index 08b4bdf..8470a2f 100644 (file)
@@ -90,8 +90,8 @@ TEST(TestScalar, TestScalar) {
   test_overflow();
 
   if (at::hasCUDA()) {
-    auto r = CUDA(Float).copy(next_h);
-    ASSERT_TRUE(CPU(Float).copy(r).equal(next_h));
+    auto r = next_h.to(at::Device(kCUDA), kFloat, /*non_blocking*/ false, /*copy*/ true);
+    ASSERT_TRUE(r.to(at::Device(kCPU), kFloat, /*non_blocking*/ false, /*copy*/ true).equal(next_h));
   }
   ASSERT_NO_THROW(randn({10, 10, 2}, options));
 
index dd076b9..ac7a041 100644 (file)
@@ -30,7 +30,7 @@ from .gen_autograd_functions import uses_single_grad
 
 # These functions are written manually in templates/VariableType.cpp
 MANUAL_IMPLEMENTATIONS = {
-    'resize_', 'resize_as_', 'detach', 'detach_', 's_copy_', '_s_copy_from'
+    'resize_', 'resize_as_', 'detach', 'detach_', 'copy_'
 }
 
 # These functions we don't want to record for tracing, because we always want
index bdc2cf8..f617179 100644 (file)
@@ -172,26 +172,6 @@ static Tensor dispatch_contiguous(const Tensor & self) {
   END_HANDLE_TH_ERRORS
 }
 
-static Tensor dispatch_copy_(Tensor & self, const Tensor & other, bool non_blocking) {
-  AutoNoGIL no_gil;
-  OptionalDeviceGuard device_guard(device_of(self));
-  return self.copy_(other, non_blocking);
-}
-
-static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs)
-{
-  HANDLE_TH_ERRORS
-  static PythonArgParser parser({
-    "copy_(Tensor other, bool non_blocking=False)",
-    "copy_(Tensor other, bool async=False)|deprecated"
-  });
-  auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
-  ParsedArgs<2> parsed_args;
-  auto r = parser.parse(args, kwargs, parsed_args);
-  return THPVariable_Wrap(dispatch_copy_(self_, r.tensor(0), r.toBool(1)));
-  END_HANDLE_TH_ERRORS
-}
-
 static double dispatch_to_CDouble(const Tensor & self) {
   AutoNoGIL no_gil;
   OptionalDeviceGuard device_guard(device_of(self));
@@ -687,7 +667,6 @@ PyMethodDef variable_methods[] = {
   {"byte", (PyCFunction)THPVariable_byte, METH_NOARGS, NULL},
   {"char", (PyCFunction)THPVariable_char, METH_NOARGS, NULL},
   {"contiguous", (PyCFunction)THPVariable_contiguous, METH_NOARGS, NULL},
-  {"copy_", (PyCFunction)THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL},
   {"cpu", (PyCFunction)THPVariable_cpu, METH_NOARGS, NULL},
   {"cuda", (PyCFunction)THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL},
   {"dim", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL},
index 5c7eb4a..f3ddc10 100644 (file)
@@ -425,7 +425,6 @@ def gen_pyi(declarations_path, out):
         'numpy': ['def numpy(self) -> Any: ...'],
         'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
         'map_': ['def map_(tensor: Tensor, callable: Callable) -> Tensor: ...'],
-        'copy_': ['def copy_(self, src: Tensor, non_blocking: bool=False) -> Tensor: ...'],
         'storage': ['def storage(self) -> Storage: ...'],
         'type': ['def type(self, dtype: Union[None, str, _dtype]=None, non_blocking: bool=False)'
                  ' -> Union[str, Tensor]: ...'],
index 2a7414c..e3d0af5 100644 (file)
@@ -229,7 +229,8 @@ void VariableType::backward(
 void VariableType::set_data(Tensor & self, Tensor new_data) const {
   as_variable_ref(self).set_data(new_data);
 }
-Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
+
+Tensor & VariableType::copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
   jit::Value* output = nullptr;
   if(torch::jit::tracer::isTracing()) {
     const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
@@ -265,9 +266,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
   }
   {
     at::AutoNonVariableTypeMode non_var_type_mode(true);
-    if (self.is_sparse() && src.is_sparse()) baseType->copy_sparse_to_sparse_(self_, src_, non_blocking);
-    else if (!self.is_sparse() && !src.is_sparse()) baseType->s_copy_(self_, src_, non_blocking);
-    else AT_ERROR("copy_() between dense and sparse Tensors is not implemented! Found self type = ", self.type(), " and src type = ", src.type());
+    baseType->copy_(self_, src_, non_blocking);
   }
   increment_version(self);
   rebase_history(as_variable_ref( self ), std::move(grad_fn));
@@ -277,10 +276,6 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
   return self;
 }
 
-Tensor VariableType::_s_copy_from(const Tensor & self, const Tensor & dst, bool non_blocking) const {
-  AT_ERROR("copy_from does not support automatic differentiation; use copy_ instead");
-}
-
 Tensor & VariableType::resize_(Tensor & self, IntArrayRef size) const {
   auto& self_ = unpack(self, "self", 0);
   if (as_variable_ref(self).requires_grad()) {
index af4f4ec..fe2bea3 100644 (file)
@@ -27,7 +27,11 @@ auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
     // TODO: What if !grad.is_cuda(), but src_device is CUDA?
     // This code is kind of weirdly asymmetric.
     if (grad.is_cuda() && grad.device() != src_device) {
-      grad_inputs[1] = src_type->copy(grad);
+      grad_inputs[1] = grad.to(
+          src_type->device_type(),
+          src_type->scalarType(),
+          /*non_blocking*/false,
+          /*copy*/true);
     } else {
       grad_inputs[1] = grad.toType(*src_type);
     }
index c1f1b43..53faa6b 100644 (file)
@@ -66,14 +66,17 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
 #else
   {
 #endif
-    auto & gpu_type = type.toBackend(type.is_sparse() ? at::Backend::SparseCUDA : at::Backend::CUDA);
     if (type.is_cuda()) {
       tensors.push_back(tensor);
     }
     IntArrayRef loop_devices = type.is_cuda() ? devices.slice(1) : devices;
     for (auto device : loop_devices) {
       _device_guard.set_index(device);
-      tensors.push_back(gpu_type.copy(tensor, true));
+      tensors.push_back(tensor.to(
+          kCUDA,
+          type.scalarType(),
+          /*non_blocking*/true,
+          /*copy*/true));
     }
   }
   return tensors;