Add at::scalar_tensor factory function, use it instead of Type.scalar… (#15074)
authorGregory Chanan <gchanan@fb.com>
Wed, 12 Dec 2018 04:35:37 +0000 (20:35 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 04:37:41 +0000 (20:37 -0800)
Summary:
…_tensor.

This is part of a long series of paring down the Type interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15074

Differential Revision: D13421482

Pulled By: gchanan

fbshipit-source-id: 84010ee71fef2cb74d32d5de7858d8ed9f36b885

15 files changed:
aten/src/ATen/ScalarOps.h
aten/src/ATen/core/Type.h
aten/src/ATen/function_wrapper.py
aten/src/ATen/native/BinaryOps.cpp
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/ReduceOpsUtils.h
aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/cuda/TensorFactories.cu
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/sparse/SparseTensorMath.cpp
aten/src/ATen/templates/Type.h
aten/src/ATen/templates/TypeDefault.cpp
aten/src/ATen/templates/TypeDefault.h
aten/src/ATen/test/basic.cpp
torch/csrc/autograd/python_variable_indexing.cpp

index d98f136..70f6163 100644 (file)
@@ -10,10 +10,10 @@ namespace c10 {
 // to implement this without going through Derived Types (which are not part of core).
 inline at::Tensor scalar_to_tensor(Scalar s) {
   if (s.isFloatingPoint()) {
-    return at::CPU(kDouble).scalarTensor(s);
+    return at::scalar_tensor(s, at::CPU(kDouble).options());
   } else {
     AT_ASSERT(s.isIntegral());
-    return at::CPU(kLong).scalarTensor(s);
+    return at::scalar_tensor(s, at::CPU(kLong).options());
   }
 }
 
index afeba43..cf6c2e3 100644 (file)
@@ -154,7 +154,6 @@ struct CAFFE2_API Type {
   virtual Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
   virtual Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const = 0;
   virtual Tensor tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const = 0;
-  virtual Tensor scalarTensor(Scalar s) const = 0;
 
   bool operator==(const Type& other) const {
     return this == &other;
index 6930182..fd6a377 100644 (file)
@@ -1518,9 +1518,9 @@ def create_derived(backend_type_env, declarations):
                     env, wrapped_tensor=wrapped_tensor, maybe_scalar=maybe_scalar))
             # return the same underlying Tensor type for both real and accreal; this ensures
             # e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
-            # ScalarType before constructing the scalarTensor to avoid overflow checking.
+            # ScalarType before constructing the scalar_tensor to avoid overflow checking.
             elif ret['type'] == 'accreal' or ret['type'] == 'real':
-                return_scalar = 'return scalarTensor(convert<${ScalarType}>(${call}));'
+                return_scalar = 'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());'
                 body.append(CodeTemplate(return_scalar).substitute(env, call=call))
             else:
                 # we using int64_t for long in the API, so correct it here...
index 3a58401..318d6e5 100644 (file)
@@ -130,46 +130,46 @@ Tensor rsub(const Tensor& self, const Tensor& other, Scalar alpha) {
 // types (int, float, etc.) to Tensor (only to Scalar). They're not exposed
 // to Python.
 
-static Tensor scalar_tensor(Scalar scalar) {
+static Tensor wrapped_scalar_tensor(Scalar scalar) {
   auto tensor = scalar_to_tensor(scalar);
   tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
   return tensor;
 }
 
 Tensor add(const Tensor& self, Scalar other, Scalar alpha) {
-  return native::add(self, scalar_tensor(other), alpha);
+  return native::add(self, wrapped_scalar_tensor(other), alpha);
 }
 
 Tensor& add_(Tensor& self, Scalar other, Scalar alpha) {
-  return native::add_(self, scalar_tensor(other), alpha);
+  return native::add_(self, wrapped_scalar_tensor(other), alpha);
 }
 
 Tensor div(const Tensor& self, Scalar other) {
-  return native::div(self, scalar_tensor(other));
+  return native::div(self, wrapped_scalar_tensor(other));
 }
 
 Tensor& div_(Tensor& self, Scalar other) {
-  return native::div_(self, scalar_tensor(other));
+  return native::div_(self, wrapped_scalar_tensor(other));
 }
 
 Tensor mul(const Tensor& self, Scalar other) {
-  return native::mul(self, scalar_tensor(other));
+  return native::mul(self, wrapped_scalar_tensor(other));
 }
 
 Tensor& mul_(Tensor& self, Scalar other) {
-  return native::mul_(self, scalar_tensor(other));
+  return native::mul_(self, wrapped_scalar_tensor(other));
 }
 
 Tensor sub(const Tensor& self, Scalar other, Scalar alpha) {
-  return native::sub(self, scalar_tensor(other), alpha);
+  return native::sub(self, wrapped_scalar_tensor(other), alpha);
 }
 
 Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) {
-  return native::sub_(self, scalar_tensor(other), alpha);
+  return native::sub_(self, wrapped_scalar_tensor(other), alpha);
 }
 
 Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
-  return native::rsub(self, scalar_tensor(other), alpha);
+  return native::rsub(self, wrapped_scalar_tensor(other), alpha);
 }
 
 }
index 9050ca7..96345c7 100644 (file)
@@ -194,7 +194,7 @@ static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
     Tensor result = at::native::sum(self);
     return result.div_(self.numel());
   } else {
-    return self.type().scalarTensor(std::numeric_limits<double>::quiet_NaN());
+    return at::scalar_tensor(std::numeric_limits<double>::quiet_NaN(), self.options());
   }
 }
 
@@ -457,7 +457,7 @@ Tensor _norm(const Tensor &self, Scalar p) {
       return at::legacy::th::_th_norm(self, p);
     } else {
       if (self.is_contiguous()) {
-        Tensor result = CPU(kFloat).scalarTensor(0).toType(self.type());
+        Tensor result = at::scalar_tensor(0, CPU(kFloat).options()).toType(self.type());
         norm_kernel(kCPU, result, self, p, c10::nullopt);
         return result;
       } else {
index af8fdcf..70f4509 100644 (file)
@@ -49,7 +49,7 @@ static c10::optional<Tensor> _allreduce_return_trivial(
     Scalar ident) {
   // Return identity
   if (self.numel() == 0) {
-    return self.type().scalarTensor(ident);
+    return at::scalar_tensor(ident, self.options());
   }
   return c10::nullopt;
 }
index 48914dc..ed4537f 100644 (file)
@@ -300,6 +300,12 @@ Tensor ones_like(const Tensor& self, const TensorOptions& options) {
   return native::ones(self.sizes(), options);
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ scalar_tensor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Tensor scalar_tensor(Scalar s, const TensorOptions& options) {
+  return at::empty({}, options).fill_(s);
+}
+
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ rand ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 Tensor rand(IntList size, const TensorOptions& options) {
index 66eb583..4afded7 100644 (file)
@@ -63,7 +63,7 @@ Tensor empty_cuda(IntList size, const TensorOptions& options) {
 
 Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
   AT_CHECK(n >= 0, "n must be non-negative, got", n);
-  AT_CHECK(result.type().scalarTensor(n).defined(),
+  AT_CHECK(at::scalar_tensor(n, result.options()).defined(),
   "n is too large for result tensor type: '", result.type().toString(), "'");
 
   result.resize_({n});
index 93328f2..b1a20b4 100644 (file)
 - func: pinverse(Tensor self, double rcond=1e-15) -> Tensor
   variants: function, method
 
+- func: scalar_tensor(Scalar s, *, TensorOptions options={}) -> Tensor
+
 - func: rand(IntList size, *, TensorOptions options={}) -> Tensor
 
 - func: rand(IntList size, *, Generator* generator, TensorOptions options={}) -> Tensor
index ad11871..d8f4d6c 100644 (file)
@@ -55,7 +55,7 @@ SparseTensor& zero_sparse_(SparseTensor& self) {
 // mul(SparseTensor, Scalar)
 // --------------------------------------------------------------------
 
-static Tensor scalar_tensor(Scalar s) {
+static Tensor wrapped_scalar_tensor(Scalar s) {
   auto tensor = scalar_to_tensor(s);
   tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
   return tensor;
@@ -82,7 +82,7 @@ SparseTensor& mul_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, con
 }
 
 SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) {
-  return mul_out_sparse_zerodim(r, t, scalar_tensor(value));
+  return mul_out_sparse_zerodim(r, t, wrapped_scalar_tensor(value));
 }
 
 // --------------------------------------------------------------------
@@ -167,7 +167,7 @@ SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, con
 }
 
 SparseTensor& div_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, Scalar value) {
-  return div_out_sparse_zerodim(r, t, scalar_tensor(value));
+  return div_out_sparse_zerodim(r, t, wrapped_scalar_tensor(value));
 }
 
 // --------------------------------------------------------------------
@@ -546,7 +546,7 @@ Tensor& s_addmm_out_sparse_dense_cpu(
   int64_t nnz        = sparse._nnz();
 
   if (nnz == 0) {
-    at::mul_out(r, t, r.type().scalarTensor(beta));
+    at::mul_out(r, t, at::scalar_tensor(beta, r.options()));
     return r;
   }
 
index cc8aad6..c927538 100644 (file)
@@ -125,7 +125,6 @@ struct CAFFE2_API Type {
   virtual Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
   virtual Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const = 0;
   virtual Tensor tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const = 0;
-  virtual Tensor scalarTensor(Scalar s) const = 0;
 
   bool operator==(const Type& other) const {
     return this == &other;
index a393e66..7fa816d 100644 (file)
@@ -124,11 +124,6 @@ Storage TypeDefault::unsafeStorageFromTH(void * th_pointer, bool retain) const {
   return Storage(c10::intrusive_ptr<StorageImpl>::reclaim(static_cast<StorageImpl*>(th_pointer)));
 }
 
-
-Tensor TypeDefault::scalarTensor(Scalar s) const {
-  return at::empty({}, this->options()).fill_(s);
-}
-
 ${type_method_definitions}
 
 }
index 9929c8d..12c12f2 100644 (file)
@@ -42,7 +42,6 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface {
   Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const override;
   Tensor tensorWithAllocator(IntList sizes, Allocator* allocator) const override;
   Tensor tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const override;
-  Tensor scalarTensor(Scalar s) const override;
 
   Storage storage(bool resizable = false) const override;
   Storage storage(size_t size, bool resizable = false) const override;
index ed03028..7d301b4 100644 (file)
@@ -157,7 +157,7 @@ void TestCopyBroadcasting(Type& type) {
   }
 }
 void TestAbsValue(Type& type) {
-  Tensor r = at::abs(type.scalarTensor(-3));
+  Tensor r = at::abs(at::scalar_tensor(-3, type.options()));
   ASSERT_EQ_RESOLVED(r.item<int32_t>(), 3);
 }
 /*
@@ -187,7 +187,7 @@ void TestSelect(Type& type) {
 }
 
 void TestZeroDim(Type& type) {
-  Tensor a = type.scalarTensor(4); // rand(type, {1});
+  Tensor a = at::scalar_tensor(4, type.options()); // rand(type, {1});
 
   Tensor b = rand({3, 4}, type);
   ASSERT_EQ_RESOLVED((a + a).dim(), 0);
index 36e19a6..0ab7f8b 100644 (file)
@@ -115,10 +115,10 @@ static Variable valueToTensor(const at::Type & type, PyObject* value) {
     return reinterpret_cast<THPVariable*>(value)->cdata;
   }
   if (THPUtils_checkLong(value)) {
-    return type.scalarTensor(Scalar(THPUtils_unpackLong(value)));
+    return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options());
   }
   if (PyFloat_Check(value)) {
-    return type.scalarTensor(Scalar(THPUtils_unpackDouble(value)));
+    return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options());
   }
   throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
 }